diff options
author | Stéphane Glondu <glondu@debian.org> | 2019-08-11 17:55:06 +0200 |
---|---|---|
committer | Stéphane Glondu <glondu@debian.org> | 2019-08-11 17:55:06 +0200 |
commit | 5c8e8182515d6d1e12c6c7282630a8a7b11143a0 (patch) | |
tree | b3013e583654d5bb768641f07e83d1facf296cec |
Import janest-base_0.12.2.orig.tar.gz
[dgit import orig janest-base_0.12.2.orig.tar.gz]
377 files changed, 45487 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85f39e5 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +_build +*.install +*.merlin + diff --git a/CHANGES.md b/CHANGES.md new file mode 100644 index 0000000..606e4f0 --- /dev/null +++ b/CHANGES.md @@ -0,0 +1,275 @@ +## git version + +- `Ordered_collection_common.get_pos_len` now returns an `Or_error.t` + +- Added `Bool.Non_short_circuiting`. + +- Added `Float.square`. + +- Remove module `Or_error.Ok`. + +- module `Ref` doesn't implement `Container.S1` anymore. + +- Rename parameter of `Sequence.merge` from `cmp` to `compare`. + +- Added `Info.of_lazy_t` + +- Added `List.partition_result` function, to partition a list of `Result.t` + values + +- Changed the signature of `equal` from `'a t -> 'a t -> equal:('a -> 'a -> + bool) -> bool` to `('a -> 'a -> bool) -> 'a t -> 'a t -> bool`. + +- Optimized `Lazy.compare` to check physical equality before forcing the lazy + values. + +- Deprecated `Args` in the `Applicative` interface in favor of using `ppx_let`. + +- Deprecated `Array.replace arr i ~f` in favor of using `arr.(i) <- (f (arr.(i)))` + +- Rename collection length parameter of `Ordered_collection_common` functions + from `length` to `total_length`, and add a unit argument to `get_pos_len` and + `get_pos_len_exn`. + +- Removed functions that were deprecated in 2016 from the `Array` and `Set` + modules. + +- [Int.Hex.of_string] and friends no longer silently ignore a suffix + of non-hexadecimal garbage. + +- Added `?backtrace` argument to `Or_error.of_exn_result`. + +- `List.zip` now returns a `List.Or_unequal_lengths.t` instead of an `option`. + +- Remove functions from the `Sequence` module that were deprecated in 2015. + +- `Container.Make` and `Container.Make0` now require callers to either provide a + custom `length` function or request that one be derived from `fold`. + `Container.to_array`'s signature is also changed to accept `length` and `iter` + instead of `fold`. + +## v0.11 + +- Deprecated `Not_found`, people who need it can use `Caml.Not_found`, but its + use isn't recommended. + +- Added the `Sexp.Not_found_s` exception which will replace `Caml.Not_found` as + the default exception in a future release. + +- Document that `Array.find_exn`, `Array.find_map_exn`, and `Array.findi_exn` + may throw `Caml.Not_found` _or_ `Not_found_s`. + +- Document that `Hashtbl.find_exn` may throw `Caml.Not_found` _or_ + `Not_found_s`. + +- Document that `List.find_exn`, and `List.find_map_exn` may throw + `Caml.Not_found` _or_ `Not_found_s`. + +- Document that `List.find_exn` may throw `Caml.Not_found` _or_ `Not_found_s`. + +- Document that `String.lsplit2_exn`, and `String.rsplit2_exn` may throw + `Caml.Not_found` _or_ `Not_found_s`. + +- Added `Sys.backend_type`. + +- Removed unnecessary unit argument from `Hashtbl.create`. + +- Removed deprecated operations from `Hashtbl`. + +- Removed `Hashable.t` constructors from `Hashtbl` and `Hash_set`, instead + favoring the first-class module constructors. + +- Removed `Container` operations from `Either.First` and `Either.Second`. + +- Changed the type of `fold_until` in the `Container` interfaces. Rather than + returning a `Finished_or_stopped_early.t` (which has also been removed), the + function now takes a `finish` function that will be applied the result if `f` + never returned a `Stop _`. + +- Removed the `String_dict` module. + +- Added a `Queue` module that is backed by an `Option_array` for efficient and + (non-allocating) implementations of most operations. + +- Added a `Poly` submodule to `Map` and `Set` that exposes constructors that + use polymorphic compare. + +- Deprecated `all_ignore` in the `Monad` and `Applicative` interfaces in favor + of `all_unit`. + +- Deprecated `Array.replace_all` in favor of `Array.map_inplace`, which is the + standard name for that sort of operation within Base. + +- Document that `List.find_exn`, and `List.find_map_exn` may throw + `Caml.Not_found` _or_ `Not_found_s`. + +- Make `~compare` a required argument to `List.dedup_and_sort`, `List.dedup`, + `List.find_a_dup`, `List.contains_dup`, and `List.find_all_dups`. + +- Removed `List.exn_if_dup`. It is still available in core_kernel. + +- Removed "normalized" index operation `List.slice`. It is still available in + core_kernel. + +- Remove "normalized" index operations from `Array`, which incluced + `Array.normalize`, `Array.slice`, `Array.nget` and `Array.nset`. These + operations are still available in core_kernel. + +- Added `Uniform_array` module that is just like an `Array` except guarantees + that the representation array is not tagged with `Double_array_tag`, the tag + for float arrays. + +- Added `Option_array` module that allows for a compact representation of `'a + optoin array`, which avoids allocating heap objects representing `Some a`. + +- Remove "normalized" index operations from `String`, which incluced + `String.normalize`, `String.slice`, `String.nget` and `String.nset`. These + operations are still available in core_kernel. + +- Added missing conversions between `Int63` and other integer types, + specifically, the versions that return options. + +- Added truncating versions of integer conversions, with a suffix of + `_trunc`. These allow fast conversions via bit arithmetic without + any conditional failure; excess bits beyond the width of the output + type are simply dropped. + +- Added `Sequence.group`, similar to `List.group`. + +- Reimplemented `String.Caseless.compare` so that it does not + allocate. + +- Added `String.is_substring_at string ~pos ~substring`. Used it as + back-end for `is_suffix` and `is_prefix`. + +- Moved all remaining `Replace_polymorphic_compare` submodules from Base + types and consolidated them in one place within `Import0`. + +- Removed `(<=.)` and its friends. + +- Added `Sys.argv`. + +- Added a infix exponentation operator for int. + +- Added a `Formatter` module to reexport the `Format.formatter` type and updated + the deprecation message for `Format`. + +## v0.10 + +(Changes that can break existing programs are marked with a "\*") + +### Bugfixes + +- Generalized the type of `Printf.ifprintf` to reflect OCaml's stdlib. + +- Made `Sequence.fold_m` and `iter_m` respect `Skip` steps and explicitly bind + when they occur. + +- Changed `Float.is_negative` and `is_non_positive` on `NaN` to return `false` + rather than `true`. + +- Fixed the `Validate.protect` function, which was mistakenly raising exceptions. + +### API changes + +- Renamed `Map.add` as `set`, and deprecated `add`. A later feature will add + `add` and `add_exn` in the style of `Hashtbl`. + +- A different hash function is used to implement [Base.Int.hash]. + The old implementation was [Int.abs] but collision resistance is not enough, + we want avalanching as well. + The new function is an adaptation of one of the + [Thomas Wang](http://web.archive.org/web/20071223173210/http://www.concentric.net/~Ttwang/tech/inthash.htm) + hash functions to OCaml (63-bit integers), which results in reasonably good avalanching. + + +- Made `open Base` expose infix float operators (+., -., etc.). + +* Renamed `List.dedup` to `List.dedup_and_sort`, to better reflect its existing behavior. + +- Added `Hashtbl.find_multi` and `Map.find_multi`. + +- Added function `Map.of_increasing_sequence` for constructing a `Map.t` from an + ordered `Sequence.t` + +- Added function `List.chunks_of : 'a t -> length : int -> 'a t t`, for breaking + a list into chunks of equal length. + +- Add to module `Random` numeric functions that take upper and lower inclusive + bounds, e.g. `Random.int_incl : int -> int -> int`. + +* Replaced `Exn.Never_elide_backtrace` with `Backtrace.elide`, a `ref` cell that + determines whether `Backtrace.to_string` and `Backtrace.sexp_of_t` elide + backtraces. + +- Exposed infix operator `Base.( @@ )`. + +- Exposed modules `Base.Continue_or_stop` and `Finished_or_stopped_early`, used + with the `Container.fold_until` function. + +- Exposed module types Base.T, T1, T2, and T3. + +- Added `Sequence.Expert` functions `next_step` and + `delayed_fold_step`, for clients that want to explicitly handle `Skip` steps. + +- Added `Bytes` module. + This includes the submodules `From_string` and `To_string` with blit + functions. + N.B. the signature (and name) of `unsafe_to_string` and `unsafe_of_string` are + different from the one in the standard library (and hopefully more explicit). + +- Add bytes functions to `Buffer`. + Also added `Buffer.content_bytes`, the analog of `contents` but that returns + `bytes` rather than `string`. + +* Enabled `-safe-string`. + +- Added function `Int63.of_int32`, which was missing. + +* Deprecated a number of `String` mutating functions. + +- Added module `Obj_array`, moved in from `Core_kernel`. + +* In module type `Hashtbl.Accessors`, removed deprecated functions, moving them + into a new module type, `Deprecated`. + +- Exported `sexp_*` types that are recognized by `ppx_sexp_*` converters: + `sexp_array`, `sexp_list`, `sexp_opaque`, `sexp_option`. + +* Reworked the `Or_error` module's interface, moving the `Container.S` interface + to an `Ok` submodule, and adding functions `is_ok`, `is_error`, and `ok` to + more closely resemble the interface of the `Result` module. + +- Removed `Int.O.of_int_exn`. + +- Exposed `Base.force` function. + +- Changed the deprecation warning for `mod` to recommend `( % )` rather than + `Caml.( mod )`. + +### Performance related changes + +- Optimized `List.compare`, removing its closure allocation. + +- Optimized `String.mem` to not allocate. + +- Optimized `Float.is_negative`, `is_non_negative`, `is_positive`, and + `is_non_positive` to avoid some boxing. + +- Changed `Hashtbl.merge` to relax its equality check on the input tables' + `Hashable.t` records, checking physical equality componentwise if the records + aren't physically equal. + +- Added `Result.combine_errors`, similar to `Or_error.combine_errors`, with a + slightly different type. + +- Added `Result.combine_errors_unit`, similar to `Or_error.combine_errors_unit`. + +- Optimized the `With_return.return` type by adding the `[@@unboxed]` attribute. + +- Improved a number of deprecation warnings. + + +## v0.9 + +Initial release. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..45e1a22 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,67 @@ +This repository contains open source software that is developed and +maintained by [Jane Street][js]. + +Contributions to this project are welcome and should be submitted via +GitHub pull requests. + +Signing contributions +--------------------- + +We require that you sign your contributions. Your signature certifies +that you wrote the patch or otherwise have the right to pass it on as +an open-source patch. The rules are pretty simple: if you can certify +the below (from [developercertificate.org][dco]): + +``` +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. +1 Letterman Drive +Suite D4700 +San Francisco, CA, 94129 + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. +``` + +Then you just add a line to every git commit message: + +``` +Signed-off-by: Joe Smith <joe.smith@email.com> +``` + +Use your real name (sorry, no pseudonyms or anonymous contributions.) + +If you set your `user.name` and `user.email` git configs, you can sign +your commit automatically with git commit -s. + +[dco]: http://developercertificate.org/ +[js]: https://opensource.janestreet.com/ diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..0680c3e --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,21 @@ +The MIT License + +Copyright (c) 2016--2019 Jane Street Group, LLC <opensource@janestreet.com> + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..1965878 --- /dev/null +++ b/Makefile @@ -0,0 +1,17 @@ +INSTALL_ARGS := $(if $(PREFIX),--prefix $(PREFIX),) + +default: + dune build + +install: + dune install $(INSTALL_ARGS) + +uninstall: + dune uninstall $(INSTALL_ARGS) + +reinstall: uninstall install + +clean: + dune clean + +.PHONY: default install uninstall reinstall clean diff --git a/README.org b/README.org new file mode 100644 index 0000000..83a9061 --- /dev/null +++ b/README.org @@ -0,0 +1,187 @@ +* Base + +Base is a standard library for OCaml. It provides a standard set of +general purpose modules that are well-tested, performant, and +fully-portable across any environment that can run OCaml code. Unlike +other standard library projects, Base is meant to be used as a +wholesale replacement of the standard library distributed with the +OCaml compiler. In particular it makes different choices and doesn't +re-export features that are not fully portable such as I/O, which are +left to other libraries. + +You also might want to browse the [[https://ocaml.janestreet.com/ocaml-core/latest/doc/base/index.html][API Documentation]]. + +** Installation + +Install Base via [[https://opam.ocaml.org][OPAM]]: + +#+begin_src +$ opam install base +#+end_src + +Base has no runtime dependencies and is fast to build. Its sole build +dependency is [[https://github.com/ocaml/dune][dune]], which itself requires nothing more than the +compiler. + +** Using the OCaml standard library with Base + +Base is intended as a full stdlib replacement. As a result, after an +=open Base=, all the modules, values, types, ... coming from the OCaml +standard library that one normally gets in the default environment are +deprecated. + +In order to access these values, one must use the =Caml= library, +which re-exports them all through the toplevel name =Caml=: +=Caml.String=, =Caml.print_string=, ... + +The recommended way to build code using Base is as follows: + +#+begin_src ocaml +$ ocamlc -open Base +#+end_src + +** Differences between Base and the OCaml standard library + +Programmers who are used to the OCaml standard library should read +through this section to understand major differences between the two +libraries that one should be aware of when switching to Base. + +*** Comparison operators + +The comparison operators exposed by the OCaml standard library are +polymorphic: + +#+begin_src ocaml +val compare : 'a -> 'a -> int +val ( <= ) : 'a -> 'a -> bool +... +#+end_src + +What they implement is structural comparison of the runtime +representation of values. Since these are often error-prone, +i.e. they don't correspond to what the user expects, they are not +exposed directly by Base. + +To use polymorphic comparison with Base, one should use the +=Polymorphic_compare= module. The default comparison operators exposed +by Base are the integer ones, just like the default arithmetic +operators are the integer ones. + +The recommended way to compare arbitrary complex data structures is to +use the specific =compare= functions. For instance: + +#+begin_src ocaml +List.compare String.compare x y +#+end_src + +The [[https://github.com/janestreet/ppx_compare][ppx_compare]] rewriter +offers an alternative way to write this: + +#+begin_src ocaml +[%compare: string list] x y +#+end_src + +** Base and ppx code generators + +Base uses a few ppx code generators to implement: + +- reliable and customizable comparison of OCaml values +- reliable and customizable hash of OCaml values +- conversions between OCaml values and s-expression + +However, it doesn't need these code generators to build. What it does +instead is use ppx as a code verification tool during development. It +works in a very similar fashion to +[[https://github.com/janestreet/ppx_expect][expectation tests]]. + +Whenever you see this in the code source: + +#+begin_src ocaml +type t = ... [@@deriving_inline sexp_of] +let sexp_of_t = ... +[@@@end] +#+end_src + +the code between the =[@@deriving_inline]= and the =[@@@end]= is +generated code. The generated code is currently quite big and hard to +read, however we are working on making it look like human-written +code. + +You can put the following elisp code in your =~/.emacs= file to hide +these blocks: + +#+begin_src scheme +(defun deriving-inline-forward-sexp (&optional arg) + (search-forward-regexp "\\[@@@end\\]") nil nil arg) + +(defun setup-hide-deriving-inline () + (inline) + (hs-minor-mode t) + (let ((hs-hide-comments-when-hiding-all nil)) + (hs-hide-all))) + +(require 'hideshow) +(add-to-list 'hs-special-modes-alist + '(tuareg-mode "\\[@@deriving_inline[^]]*\\]" "\\[@@@end\\]" nil + deriving-inline-forward-sexp nil)) +(add-hook 'tuareg-mode-hook 'setup-hide-deriving-inline) +#+end_src + +Things are not yet setup in the git repository to make it convenient +to change types and update the generated code, but they will be setup +soon. + +** Base coding rules + +There are a few coding rules across the code base that are enforced by +lint tools. + +These rules are: + +- Opening the =Caml= module is not allowed. Inside Base, the OCaml + stdlib is shadowed and accessible through the =Caml= module. We + forbid opening =Caml= so that we know exactly where things come + from. +- =Caml.Foo= modules cannot be aliased, one must use =Caml.Foo= + explicitly. This is to avoid having to remember a list of aliases + at the beginning of each file. +- For some modules that are both in the OCaml stdlib and Base, such as + =String=, we define a module =String0= for common functions that + cannot be defined directly in =Base.String= to avoid creating a + circular dependency. Except for =String= itself, other modules + are not allowed to use =Caml.String= and must use either =String= or + =String0= instead. +- Indentation is exactly the one of =ocp-indent=. +- A few other coding style rules enforced by + [[https://github.com/janestreet/ppx_js_style][ppx_js_style]]. + +The Base specific coding rules are checked by =ppx_base_lint=, in the +=lint= subfolder. The indentation rules are checked by a wrapper around +=ocp-indent= and the coding style rules are checked by =ppx_js_style=. + +These checks are currently not run by =dune=, but it will soon get a +=-dev= flag to run them automatically. + +** Roadmap + +Following is the current plan for a stable version 1 of Base. + +*** Add more integer types + +Add support for ={,u}int{8,16,32,64}=. These are always useful when +implementing binary protocols. + +Initially they should be implemented with C stubs and eventually we +should propose their inclusion in the compiler. + +*** 80 columns limit + +Currently lines in Base are limited to a maximum width of 90 +characters. To make things more standard, we should use an 80 columns +limit. The only thing needed for this is to extend the style checker +to enforce a maximum line width. + +*** Improve the generated code + +Improve our code generators to produce code that looks more like +hand-written code. diff --git a/ROADMAP.md b/ROADMAP.md new file mode 100644 index 0000000..d7644e8 --- /dev/null +++ b/ROADMAP.md @@ -0,0 +1,112 @@ +# Stable Interface (v1.0) + + - [X] Make the entire library `-safe-string` compliant. This will involve + introducing a `Bytes` module, removing all direct mutation on strings from + the `String` module, and "re-typing" string values that require mutation to + `bytes`. + + - [X] Do not export the `\*\_intf` modules from Base. Instead, any signatures + should be exported by the `.ml` and `.mli`s. + + - [X] Only expose the first-class module interface of `Hashtbl`. Accompanying + this should be cleanup of `Hashtbl_intf`, moving anything that's still + required in core_kernel to the appropriate files in that project. + + - [X] Replace `Hashtbl.create (module String) ()` by just + `Hashtbl.create (module String)` + + - [X] Remove `replace` from `Hashtbl_intf.Accessors`. + + - [X] Label one of the arguments of `Hashtbl_intf.merge_into` to indicate the + flow of data. + + - [X] Merge `Hashtbl_intf.Key_common` and `Hashtbl_intf.Key_plain`. + + - [ ] Use `Either.t` as the return value for `Map.partition`. + + - [X] Rename `Monad_intf.all_ignore` to `Monad_intf.all_unit`. + + - [ ] Eliminate all uses of `Not_found`, replacing them with descriptive error messages. + + - [X] Move the various private modules to `Base.Base_private` + instead of `Base.Exported_for_specific_uses` and `Base.Not_exposed_properly` + + - [X] Use `compare` rather than `cmp` as the label for comparison functions + throughout. + +# Implementation Cleanup + + - [ ] Remove `ignore` and `(=)` from `Sexp_conv`'s public interface. These + values are hidden from the documentation so their removal won't be + considered a breaking API change. + + - [ ] Do not expose the type equality `Int63_emul.W.t = int64`. + + - [ ] Replace the exception thrown by `Float.of_string` with a named + exception that's more descriptive. + + - [X] Delete the `Hashable` toplevel module. This is a vestige of the previous + `Map` and `Set` implementations and is no longer needed. + + - [ ] Ensure that `Map` operations that are effective NO-OPs return the same + `Map.t` they were provided. Candidate operations include e.g `add`, `remove`, + `filter`. + + - [ ] Simplify the implementation of `Option.value_exn`, if possible. + + - [ ] Eliminate all instances of `open! Polymorphic_compare` + + - [ ] Refactor common blit code in `String.replace_all` and `String.replace_first`. + + - [ ] Delete unused function aliases in `Import0` + + - [ ] Put `Sexp_conv.Exn_converter` into its own file, with only an + alias in Sexp_conv, so that it doesn't get pulled unless used + + - [ ] Create a file with all the basic types and their associated + combinators (`sexp_of_t`, `compare`, `hash`), and expose the + declaration + + - [ ] Put all the exported private modules from + `Base.Exported_for_specific_uses` and `Base.Not_exposed_properly` + in `Base.Base_private` + + - [ ] Decide on a better name for `Polymorphic_compare`. + `Polymorphic_compare_intf` contains interface for comparison + of non-polymorphic types, which is weird. Get rid of it and + inline things in `Comparable_intf` + + - [X] `hashtbl_of_sexp` shouldn't live in Base.Sexp_conv since we + have our own hash tables. Move it to sexplib + +# Performance Improvements + + - [ ] In `Hash_set.diff`, use the size of each set to determine which to iterate + over. + + - [ ] Ensure that the correct `compare` function and other related functions are + exported by all modules. These functions should not be derived from + a functor application, in order to ensure proper inlining. Implementing + this change should also include benchmarks to verify the initial result, + and to maintain it on an ongoing basis. See `bench/bench_int.ml` for + examples. + + - [X] Optimize `Lazy.compare` by performing a `phys_equal` check before + forcing the lazy value. Note that this will also change the semantics of + `compare` and should be documented and rolled out with care. + + - [ ] Conduct a thorough performance review of the `Sequence` module. + +# Documentation + + - [ ] Consolidate documentation the interface and implementation files + related to the `Hash` module. + + - [ ] Add documentation to the `Ref` toplevel module. + + - [ ] Document properly how `String.unescape_gen` handles malformed strings + +# Changes For The Distant Future + + - [ ] Make the various comparison functions return an `Ordering.t` + instead of an `int`. diff --git a/base.opam b/base.opam new file mode 100644 index 0000000..7662dad --- /dev/null +++ b/base.opam @@ -0,0 +1,35 @@ +opam-version: "2.0" +version: "v0.12.0" +maintainer: "opensource@janestreet.com" +authors: ["Jane Street Group, LLC <opensource@janestreet.com>"] +homepage: "https://github.com/janestreet/base" +bug-reports: "https://github.com/janestreet/base/issues" +dev-repo: "git+https://github.com/janestreet/base.git" +doc: "https://ocaml.janestreet.com/ocaml-core/latest/doc/base/index.html" +license: "MIT" +build: [ + ["dune" "build" "-p" name "-j" jobs] +] +depends: [ + "ocaml" {>= "4.04.2" & < "4.09.0"} + "sexplib0" {>= "v0.12" & < "v0.13"} + "dune" {build & >= "1.5.1"} +] +depopts: [ + "base-native-int63" +] +synopsis: "Full standard library replacement for OCaml" +description: " +Full standard library replacement for OCaml + +Base is a complete and portable alternative to the OCaml standard +library. It provides all standard functionalities one would expect +from a language standard library. It uses consistent conventions +across all of its module. + +Base aims to be usable in any context. As a result system dependent +features such as I/O are not offered by Base. They are instead +provided by companion libraries such as stdio: + + https://github.com/janestreet/stdio +" diff --git a/compiler-stdlib/gen/dune b/compiler-stdlib/gen/dune new file mode 100644 index 0000000..0f60968 --- /dev/null +++ b/compiler-stdlib/gen/dune @@ -0,0 +1,3 @@ +(executables (names gen) + (libraries compiler-libs.common compiler-libs.bytecomp) + (preprocess no_preprocessing))
\ No newline at end of file diff --git a/compiler-stdlib/gen/gen.ml b/compiler-stdlib/gen/gen.ml new file mode 100644 index 0000000..7b90ed3 --- /dev/null +++ b/compiler-stdlib/gen/gen.ml @@ -0,0 +1,88 @@ + +open StdLabels + +module Ocaml_version : sig + type t + val parse : string -> t + val v407 : t + val v408 : t + val current : t + val compare : t -> t -> int +end = struct + type t = int * int + + let parse s = + try + let d1 = String.index_from s 0 '.' in + let d2 = try String.index_from s (d1 + 1) '.' with Not_found -> String.length s in + let p1 = int_of_string (String.sub s ~pos:0 ~len:d1) in + let p2 = int_of_string (String.sub s ~pos:(d1 + 1) ~len:(d2 - d1 - 1)) in + p1, p2 + with _ -> failwith (Printf.sprintf "Invalid ocaml version %S" s) + + let v407 = parse "4.07" + let v408 = parse "4.08" + + let current = parse Sys.ocaml_version + + let compare ((a1,b1): t) ((a2,b2):t) = + match compare a1 a2 with + | 0 -> compare b1 b2 + | c -> c +end + +let () = + let ocaml_where, oc = + match Sys.argv with + | [|_; "-ocaml-where"; ocaml_where; "-o"; fn|] -> + (ocaml_where, open_out fn) + | _ -> + failwith "bad command line arguments" + in + let pr fmt = Printf.fprintf oc (fmt ^^ "\n") in + pr "(* This file is automatically generated *)"; + pr ""; + (if Ocaml_version.(compare current v407) >= 0 + then pr "include Stdlib" + else begin + (* The cma format is documented in typing/cmo_format.mli in the compiler sources *) + let ic = + let (^/) = Filename.concat in + try open_in_bin (ocaml_where ^/ "stdlib" ^/ "stdlib.cma") + with Sys_error _ -> open_in_bin (ocaml_where ^/ "stdlib.cma") + in + let len_magic_number = String.length Config.cma_magic_number in + let magic_number = really_input_string ic len_magic_number in + assert (magic_number = Config.cma_magic_number); + let toc_pos = input_binary_int ic in + seek_in ic toc_pos; + let toc = (input_value ic : Cmo_format.library) in + close_in ic; + let units = + List.map toc.lib_units ~f:(fun cu -> cu.Cmo_format.cu_name) + |> List.sort ~cmp:String.compare + in + let max_len = + List.fold_left units ~init:0 ~f:(fun acc unit -> + max acc (String.length unit)) + in + List.iter units ~f:(fun u -> pr "module %-*s = %s" max_len u u); + pr ""; + pr "include Pervasives"; + end); + pr ""; + (if Ocaml_version.(compare current v407) < 0 + then + pr "module Float = struct end"); + (if Ocaml_version.(compare current v408) < 0 + then begin + pr "module Bool = struct end"; + pr "module Int = struct end"; + pr "module Option = struct end"; + pr "module Result = struct end"; + pr "module Unit = struct end"; + pr "module Fun = struct end" + end; + pr ""; + pr "exception Not_found = Not_found") + diff --git a/compiler-stdlib/src/dune b/compiler-stdlib/src/dune new file mode 100644 index 0000000..317f8e4 --- /dev/null +++ b/compiler-stdlib/src/dune @@ -0,0 +1,4 @@ +(library (name caml) (public_name base.caml) (preprocess no_preprocessing)) + +(rule (targets caml.ml) (deps (:first_dep ../gen/gen.exe)) + (action (run %{first_dep} -ocaml-where %{ocaml_where} -o %{targets})))
\ No newline at end of file diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..598db56 --- /dev/null +++ b/dune-project @@ -0,0 +1 @@ +(lang dune 1.5)
\ No newline at end of file diff --git a/generate/dune b/generate/dune new file mode 100644 index 0000000..8c94536 --- /dev/null +++ b/generate/dune @@ -0,0 +1,2 @@ +(executables (names generate_pow_overflow_bounds) (libraries num) + (preprocess no_preprocessing))
\ No newline at end of file diff --git a/generate/generate_pow_overflow_bounds.ml b/generate/generate_pow_overflow_bounds.ml new file mode 100644 index 0000000..4f1f623 --- /dev/null +++ b/generate/generate_pow_overflow_bounds.ml @@ -0,0 +1,185 @@ +(* NB: This needs to be pure OCaml (no Base!), since we need this in order to build + Base. *) + +(* This module generates lookup tables to detect integer overflow when calculating integer + exponents. At index [e], [table.[e]^e] will not overflow, but [(table[e] + 1)^e] + will. *) + +type mode = Normal | Atomic of { out_fn : string; tmp_fn : string } + +let oc, mode = + match Sys.argv with + | [|_|] -> (stdout, Normal) + | [|_; "-o"; out_fn|] + | [|_; "-atomic"; "-o"; out_fn|] -> + (* Always produce the file atomically, we just have this option to remember that we + need to do it *) + let tmp_fn, oc = + Filename.open_temp_file + ~temp_dir:(Filename.dirname out_fn) + "generate_pow_overflow_bounds" ".ml.tmp" + in + (oc, Atomic { out_fn; tmp_fn }) + | _ -> failwith "bad command line arguments" + +module Big_int = struct + include Big_int + type t = big_int + let (>) = gt_big_int + let (<=) = le_big_int + let (^) = power_big_int_positive_int + let (-) = sub_big_int + let (+) = add_big_int + let one = unit_big_int + let sqrt = sqrt_big_int + let to_string = string_of_big_int +end + +module Array = StdLabels.Array + +type generated_type = + | Int + | Int32 + | Int63 + | Int64 + +let max_big_int_for_bits bits = + let shift = bits - 1 in (* sign bit *) + Big_int.((shift_left_big_int one shift) - one) +;; + +let safe_to_print_as_int = + let int31_max = max_big_int_for_bits 31 in + fun x -> Big_int.(x <= int31_max) + +let format_entry typ b = + let s = Big_int.to_string b in + match typ with + | Int -> + if safe_to_print_as_int b + then s + else Printf.sprintf "Caml.Int64.to_int %sL" s + | Int32 -> s ^ "l" + | Int63 + | Int64 -> s ^ "L" + +let bits = function + | Int -> assert false (* architecture dependent *) + | Int32 -> 32 + | Int63 -> 63 + | Int64 -> 64 + +let max_val typ = max_big_int_for_bits (bits typ) + +let name = function + | Int -> "int" + | Int32 -> "int32" + | Int63 -> "int63_on_int64" + | Int64 -> "int64" + +let ocaml_type_name = function + | Int -> "int" + | Int32 -> "int32" + | Int63 + | Int64 -> "int64" + +let generate_negative_bounds = function + | Int -> false + | Int32 -> false + | Int63 -> false + | Int64 -> true + +let highest_base exponent max_val = + let open Big_int in + match exponent with + | 0 | 1 -> max_val + | 2 -> sqrt max_val + | _ -> + let rec search possible_base = + if possible_base ^ exponent > max_val then + begin + let res = possible_base - one in + assert (res ^ exponent <= max_val); + res + end + else + search (possible_base + one) + in + search one +;; + +type sign = Pos | Neg + +let pr fmt = Printf.fprintf oc (fmt ^^ "\n") + +let gen_array ~typ ~bits ~sign ~indent = + let pr fmt = pr ("%*s" ^^ fmt) indent "" in + let max_val = max_big_int_for_bits bits in + let pos_bounds = Array.init 64 ~f:(fun i -> highest_base i max_val) in + let bounds = + match sign with + | Pos -> pos_bounds + | Neg -> Array.map pos_bounds ~f:Big_int.minus_big_int + in + pr "[| %s" (format_entry typ bounds.(0)); + for i = 1 to Array.length bounds - 1 do + pr "; %s" (format_entry typ bounds.(i)) + done; + pr "|]"; +;; + + +let gen_bounds typ = + pr "let overflow_bound_max_%s_value : %s =" (name typ) (ocaml_type_name typ); + (match typ with + | Int -> pr " (-1) lsr 1" + | _ -> pr " %s" (format_entry typ (max_val typ))); + pr ""; + + let array_name typ sign = + Printf.sprintf "%s_%s_overflow_bounds" (name typ) + (match sign with Pos -> "positive" | Neg -> "negative") + in + + pr "let %s : %s array =" (array_name typ Pos) (ocaml_type_name typ); + (match typ with + | Int -> + pr " match Int_conversions.num_bits_int with"; + pr " | 32 -> Array.map %s ~f:Caml.Int32.to_int" (array_name Int32 Pos); + pr " | 63 ->"; + gen_array ~typ ~bits:63 ~sign:Pos ~indent:4; + pr " | 31 ->"; + gen_array ~typ ~bits:31 ~sign:Pos ~indent:4; + pr " | _ -> assert false" + | _ -> + gen_array ~typ ~bits:(bits typ) ~sign:Pos ~indent:2); + pr ""; + + if generate_negative_bounds typ then begin + pr "let %s : %s array =" (array_name typ Neg) (ocaml_type_name typ); + gen_array ~typ ~bits:(bits typ) ~sign:Neg ~indent:2 + end; +;; + +let () = + pr "(* This file was autogenerated by %s *)" Sys.argv.(0); + pr ""; + pr "open! Import"; + pr ""; + pr "module Array = Array0"; + pr ""; + pr "(* We have to use Int64.to_int_exn instead of int constants to make"; + pr " sure that file can be preprocessed on 32-bit machines. *)"; + pr ""; + gen_bounds Int32; + gen_bounds Int; + gen_bounds Int63; + gen_bounds Int64; +;; + +let () = + match mode with + | Normal -> () + | Atomic { tmp_fn; out_fn } -> + close_out oc; + Sys.rename tmp_fn out_fn diff --git a/lint/dune b/lint/dune new file mode 100644 index 0000000..40445e6 --- /dev/null +++ b/lint/dune @@ -0,0 +1,3 @@ +(library (name ppx_base_lint) (kind ppx_rewriter) + (libraries compiler-libs.common base ppxlib) + (preprocess no_preprocessing))
\ No newline at end of file diff --git a/lint/ppx_base_lint.ml b/lint/ppx_base_lint.ml new file mode 100644 index 0000000..a012454 --- /dev/null +++ b/lint/ppx_base_lint.ml @@ -0,0 +1,105 @@ +open Ppxlib +open Base + +let error ~loc fmt = + Location.raise_errorf ~loc (Caml.(^^) "ppx_base_lint:" fmt) + +type suspicious_id = + | Caml_submodule of string + +let rec iter_suspicious (id : Longident.t) ~f = + match id with + | Ldot (Lident "Caml", s) when + String.(<>) s "" && + match s.[0] with + | 'A'..'Z' -> true + | _ -> false + -> + f (Caml_submodule s) + | Ldot (x, _) -> iter_suspicious x ~f + | Lapply (a, b) -> + iter_suspicious a ~f; + iter_suspicious b ~f + | Lident _ -> () + +let zero_modules () = + Caml.Sys.readdir "." + |> Array.to_list + |> List.filter ~f:(fun fn -> + Caml.Filename.check_suffix fn "0.ml") + |> List.map ~f:(fun fn -> + String.capitalize (String.sub fn ~pos:0 ~len:(String.length fn - 4))) + |> Set.of_list (module String) + +let check_open (id : Longident.t Asttypes.loc) = + match id.txt with + | Lident "Caml" -> + error ~loc:id.loc + "you are not allowed to open Caml inside Base" + | _ -> () + +let rec is_caml_dot_something : Longident.t -> bool = function + | Ldot (Lident "Caml", _) -> true + | Ldot (id, _) -> is_caml_dot_something id + | _ -> false + +let check current_module = + let zero_modules = zero_modules () in + object + inherit Ast_traverse.iter as super + + method! longident_loc { txt = id; loc } = + (* Note: we don't distinguish between module identifiers and constructors + names. Since there is no [Caml.String], [Caml.Array], ... constructors this is + not a problem. *) + iter_suspicious id ~f:(function + | Caml_submodule m -> + if not (Set.mem zero_modules m) then + () (* We are allowed to use Caml modules that don't have a Foo0 version *) + else if String.equal (m ^ "0") current_module then + () (* Foo0 is allowed to use Caml.Foo *) + else + match current_module with + | "Import0" | "Base" -> () + | _ -> + error ~loc + "you cannot use [Caml.%s] here, use [%s0] instead" m m) + + method! expression e = + super#expression e; + match e.pexp_desc with + | Pexp_open (_, id, _) -> check_open id + | _ -> () + + method! open_description op = + super#open_description op; + check_open op.popen_lid + + method! module_binding mb = + super#module_binding mb; + match current_module with + | "Import0" -> () + | _ -> + match mb.pmb_expr.pmod_desc with + | Pmod_ident { txt = id; _ } when is_caml_dot_something id -> + error ~loc:mb.pmb_loc + "you cannot alias [Caml] sub-modules, use them directly" + | _ -> () + end + +let module_of_loc (loc : Location.t) = + String.capitalize (Caml.Filename.chop_extension + (Caml.Filename.basename loc.loc_start.pos_fname)) + +let () = + Ppxlib.Driver.register_transformation "base_lint" + ~impl:(function + | [] -> [] + | { pstr_loc = loc; _ } :: _ as st -> + (check (module_of_loc loc))#structure st; + st) + ~intf:(function + | [] -> [] + | { psig_loc = loc; _ } :: _ as sg -> + (check (module_of_loc loc))#signature sg; + sg) diff --git a/md5/src/dune b/md5/src/dune new file mode 100644 index 0000000..a10038e --- /dev/null +++ b/md5/src/dune @@ -0,0 +1,2 @@ +(library (name md5_lib) (public_name base.md5) (preprocess no_preprocessing) + (libraries) (js_of_ocaml (javascript_files)))
\ No newline at end of file diff --git a/md5/src/md5_lib.ml b/md5/src/md5_lib.ml new file mode 100644 index 0000000..52eccc1 --- /dev/null +++ b/md5/src/md5_lib.ml @@ -0,0 +1,26 @@ +type t = string + +(* Share the digest of the empty string *) +let empty = Digest.string "" +let make s = + if s = empty then + empty + else + s + +let compare = compare + +let length = 16 + +let to_binary s = s +let of_binary_exn s = assert (String.length s = length); make s +let unsafe_of_binary = make + +let to_hex = Digest.to_hex +let of_hex_exn s = make (Digest.from_hex s) + +let string s = make (Digest.string s) + +let bytes s = make (Digest.bytes s) + +let subbytes bytes ~pos ~len = make (Digest.subbytes bytes pos len) diff --git a/md5/src/md5_lib.mli b/md5/src/md5_lib.mli new file mode 100644 index 0000000..f784eb9 --- /dev/null +++ b/md5/src/md5_lib.mli @@ -0,0 +1,21 @@ +type t + +val compare : t -> t -> int + +(** [length = 16] is the size of the digest in bytes. *) +val length : int + +val to_binary : t -> string +val of_binary_exn : string -> t + +(** assumes the input is 16 bytes without checking *) +val unsafe_of_binary : string -> t + +val to_hex : t -> string +val of_hex_exn : string -> t + +val string : string -> t + +val bytes : bytes -> t + +val subbytes : bytes -> pos:int -> len:int -> t diff --git a/shadow-stdlib/gen/dune b/shadow-stdlib/gen/dune new file mode 100644 index 0000000..9e34b3a --- /dev/null +++ b/shadow-stdlib/gen/dune @@ -0,0 +1,4 @@ +(executables (names gen) (libraries str compiler-libs.common caml) + (link_flags -linkall) (preprocess no_preprocessing)) + +(ocamllex mapper)
\ No newline at end of file diff --git a/shadow-stdlib/gen/gen.ml b/shadow-stdlib/gen/gen.ml new file mode 100644 index 0000000..a763c86 --- /dev/null +++ b/shadow-stdlib/gen/gen.ml @@ -0,0 +1,36 @@ +open StdLabels + +let () = + let cmi_fn, oc = + match Sys.argv with + | [|_; "-caml-cmi"; cmi_fn; "-o"; fn|] -> + (cmi_fn, open_out fn) + | [|_; "-caml-cmi"; cmi_fn1; cmi_fn2; "-o"; fn|] -> + let cmi_fn = + if Sys.file_exists cmi_fn1 then + cmi_fn1 + else + cmi_fn2 + in + (cmi_fn, open_out fn) + | _ -> + failwith "bad command line arguments" + in + + try + let cmi = Cmi_format.read_cmi cmi_fn in + let buf = Buffer.create 512 in + let pp = Format.formatter_of_buffer buf in + Format.pp_set_margin pp max_int; (* so we can parse line by line below *) + Format.fprintf pp "%a@." Printtyp.signature cmi.Cmi_format.cmi_sign; + let s = Buffer.contents buf in + let lines = Str.split (Str.regexp "\n") s in + Printf.fprintf oc "[@@@warning \"-3\"]\n\n"; + List.iter lines ~f:(fun line -> + let repl = Mapper.line (Lexing.from_string line) in + if repl <> "" then + Printf.fprintf oc "%s\n\n" repl); + flush oc + with exn -> + Location.report_exception Format.err_formatter exn; + exit 2 diff --git a/shadow-stdlib/gen/mapper.mll b/shadow-stdlib/gen/mapper.mll new file mode 100644 index 0000000..3de7c37 --- /dev/null +++ b/shadow-stdlib/gen/mapper.mll @@ -0,0 +1,267 @@ +{ +open StdLabels +open Printf + +module String = struct + [@@@warning "-32-3"] + let capitalize_ascii = String.capitalize + let uncapitalize_ascii = String.uncapitalize + let uppercase_ascii = String.uppercase + let lowercase_ascii = String.lowercase + include String +end + +let deprecated_msg ~is_exn what = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + Referring to the stdlib directly is discouraged by Base. You should either\n\ + use the equivalent functionality offered by Base, or if you really want to\n\ + refer to the stdlib, use Caml.%s instead\"]" + (if is_exn then "@" else "@@") + what + +let deprecated_msg_no_equivalent ~is_exn what = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + There is not equivalent functionality in Base or Stdio at the moment,\n\ + so you need to use [Caml.%s] instead\"]" + (if is_exn then "@" else "@@") + what + +let deprecated_msg_with_repl_text ~is_exn text = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + %s.\"]" + (if is_exn then "@" else "@@") + text + +let deprecated_msg_with_repl ~is_exn repl = + deprecated_msg_with_repl_text ~is_exn (sprintf "Use [%s] instead" repl) + +let deprecated_msg_with_approx_repl ~is_exn ~id repl = + sprintf + "[%sdeprecated \"\\\n\ + [2016-09] this element comes from the stdlib distributed with OCaml.\n\ + There is no equivalent functionality in Base or Stdio but you can use\n\ + [%s] instead.\n\ + Alternatively, if you really want to refer the stdlib function you can\n\ + use [Caml.%s].\"]" + (if is_exn then "@" else "@@") + repl id + +type replacement = + | No_equivalent + | Repl of string + | Repl_text of string + | Approx of string + +let val_replacement = function + | "( != )" -> Repl "not (phys_equal ...)" + | "( == )" -> Repl "phys_equal" + | "( ** )" -> Repl "**." + | "( mod )" -> Repl_text "Use (%), which has slightly different \ + semantics, or Int.rem which is equivalent" + | "acos" -> Repl "Float.acos" + | "asin" -> Repl "Float.asin" + | "atan" -> Repl "Float.atan" + | "atan2" -> Repl "Float.atan2" + | "bool_of_string" -> Repl "Bool.of_string" + | "ceil" -> Repl "Float.round_up" + | "char_of_int" -> Repl "Char.of_int_exn" + | "classify_float" -> Repl "Float.classify" + | "close_in" -> Repl "Stdio.In_channel.close" + | "close_in_noerr" -> Repl "Stdio.In_channel.close" + | "close_out" -> Repl "Stdio.Out_channel.close" + | "close_out_noerr" -> Repl "Stdio.Out_channel.close" + | "copysign" -> Repl "Float.copysign" + | "cos" -> Repl "Float.cos" + | "cosh" -> Repl "Float.cosh" + | "decr" -> Repl "Int.decr" + | "epsilon_float" -> Repl "Float.epsilon_float" + | "exp" -> Repl "Float.exp" + | "expm1" -> Repl "Float.expm1" + | "float" -> Repl "Float.of_int" + | "float_of_int" -> Repl "Float.of_int" + | "float_of_string" -> Repl "Float.of_string" + | "floor" -> Repl "Float.round_down" + | "flush" -> Repl "Stdio.Out_channel.flush" + | "flush_all" -> No_equivalent + | "frexp" -> Repl "Float.frexp" + | "hypot" -> Repl "Float.hypot" + | "in_channel_length" -> Repl "Stdio.In_channel.length" + | "infinity" -> Repl "Float.infinity" + | "incr" -> Repl "Int.incr" + | "input" -> Repl "Stdio.In_channel.input" + | "input_binary_int" -> Repl "Stdio.In_channel.input_binary_int" + | "input_byte" -> Repl "Stdio.In_channel.input_byte" + | "input_char" -> Repl "Stdio.In_channel.input_char" + | "input_line" -> Repl "Stdio.In_channel.input_line" + | "input_value" -> Repl "Stdio.In_channel.unsafe_input_value" + | "int_of_char" -> Repl "Char.to_int" + | "int_of_float" -> Repl "Int.of_float" + | "int_of_string" -> Repl "Int.of_string" + | "ldexp" -> Repl "Float.ldexp" + | "log" -> Repl "Float.log" + | "log10" -> Repl "Float.log10" + | "log1p" -> Repl "Float.log1p" + | "max_float" -> Repl "Float.max_finite_value" + | "max_int" -> Repl "Int.max_value" + | "min_float" -> Repl "Float.min_positive_normal_value" + | "min_int" -> Repl "Int.min_value" + | "mod_float" -> Repl "Float.mod_float" + | "modf" -> Repl "Float.modf" + | "nan" -> Repl "Float.nan" + | "neg_infinity" -> Repl "Float.neg_infinity" + | "open_in" -> Repl "Stdio.In_channel.create" + | "open_in_bin" -> Repl "Stdio.In_channel.create" + | "open_in_gen" -> No_equivalent + | "open_out" -> Repl "Stdio.Out_channel.create" + | "open_out_bin" -> Repl "Stdio.Out_channel.create" + | "open_out_gen" -> No_equivalent + | "out_channel_length" -> Repl "Stdio.Out_channel.length" + | "output" -> Repl "Stdio.Out_channel.output" + | "output_binary_int" -> Repl "Stdio.Out_channel.output_binary_int" + | "output_byte" -> Repl "Stdio.Out_channel.output_byte" + | "output_bytes" -> Repl "Stdio.Out_channel.output_bytes" + | "output_char" -> Repl "Stdio.Out_channel.output_char" + | "output_string" -> Repl "Stdio.Out_channel.output_string" + | "output_substring" -> Repl "Stdio.Out_channel.output" + | "output_value" -> Repl "Stdio.Out_channel.output_value" + | "pos_in" -> Repl "Stdio.In_channel.pos" + | "pos_out" -> Repl "Stdio.Out_channel.pos" + | "pred" -> Repl "Int.pred" + | "prerr_bytes" -> Repl "Stdio.Out_channel.output_bytes Stdio.stderr" + | "prerr_char" -> Repl "Stdio.Out_channel.output_char Stdio.stderr" + | "prerr_endline" -> Repl "Stdio.prerr_endline" + | "prerr_float" -> Repl "Stdio.eprintf \"%f\"" + | "prerr_int" -> Repl "Stdio.eprintf \"%d\"" + | "prerr_newline" -> Repl "Stdio.eprintf \"\n%!\"" + | "prerr_string" -> Repl "Stdio.Out_channel.output_string Stdio.stderr" + | "print_bytes" -> Repl "Stdio.Out_channel.output_bytes Stdio.stdout" + | "print_char" -> Repl "Stdio.Out_channel.output_char Stdio.stdout" + | "print_endline" -> Repl "Stdio.print_endline" + | "print_float" -> Repl "Stdio.eprintf \"%f\"" + | "print_int" -> Repl "Stdio.eprintf \"%d\"" + | "print_newline" -> Repl "Stdio.eprintf \"\n%!\"" + | "print_string" -> Repl "Stdio.Out_channel.output_string Stdio.stdout" + | "read_float" -> No_equivalent + | "read_int" -> No_equivalent + | "read_line" -> Repl "Stdio.In_channel.input_line" + | "really_input" -> Repl "Stdio.In_channel.really_input" + | "really_input_string" -> Approx "Stdio.Out_channel" + | "seek_in" -> Repl "Stdio.In_channel.seek" + | "seek_out" -> Repl "Stdio.Out_channel.seek" + | "set_binary_mode_in" -> Repl "Stdio.In_channel.set_binary_mode" + | "set_binary_mode_out" -> Repl "Stdio.Out_channel.set_binary_mode" + | "sin" -> Repl "Float.sin" + | "sinh" -> Repl "Float.sinh" + | "sqrt" -> Repl "Float.sqrt" + | "stderr" -> Repl "Stdio.stderr" + | "stdin" -> Repl "Stdio.stdin" + | "stdout" -> Repl "Stdio.stdout" + | "string_of_bool" -> Repl "Bool.to_string" + | "string_of_float" -> Repl "Float.to_string" + | "string_of_int" -> Repl "Int.to_string" + | "succ" -> Repl "Int.succ" + | "tan" -> Repl "Float.tan" + | "tanh" -> Repl "Float.tanh" + | "truncate" -> Repl "Int.of_float" + (* This is documented as DO-NOT-USE in the stdlib *) + | "unsafe_really_input" -> No_equivalent + | _ -> No_equivalent +;; + +let exception_replacement = function + | "Not_found" -> + Some (Repl_text "\ +Instead of raising [Not_found], consider using [raise_s] with an informative error\n\ +message. If code needs to distinguish [Not_found] from other exceptions, please change\n\ +it to handle both [Not_found] and [Not_found_s]. Then, instead of raising [Not_found],\n\ +raise [Not_found_s] with an informative error message") + | _ -> None + +let module_replacement = function + | "Printexc" -> Some (Repl_text "Use [Exn] or [Backtrace] instead") + | "Format" -> + let repl_text = + "[Base] doesn't export a [Format] module, although the \n\ + [Caml.Format.formatter] type is available (as [Formatter.t])\n\ + for interaction with other libraries" + in + Some (Repl_text repl_text) + | "Fun" -> Some (Repl_text "Use [Fn] instead") + | _ -> None + +let replace ~is_exn id replacement line = + let msg = + match replacement with + | No_equivalent -> deprecated_msg_no_equivalent ~is_exn id + | Repl repl -> deprecated_msg_with_repl ~is_exn repl + | Repl_text text -> deprecated_msg_with_repl_text ~is_exn text + | Approx repl -> deprecated_msg_with_approx_repl ~is_exn repl ~id + in + sprintf "%s\n%s" line msg +;; +} + +let id_trail = ['a'-'z' 'A'-'Z' '_' '0'-'9']* + let id = ['a'-'z' 'A'-'Z' '_' '0'-'9'] id_trail +let val_id = id | '(' [^ ')']* ')' +let params = ('(' [^')']* ')' | ['+' '-']? '\'' id) " " + +let val_ = "val " | "external " + +rule line = parse + | "module Camlinternal" _* + { "" (* We can't deprecate these *) } + | "module Bigarray" _* { "" (* Don't deprecate it yet *) } + | "type " (params? (id as id) _* as def) + { sprintf "type nonrec %s\n%s" def + (match id with + | "in_channel" -> deprecated_msg_with_repl ~is_exn:false "Stdio.In_channel.t" + | "out_channel" -> deprecated_msg_with_repl ~is_exn:false "Stdio.Out_channel.t" + | _ -> deprecated_msg ~is_exn:false id) + } + + | val_ (val_id as id) _* as line { replace ~is_exn:false id (val_replacement id) line } + + | "module " (id as id) " = Stdlib__" (id as id2) (_* as line) + { + let line = + Printf.sprintf "module %s = Stdlib.%s %s" + id (String.capitalize_ascii id2) line in + match module_replacement id with + | Some replacement -> replace ~is_exn:false id replacement line + | None -> sprintf "%s\n%s" line (deprecated_msg ~is_exn:false id) } + + | "exception " (id as id) _* as line + { match exception_replacement id with + | Some replacement -> replace ~is_exn:true id replacement line + | None -> + let predefined_exceptions = + [ "Out_of_memory" + ; "Sys_error" + ; "Failure" + ; "Invalid_argument" + ; "End_of_file" + ; "Division_by_zero" + ; "Not_found" + ; "Match_failure" + ; "Stack_overflow" + ; "Sys_blocked_io" + ; "Assert_failure" + ; "Undefined_recursive_module" ] + in + if List.mem id ~set:predefined_exceptions + then "" + else sprintf "%s\n%s" line (deprecated_msg ~is_exn:true id) + } + | "module " (id as id) _* as line + { match module_replacement id with + | Some replacement -> replace ~is_exn:false id replacement line + | None -> sprintf "%s\n%s" line (deprecated_msg ~is_exn:false id) } + | _* as line + { ksprintf failwith "cannot parse this: %s" line } diff --git a/shadow-stdlib/src/dune b/shadow-stdlib/src/dune new file mode 100644 index 0000000..edea282 --- /dev/null +++ b/shadow-stdlib/src/dune @@ -0,0 +1,10 @@ +(library (name shadow_stdlib) (public_name base.shadow_stdlib) + (libraries caml) (preprocess no_preprocessing)) + +(rule (targets shadow_stdlib.mli) + (deps (:first_dep ../gen/gen.exe) + ../../compiler-stdlib/src/caml.cma) + (action + (run %{first_dep} -caml-cmi ../../compiler-stdlib/src/.caml.objs/caml.cmi + ../../compiler-stdlib/src/.caml.objs/byte/caml.cmi + -o %{targets}))) diff --git a/shadow-stdlib/src/shadow_stdlib.ml b/shadow-stdlib/src/shadow_stdlib.ml new file mode 100644 index 0000000..62ab439 --- /dev/null +++ b/shadow-stdlib/src/shadow_stdlib.ml @@ -0,0 +1 @@ +include Caml diff --git a/src/am_testing.c b/src/am_testing.c new file mode 100644 index 0000000..30ce05c --- /dev/null +++ b/src/am_testing.c @@ -0,0 +1,9 @@ +#include <caml/mlvalues.h> + +/* The default [Base_am_testing] value is [false]. [ppx_inline_test] overrides + the default by linking against an implementation of [Base_am_testing] that + returns [true]. */ +CAMLprim CAMLweakdef value Base_am_testing() +{ + return Val_false; +} diff --git a/src/am_testing.h b/src/am_testing.h new file mode 100644 index 0000000..cb5d923 --- /dev/null +++ b/src/am_testing.h @@ -0,0 +1,11 @@ +#ifndef BASE_AM_TESTING_H +#define BASE_AM_TESTING_H +#include <caml/mlvalues.h> + +CAMLprim value Base_am_testing (); + +static inline int am_testing () { + return Bool_val (Base_am_testing ()); +} + +#endif diff --git a/src/applicative.ml b/src/applicative.ml new file mode 100644 index 0000000..3a32955 --- /dev/null +++ b/src/applicative.ml @@ -0,0 +1,159 @@ +open! Import + +include Applicative_intf + +(** This module serves mostly as a partial check that [S2] and [S] are in sync, but + actually calling it is occasionally useful. *) +module S_to_S2 (X : S) : (S2 with type ('a, 'e) t = 'a X.t) = struct + type ('a, 'e) t = 'a X.t + include (X : S with type 'a t := 'a X.t) +end + +module S2_to_S (X : S2) : (S with type 'a t = ('a, unit) X.t) = struct + type 'a t = ('a, unit) X.t + include (X : S2 with type ('a, 'e) t := ('a, 'e) X.t) +end + +module Args_to_Args2 (X : Args) : ( + Args2 with type ('a, 'e) arg = 'a X.arg + with type ('f, 'r, 'e) t = ('f, 'r) X.t +) = struct + type ('a, 'e) arg = 'a X.arg + type ('f, 'r, 'e) t = ('f, 'r) X.t + include (X : Args with type 'a arg := 'a X.arg and type ('f, 'r) t := ('f, 'r) X.t) +end +[@@warning "-3"] + +module Make2 (X : Basic2) : S2 with type ('a, 'e) t := ('a, 'e) X.t = struct + + include X + + let (<*>) = apply + + let derived_map t ~f = return f <*> t + + let map = + match X.map with + | `Define_using_apply -> derived_map + | `Custom x -> x + + let ( >>|) t f = map t ~f + + let map2 ta tb ~f = map ~f ta <*> tb + + let map3 ta tb tc ~f = map ~f ta <*> tb <*> tc + + let all ts = List.fold_right ts ~init:(return []) ~f:(map2 ~f:(fun x xs -> x :: xs)) + + let both ta tb = map2 ta tb ~f:(fun a b -> (a, b)) + + let ( *> ) u v = return (fun () y -> y) <*> u <*> v + let ( <* ) u v = return (fun x () -> x) <*> u <*> v + + let all_unit ts = List.fold ts ~init:(return ()) ~f:( *> ) + let all_ignore = all_unit + + module Applicative_infix = struct + let ( <*> ) = ( <*> ) + let ( *> ) = ( *> ) + let ( <* ) = ( <* ) + let ( >>| ) = ( >>| ) + end +end + +module Make (X : Basic) : S with type 'a t := 'a X.t = + Make2 (struct + type ('a, 'e) t = 'a X.t + include (X : Basic with type 'a t := 'a X.t) + end) + +module Make_let_syntax (X : For_let_syntax) (Intf : sig module type S end) (Impl : Intf.S) = struct + module Let_syntax = struct + include X + module Let_syntax = struct + include X + module Open_on_rhs = Impl + end + end +end + +module Make2_using_map2 (X : Basic2_using_map2) = + Make2 (struct + include X + let apply tf tx = map2 tf tx ~f:(fun f x -> f x) + let map = + match map with + | `Custom map -> `Custom map + | `Define_using_map2 -> `Define_using_apply + end) + +module Make_using_map2 (X : Basic_using_map2) : S with type 'a t := 'a X.t = + Make2_using_map2 (struct + type ('a, 'e) t = 'a X.t + include (X : Basic_using_map2 with type 'a t := 'a X.t) + end) + +module Make_args' (X : S2) = struct + open X + + type ('f, 'r, 'e) t_ = { applyN : ('f, 'e) X.t -> ('r, 'e) X.t } + + let nil = { applyN = Fn.id } + + let cons arg t = { applyN = fun d -> t.applyN (apply d arg) } + + let step t ~f = { applyN = fun d -> t.applyN (map ~f d) } + + let (@>) = cons + + let applyN arg t = t.applyN arg + + let mapN ~f t = applyN (return f) t +end + +module Make_args (X : S) : Args with type 'a arg := 'a X.t = struct + include Make_args' (struct + type ('a, 'e) t = 'a X.t + include (X : S with type 'a t := 'a X.t) + end) + + type ('f, 'r) t = ('f, 'r, unit) t_ +end +[@@warning "-3"] + +module Make_args2 (X : S2) : Args2 with type ('a, 'e) arg := ('a, 'e) X.t = struct + include Make_args' (X) + + type ('f, 'r, 'e) t = ('f, 'r, 'e) t_ +end +[@@warning "-3"] + +module Of_monad (M : Monad.S) : S with type 'a t := 'a M.t = + Make (struct + type 'a t = 'a M.t + let return = M.return + let apply mf mx = M.bind mf ~f:(fun f -> M.map mx ~f) + let map = `Custom M.map + end) + +module Compose (F : S) (G : S) : S with type 'a t = 'a F.t G.t = struct + type 'a t = 'a F.t G.t + include Make (struct + type nonrec 'a t = 'a t + let return a = G.return (F.return a) + let apply tf tx = G.apply (G.map ~f:F.apply tf) tx + let custom_map t ~f = G.map ~f:(F.map ~f) t + let map = `Custom custom_map + end) +end + +module Pair (F : S) (G : S) : S with type 'a t = 'a F.t * 'a G.t = struct + type 'a t = 'a F.t * 'a G.t + include Make (struct + type nonrec 'a t = 'a t + let return a = (F.return a, G.return a) + let apply tf tx = (F.apply (fst tf) (fst tx), G.apply (snd tf) (snd tx)) + let custom_map t ~f = (F.map ~f (fst t), G.map ~f (snd t)) + let map = `Custom custom_map + end) +end diff --git a/src/applicative.mli b/src/applicative.mli new file mode 100644 index 0000000..3a4f2d2 --- /dev/null +++ b/src/applicative.mli @@ -0,0 +1 @@ +include Applicative_intf.Applicative (** @inline *) diff --git a/src/applicative_intf.ml b/src/applicative_intf.ml new file mode 100644 index 0000000..e866442 --- /dev/null +++ b/src/applicative_intf.ml @@ -0,0 +1,318 @@ +(** Applicatives model computations in which values computed by subcomputations cannot + affect what subsequent computations will take place. + + Relative to monads, this restriction takes power away from the user of the interface + and gives it to the implementation. In particular, because the structure of the + entire computation is known, one can augment its definition with some description of + that structure. + + For more information, see: + + {v + Applicative Programming with Effects. + Conor McBride and Ross Paterson. + Journal of Functional Programming 18:1 (2008), pages 1-13. + http://staff.city.ac.uk/~ross/papers/Applicative.pdf + v} *) + +open! Import + +module type Basic = sig + type 'a t + val return : 'a -> 'a t + val apply : ('a -> 'b) t -> 'a t -> 'b t + (** The following identities ought to hold for every Applicative (for some value of =): + + - identity: [return Fn.id <*> t = t] + - composition: [return Fn.compose <*> tf <*> tg <*> tx = tf <*> (tg <*> tx)] + - homomorphism: [return f <*> return x = return (f x)] + - interchange: [tf <*> return x = return (fun f -> f x) <*> tf] + + Note: <*> is the infix notation for apply. *) + + (** The [map] argument to [Applicative.Make] says how to implement the applicative's + [map] function. [`Define_using_apply] means to define [map t ~f = return f <*> t]. + [`Custom] overrides the default implementation, presumably with something more + efficient. + + Some other functions returned by [Applicative.Make] are defined in terms of [map], + so passing in a more efficient [map] will improve their efficiency as well. *) + val map : [`Define_using_apply | `Custom of ('a t -> f:('a -> 'b) -> 'b t)] +end + +module type Basic_using_map2 = sig + type 'a t + val return : 'a -> 'a t + val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + val map : [`Define_using_map2 | `Custom of ('a t -> f:('a -> 'b) -> 'b t)] +end + +module type Applicative_infix = sig + type 'a t + + val ( <*> ) : ('a -> 'b) t -> 'a t -> 'b t (** same as [apply] *) + + val ( <* ) : 'a t -> unit t -> 'a t + val ( *> ) : unit t -> 'a t -> 'a t + + val ( >>| ) : 'a t -> ('a -> 'b) -> 'b t +end + +module type For_let_syntax = sig + type 'a t + + val return : 'a -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + + val both : 'a t -> 'b t -> ('a * 'b) t + + include Applicative_infix with type 'a t := 'a t + +end + +module type S = sig + + include For_let_syntax + + val apply : ('a -> 'b) t -> 'a t -> 'b t + + val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + + val map3 : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t + + val all : 'a t list -> 'a list t + + val all_unit : unit t list -> unit t + + val all_ignore : unit t list -> unit t [@@deprecated "[since 2018-02] Use [all_unit]"] + + module Applicative_infix : Applicative_infix with type 'a t := 'a t +end + +module type Let_syntax = sig + type 'a t + + module Open_on_rhs_intf : sig + module type S + end + + module Let_syntax : sig + + val return : 'a -> 'a t + include Applicative_infix with type 'a t := 'a t + + module Let_syntax : sig + + val return : 'a -> 'a t + + val map : 'a t -> f:('a -> 'b) -> 'b t + + val both : 'a t -> 'b t -> ('a * 'b) t + + module Open_on_rhs : Open_on_rhs_intf.S + end + end +end + +(** Argument lists and associated N-ary map and apply functions. *) +module type Args = sig + + type 'a arg (** the underlying applicative *) + + (** ['f] is the type of a function that consumes the list of arguments and returns an + ['r]. *) + type ('f, 'r) t + + (** the empty argument list **) + val nil : ('r, 'r) t + + (** prepend an argument *) + val cons : 'a arg -> ('f, 'r) t -> ('a -> 'f, 'r) t + + (** infix operator for [cons] *) + val (@>) : 'a arg -> ('f, 'r) t -> ('a -> 'f, 'r) t + + (** Transform argument values in some way. For example, one can label a function + argument like so: + + {[ + step ~f:(fun f x -> f ~foo:x) : ('a -> 'r1, 'r2) t -> (foo:'a -> 'r1, 'r2) t + ]} *) + val step : ('f1, 'r) t -> f:('f2 -> 'f1) -> ('f2, 'r) t + + (** The preferred way to factor out an [Args] sub-sequence: + + {[ + let args = + Foo.Args.( + bar "A" + (* TODO: factor out the common baz qux sub-sequence *) + @> baz "B" + @> qux "C" + @> zap "D" + @> nil + ) + ]} + + is to write a function that prepends the sub-sequence: + + {[ + let baz_qux remaining_args = + Foo.Args.( + baz "B" + @> qux "C" + @> remaining_args + ) + ]} + + and splice it back into the original sequence using [@@] so that things line up + nicely: + + {[ + let args = + Foo.Args.( + bar "A" + @> baz_qux + @@ zap "D" + @> nil + ) + ]} *) + + val mapN : f:'f -> ('f, 'r) t -> 'r arg + + val applyN : 'f arg -> ('f, 'r) t -> 'r arg + +end +[@@deprecated "[since 2018-09] Use [ppx_let] instead."] + +module type Basic2 = sig + type ('a, 'e) t + val return : 'a -> ('a, _) t + val apply : ('a -> 'b, 'e) t -> ('a, 'e) t -> ('b, 'e) t + val map : [`Define_using_apply | `Custom of (('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t)] +end + +module type Basic2_using_map2 = sig + type ('a, 'e) t + val return : 'a -> ('a, _) t + val map2 : ('a, 'e) t -> ('b, 'e) t -> f:('a -> 'b -> 'c) -> ('c, 'e) t + val map : [`Define_using_map2 | `Custom of (('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t)] +end + +module type S2 = sig + type ('a, 'e) t + + val return : 'a -> ('a, _) t + + val apply : ('a -> 'b, 'e) t -> ('a, 'e) t -> ('b, 'e) t + + val map : ('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t + + val map2 : ('a, 'e) t -> ('b, 'e) t -> f:('a -> 'b -> 'c) -> ('c, 'e) t + + val map3 + : ('a, 'e) t + -> ('b, 'e) t + -> ('c, 'e) t + -> f:('a -> 'b -> 'c -> 'd) + -> ('d, 'e) t + + val all : ('a, 'e) t list -> ('a list, 'e) t + + val all_unit : (unit, 'e) t list -> (unit, 'e) t + + val all_ignore : (unit, 'e) t list -> (unit, 'e) t + [@@deprecated "[since 2018-02] Use [all_unit]"] + + val both : ('a, 'e) t -> ('b, 'e) t -> ('a * 'b, 'e) t + + module Applicative_infix : sig + val ( <*> ) : ('a -> 'b, 'e) t -> ('a, 'e) t -> ('b, 'e) t + val ( <* ) : ('a, 'e) t -> (unit, 'e) t -> ('a, 'e) t + val ( *> ) : (unit, 'e) t -> ('a, 'e) t -> ('a, 'e) t + val ( >>| ) : ('a, 'e) t -> ('a -> 'b) -> ('b, 'e) t + end + + include module type of Applicative_infix +end + +module type Args2 = sig + type ('a, 'e) arg + + type ('f, 'r, 'e) t + + val nil : ('r, 'r, _) t + + val cons : ('a, 'e) arg -> ('f, 'r, 'e) t -> ('a -> 'f, 'r, 'e) t + val (@>) : ('a, 'e) arg -> ('f, 'r, 'e) t -> ('a -> 'f, 'r, 'e) t + + val step : ('f1, 'r, 'e) t -> f:('f2 -> 'f1) -> ('f2, 'r, 'e) t + + val mapN : f:'f -> ('f, 'r, 'e) t -> ('r, 'e) arg + val applyN : ('f, 'e) arg -> ('f, 'r, 'e) t -> ('r, 'e) arg +end +[@@deprecated "[since 2018-09] Use [ppx_let] instead."] + +module type Applicative = sig + + module type Applicative_infix = Applicative_infix + module type Args = Args + [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + module type Args2 = Args2 + [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + module type Basic = Basic + module type Basic2 = Basic2 + module type Basic2_using_map2 = Basic2_using_map2 + module type Basic_using_map2 = Basic_using_map2 + module type Let_syntax = Let_syntax + module type S = S + module type S2 = S2 + + module Args_to_Args2 (X : Args) : + Args2 + with type ('a, 'e) arg = 'a X.arg + with type ('f, 'r, 'e) t = ('f, 'r) X.t + [@@warning "-3"] + + module S2_to_S (X : S2) : S with type 'a t = ('a, unit) X.t + + module S_to_S2 (X : S) : S2 with type ('a, 'e) t = 'a X.t + + module Make (X : Basic ) : S with type 'a t := 'a X.t + module Make2 (X : Basic2) : S2 with type ('a, 'e) t := ('a, 'e) X.t + + module Make_let_syntax + (X : For_let_syntax) + (Intf : sig module type S end) + (Impl : Intf.S) + : Let_syntax with type 'a t := 'a X.t + with module Open_on_rhs_intf := Intf + + module Make_using_map2 (X : Basic_using_map2 ) : S with type 'a t := 'a X.t + module Make2_using_map2 (X : Basic2_using_map2) : S2 with type ('a, 'e) t := ('a, 'e) X.t + + module Make_args (X : S ) : Args with type 'a arg := 'a X.t [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + module Make_args2 (X : S2) : Args2 with type ('a, 'e) arg := ('a, 'e) X.t [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + + (** The following functors give a sense of what Applicatives one can define. + + Of these, [Of_monad] is likely the most useful. The others are mostly didactic. *) + + (** Every monad is Applicative via: + + {[ + let apply mf mx = + mf >>= fun f -> + mx >>| fun x -> + f x + ]} *) + module Of_monad (M : Monad.S) : S with type 'a t := 'a M.t + module Compose (F : S) (G : S) : S with type 'a t = 'a F.t G.t + module Pair (F : S) (G : S) : S with type 'a t = 'a F.t * 'a G.t + +end diff --git a/src/array.ml b/src/array.ml new file mode 100644 index 0000000..5c94b1f --- /dev/null +++ b/src/array.ml @@ -0,0 +1,731 @@ +open! Import + +include Array0 +module Int = Int0 + +let raise_s = Error.raise_s + +type 'a t = 'a array [@@deriving_inline compare, sexp] +let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = compare_array +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + array_of_sexp +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_array +[@@@end] + +(* This module implements a new in-place, constant heap sorting algorithm to replace the + one used by the standard libraries. Its only purpose is to be faster (hopefully + strictly faster) than the base sort and stable_sort. + + At a high level the algorithm is: + - pick two pivot points by: + - pick 5 arbitrary elements from the array + - sort them within the array + - take the elements on either side of the middle element of the sort as the pivots + - sort the array with: + - all elements less than pivot1 to the left (range 1) + - all elements >= pivot1 and <= pivot2 in the middle (range 2) + - all elements > pivot2 to the right (range 3) + - if pivot1 and pivot2 are equal, then the middle range is sorted, so ignore it + - recurse into range 1, 2 (if pivot1 and pivot2 are unequal), and 3 + - during recursion there are two inflection points: + - if the size of the current range is small, use insertion sort to sort it + - if the stack depth is large, sort the range with heap-sort to avoid n^2 worst-case + behavior + + See the following for more information: + - "Dual-Pivot Quicksort" by Vladimir Yaroslavskiy. + Available at + http://www.kriche.com.ar/root/programming/spaceTimeComplexity/DualPivotQuicksort.pdf + - "Quicksort is Optimal" by Sedgewick and Bentley. + Slides at http://www.cs.princeton.edu/~rs/talks/QuicksortIsOptimal.pdf + - http://www.sorting-algorithms.com/quick-sort-3-way *) + +module Sort = struct + (* For the sake of speed we could use unsafe get/set throughout, but speed tests don't + show a significant improvement. *) + let get = get + let set = set + + let swap arr i j = + let tmp = get arr i in + set arr i (get arr j); + set arr j tmp + ;; + + module type Sort = sig + val sort + : 'a t + -> compare:('a -> 'a -> int) + -> left:int (* leftmost index of sub-array to sort *) + -> right:int (* rightmost index of sub-array to sort *) + -> unit + end + + (* http://en.wikipedia.org/wiki/Insertion_sort *) + module Insertion_sort : Sort = struct + let sort arr ~compare ~left ~right = + (* loop invariant: + [arr] is sorted from [left] to [pos - 1], inclusive *) + for pos = left + 1 to right do + (* loop invariants: + 1. the subarray arr[left .. i-1] is sorted + 2. the subarray arr[i+1 .. pos] is sorted and contains only elements > v + 3. arr[i] may be thought of as containing v + + Note that this does not allocate a closure, but is left in the for + loop for the readability of the documentation. *) + let rec loop arr ~left ~compare i v = + let i_next = i - 1 in + if i_next >= left && compare (get arr i_next) v > 0 then begin + set arr i (get arr i_next); + loop arr ~left ~compare i_next v + end else + i + in + let v = get arr pos in + let final_pos = loop arr ~left ~compare pos v in + set arr final_pos v + done + ;; + end + + (* http://en.wikipedia.org/wiki/Heapsort *) + module Heap_sort : Sort = struct + (* loop invariant: + root's children are both either roots of max-heaps or > right *) + let rec heapify arr ~compare root ~left ~right = + let relative_root = root - left in + let left_child = (2 * relative_root) + left + 1 in + let right_child = (2 * relative_root) + left + 2 in + let largest = + if left_child <= right && compare (get arr left_child) (get arr root) > 0 + then left_child + else root + in + let largest = + if right_child <= right && compare (get arr right_child) (get arr largest) > 0 + then right_child + else largest + in + if largest <> root then begin + swap arr root largest; + heapify arr ~compare largest ~left ~right + end; + ;; + + let build_heap arr ~compare ~left ~right = + (* Elements in the second half of the array are already heaps of size 1. We move + through the first half of the array from back to front examining the element at + hand, and the left and right children, fixing the heap property as we go. *) + for i = (left + right) / 2 downto left do + heapify arr ~compare i ~left ~right; + done; + ;; + + let sort arr ~compare ~left ~right = + build_heap arr ~compare ~left ~right; + (* loop invariants: + 1. the subarray arr[left ... i] is a max-heap H + 2. the subarray arr[i+1 ... right] is sorted (call it S) + 3. every element of H is less than every element of S *) + for i = right downto left + 1 do + swap arr left i; + heapify arr ~compare left ~left ~right:(i - 1); + done; + ;; + end + + (* http://en.wikipedia.org/wiki/Introsort *) + module Intro_sort : sig + include Sort + val five_element_sort + : 'a t -> compare:('a -> 'a -> int) -> int -> int -> int -> int -> int -> unit + end = struct + + let five_element_sort arr ~compare m1 m2 m3 m4 m5 = + let compare_and_swap i j = + if compare (get arr i) (get arr j) > 0 then swap arr i j + in + (* optimal 5-element sorting network *) + compare_and_swap m1 m2; (* 1--o-----o-----o--------------1 *) + compare_and_swap m4 m5; (* | | | *) + compare_and_swap m1 m3; (* 2--o-----|--o--|-----o--o-----2 *) + compare_and_swap m2 m3; (* | | | | | *) + compare_and_swap m1 m4; (* 3--------o--o--|--o--|--o-----3 *) + compare_and_swap m3 m4; (* | | | *) + compare_and_swap m2 m5; (* 4-----o--------o--o--|-----o--4 *) + compare_and_swap m2 m3; (* | | | *) + compare_and_swap m4 m5; (* 5-----o--------------o-----o--5 *) + ;; + + (* choose pivots for the array by sorting 5 elements and examining the center three + elements. The goal is to choose two pivots that will either: + - break the range up into 3 even partitions + or + - eliminate a commonly appearing element by sorting it into the center partition + by itself + To this end we look at the center 3 elements of the 5 and return pairs of equal + elements or the widest range *) + let choose_pivots arr ~compare ~left ~right = + let sixth = (right - left) / 6 in + let m1 = left + sixth in + let m2 = m1 + sixth in + let m3 = m2 + sixth in + let m4 = m3 + sixth in + let m5 = m4 + sixth in + five_element_sort arr ~compare m1 m2 m3 m4 m5; + let m2_val = get arr m2 in + let m3_val = get arr m3 in + let m4_val = get arr m4 in + if compare m2_val m3_val = 0 then (m2_val, m3_val, true) + else if compare m3_val m4_val = 0 then (m3_val, m4_val, true) + else (m2_val, m4_val, false) + ;; + + let dual_pivot_partition arr ~compare ~left ~right = + let pivot1, pivot2, pivots_equal = choose_pivots arr ~compare ~left ~right in + (* loop invariants: + 1. left <= l < r <= right + 2. l <= p <= r + 3. l <= x < p implies arr[x] >= pivot1 + and arr[x] <= pivot2 + 4. left <= x < l implies arr[x] < pivot1 + 5. r < x <= right implies arr[x] > pivot2 *) + let rec loop l p r = + let pv = get arr p in + if compare pv pivot1 < 0 then begin + swap arr p l; + cont (l + 1) (p + 1) r + end else if compare pv pivot2 > 0 then begin + (* loop invariants: same as those of the outer loop *) + let rec scan_backwards r = + if r > p && compare (get arr r) pivot2 > 0 + then scan_backwards (r - 1) + else r + in + let r = scan_backwards r in + swap arr r p; + cont l p (r - 1) + end else + cont l (p + 1) r + and cont l p r = + if p > r then (l, r) else loop l p r + in + let (l, r) = cont left left right in + (l, r, pivots_equal) + ;; + + let rec intro_sort arr ~max_depth ~compare ~left ~right = + let len = right - left + 1 in + (* This takes care of some edge cases, such as left > right or very short arrays, + since Insertion_sort.sort handles these cases properly. Thus we don't need to + make sure that left and right are valid in recursive calls. *) + if len <= 32 then begin + Insertion_sort.sort arr ~compare ~left ~right + end else if max_depth < 0 then begin + Heap_sort.sort arr ~compare ~left ~right; + end else begin + let max_depth = max_depth - 1 in + let (l, r, middle_sorted) = dual_pivot_partition arr ~compare ~left ~right in + intro_sort arr ~max_depth ~compare ~left ~right:(l - 1); + if not middle_sorted then intro_sort arr ~max_depth ~compare ~left:l ~right:r; + intro_sort arr ~max_depth ~compare ~left:(r + 1) ~right; + end + ;; + + let log10_of_3 = Caml.log10 3. + + let log3 x = Caml.log10 x /. log10_of_3 + + let sort arr ~compare ~left ~right = + let len = right - left + 1 in + let heap_sort_switch_depth = + (* with perfect 3-way partitioning, this is the recursion depth *) + Int.of_float (log3 (Int.to_float len)) + in + intro_sort arr ~max_depth:heap_sort_switch_depth ~compare ~left ~right; + ;; + end +end + +let sort ?pos ?len arr ~compare = + let pos, len = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length arr) + in + Sort.Intro_sort.sort arr ~compare ~left:pos ~right:(pos + len - 1) + +let to_array t = t + +let is_empty t = length t = 0 + +let is_sorted t ~compare = + let rec is_sorted_loop t ~compare i = + if i < 1 + then true + else + compare t.(i - 1) t.(i) <= 0 + && is_sorted_loop t ~compare (i - 1) + in + is_sorted_loop t ~compare (length t - 1) +;; + +let is_sorted_strictly t ~compare = + let rec is_sorted_strictly_loop t ~compare i = + if i < 1 + then true + else + compare t.(i - 1) t.(i) < 0 + && is_sorted_strictly_loop t ~compare (i - 1) + in + is_sorted_strictly_loop t ~compare (length t - 1) +;; + +let folding_map t ~init ~f = + let acc = ref init in + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) +;; + +let fold_map t ~init ~f = + let acc = ref init in + let result = + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t +let count t ~f = Container.count ~fold t ~f +let sum m t ~f = Container.sum ~fold m t ~f +let min_elt t ~compare = Container.min_elt ~fold t ~compare +let max_elt t ~compare = Container.max_elt ~fold t ~compare + +let foldi t ~init ~f = + let rec foldi_loop t i ac ~f = + if i = length t + then ac + else foldi_loop t (i + 1) (f i ac t.(i)) ~f + in + foldi_loop t 0 init ~f +;; + +let folding_mapi t ~init ~f = + let acc = ref init in + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) +;; + +let fold_mapi t ~init ~f = + let acc = ref init in + let result = + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let counti t ~f = foldi t ~init:0 ~f:(fun idx count a -> if f idx a then count + 1 else count) + +let concat_map t ~f = concat (to_list (map ~f t)) +let concat_mapi t ~f = concat (to_list (mapi ~f t)) + +let rev_inplace t = + let i = ref 0 in + let j = ref (length t - 1) in + while !i < !j; do + swap t !i !j; + incr i; + decr j; + done +;; + +let of_list_rev l = + match l with + | [] -> [||] + | a :: l -> + let len = 1 + List.length l in + let t = create ~len a in + let r = ref l in + (* We start at [len - 2] because we already put [a] at [t.(len - 1)]. *) + for i = len - 2 downto 0 do + match !r with + | [] -> assert false + | a :: l -> t.(i) <- a; r := l + done; + t +;; + +(* [of_list_map] and [of_list_rev_map] are based on functions from the OCaml + distribution. *) + +let of_list_map xs ~f = + match xs with + | [] -> [||] + | hd::tl -> + let a = create ~len:(1 + List.length tl) (f hd) in + let rec fill i = function + | [] -> a + | hd::tl -> unsafe_set a i (f hd); fill (i+1) tl in + fill 1 tl + +let of_list_mapi xs ~f = + match xs with + | [] -> [||] + | hd::tl -> + let a = create ~len:(1 + List.length tl) (f 0 hd) in + let rec fill a i = function + | [] -> a + | hd::tl -> + unsafe_set a i (f i hd); + fill a (i+1) tl + in + fill a 1 tl + +let of_list_rev_map xs ~f = + let t = of_list_map xs ~f in + rev_inplace t; + t + +let of_list_rev_mapi xs ~f = + let t = of_list_mapi xs ~f in + rev_inplace t; + t + +(* [Obj.truncate] reduces the size of a block on the ocaml heap. For arrays, the block + size is the array length. This holds even for float arrays. *) +let unsafe_truncate t ~len = + if len <= 0 || len > length t then + raise_s (Sexp.message "Array.unsafe_truncate got invalid len" + ["len", sexp_of_int len]); + if len < length t then Caml.Obj.truncate (Caml.Obj.repr t) len; +;; + +let filter_mapi t ~f = + let r = ref [||] in + let k = ref 0 in + for i = 0 to length t - 1 do + match f i (unsafe_get t i) with + | None -> () + | Some a -> + if !k = 0 then begin + r := create ~len:(length t) a + end; + unsafe_set !r !k a; + incr k; + done; + if !k > 0 then begin + unsafe_truncate !r ~len:!k; + !r + end else + [||] + +let filter_map t ~f = + filter_mapi t ~f:(fun _i a -> f a) + +let filter_opt t = + filter_map t ~f:Fn.id + +let iter2_exn t1 t2 ~f = + if length t1 <> length t2 then invalid_arg "Array.iter2_exn"; + iteri t1 ~f:(fun i x1 -> f x1 t2.(i)) + +let map2_exn t1 t2 ~f = + let len = length t1 in + if length t2 <> len then invalid_arg "Array.map2_exn"; + init len ~f:(fun i -> f t1.(i) t2.(i)) + +let fold2_exn t1 t2 ~init ~f = + if length t1 <> length t2 then invalid_arg "Array.fold2_exn"; + foldi t1 ~init ~f:(fun i ac x -> f ac x t2.(i)) +;; + +let filter t ~f = filter_map t ~f:(fun x -> if f x then Some x else None) + +let filteri t ~f = filter_mapi t ~f:(fun i x -> if f i x then Some x else None) + +let exists t ~f = + let rec exists_loop t ~f i = + if i < 0 + then false + else f t.(i) || exists_loop t ~f (i - 1) + in + exists_loop t ~f (length t - 1) + +let existsi t ~f = + let rec existsi_loop t ~f i = + if i < 0 + then false + else f i t.(i) || existsi_loop t ~f (i - 1) + in + existsi_loop t ~f (length t - 1) + +let mem t a ~equal = exists t ~f:(equal a) + +let for_all t ~f = + let rec for_all_loop t ~f i = + if i < 0 + then true + else f t.(i) && for_all_loop t ~f (i - 1) + in + for_all_loop t ~f (length t - 1) + +let for_alli t ~f = + let rec for_alli_loop t ~f i = + if i < 0 + then true + else f i t.(i) && for_alli_loop t ~f (i - 1) + in + for_alli_loop t ~f (length t - 1) + +let exists2_exn t1 t2 ~f = + let rec exists2_exn_loop t1 t2 ~f i = + if i < 0 + then false + else f t1.(i) t2.(i) || exists2_exn_loop t1 t2 ~f (i - 1) + in + let len = length t1 in + if length t2 <> len then invalid_arg "Array.exists2_exn"; + exists2_exn_loop t1 t2 ~f (len - 1) + +let for_all2_exn t1 t2 ~f = + let rec for_all2_loop t1 t2 ~f i = + if i < 0 + then true + else f t1.(i) t2.(i) && for_all2_loop t1 t2 ~f (i - 1) + in + let len = length t1 in + if length t2 <> len then invalid_arg "Array.for_all2_exn"; + for_all2_loop t1 t2 ~f (len - 1) + +let equal equal t1 t2 = length t1 = length t2 && for_all2_exn t1 t2 ~f:equal + +let replace t i ~f = t.(i) <- f t.(i) + +let map_inplace t ~f = + for i = 0 to length t - 1 do + t.(i) <- f t.(i) + done + +let replace_all = map_inplace + +let findi t ~f = + let rec findi_loop t ~f ~length i = + if i >= length then None + else if f i t.(i) then Some (i, t.(i)) + else findi_loop t ~f ~length (i + 1) + in + let length = length t in + findi_loop t ~f ~length 0 +;; + +let findi_exn t ~f = + match findi t ~f with + | None -> raise Caml.Not_found + | Some x -> x +;; + +let find_exn t ~f = + match findi t ~f:(fun _i x -> f x) with + | None -> raise Caml.Not_found + | Some (_i, x) -> x +;; + +let find t ~f = Option.map (findi t ~f:(fun _i x -> f x)) ~f:(fun (_i, x) -> x) + +let find_map t ~f = + let rec find_map_loop t ~f ~length i = + if i >= length then None + else + match f t.(i) with + | None -> find_map_loop t ~f ~length (i + 1) + | Some _ as res -> res + in + let length = length t in + find_map_loop t ~f ~length 0 +;; + +let find_map_exn t ~f = + match find_map t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let find_mapi t ~f = + let rec find_mapi_loop t ~f ~length i = + if i >= length then None + else + match f i t.(i) with + | None -> find_mapi_loop t ~f ~length (i + 1) + | Some _ as res -> res + in + let length = length t in + find_mapi_loop t ~f ~length 0 +;; + +let find_mapi_exn t ~f = + match find_mapi t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let find_consecutive_duplicate t ~equal = + let n = length t in + if n <= 1 + then None + else begin + let result = ref None in + let i = ref 1 in + let prev = ref t.(0) in + while !i < n do + let cur = t.(!i) in + if equal cur !prev + then (result := Some (!prev, cur); i := n) + else (prev := cur; incr i) + done; + !result + end +;; + +let reduce t ~f = + if length t = 0 then None + else begin + let r = ref t.(0) in + for i = 1 to length t - 1 do + r := f !r t.(i) + done; + Some !r + end + +let reduce_exn t ~f = + match reduce t ~f with + | None -> invalid_arg "Array.reduce_exn" + | Some v -> v + +let permute = Array_permute.permute + +let random_element_exn ?(random_state = Random.State.default) t = + if is_empty t + then failwith "Array.random_element_exn: empty array" + else t.(Random.State.int random_state (length t)) + +let random_element ?(random_state = Random.State.default) t = + try Some (random_element_exn ~random_state t) + with _ -> None + +let zip t1 t2 = + if length t1 <> length t2 then None + else Some (map2_exn t1 t2 ~f:(fun x1 x2 -> x1, x2)) + +let zip_exn t1 t2 = + if length t1 <> length t2 then failwith "Array.zip_exn" + else map2_exn t1 t2 ~f:(fun x1 x2 -> x1, x2) + +let unzip t = + let n = length t in + if n = 0 then [||], [||] + else + let x, y = t.(0) in + let res1 = create ~len:n x in + let res2 = create ~len:n y in + for i = 1 to n - 1 do + let x, y = t.(i) in + res1.(i) <- x; + res2.(i) <- y; + done; + res1, res2 + +let sorted_copy t ~compare = + let t1 = copy t in + sort t1 ~compare; + t1 + +let partitioni_tf t ~f = + let both = mapi t ~f:(fun i x -> if f i x then Either.First x else Either.Second x) in + let trues = filter_map both ~f:(function First x -> Some x | Second _ -> None) in + let falses = filter_map both ~f:(function First _ -> None | Second x -> Some x) in + (trues, falses) + +let partition_tf t ~f = + partitioni_tf t ~f:(fun _i x -> f x) + +let last t = t.(length t - 1) + +(* Convert to a sequence but does not attempt to protect against modification + in the array. *) +let to_sequence_mutable t = + Sequence.unfold_step ~init:0 ~f:(fun i -> + if i >= length t + then Sequence.Step.Done + else Sequence.Step.Yield (t.(i), i+1)) + +let to_sequence t = to_sequence_mutable (copy t) + +let cartesian_product t1 t2 = + if is_empty t1 || is_empty t2 then + [||] + else + let n1 = length t1 in + let n2 = length t2 in + let t = create ~len:(n1 * n2) (t1.(0), t2.(0)) in + let r = ref 0 in + for i1 = 0 to n1 - 1 do + for i2 = 0 to n2 - 1 do + t.(!r) <- (t1.(i1), t2.(i2)); + incr r; + done + done; + t +;; + +let transpose tt = + if length tt = 0 + then Some [||] + else + let width = length tt in + let depth = length tt.(0) in + if exists tt ~f:(fun t -> length t <> depth) + then None + else Some (init depth ~f:(fun d -> init width ~f:(fun w -> tt.(w).(d)))) + +let transpose_exn tt = + match transpose tt with + | None -> invalid_arg "Array.transpose_exn"; + | Some tt' -> tt' + +include Binary_searchable.Make1 (struct + type nonrec 'a t = 'a t + + let get = get + let length = length + end) + +include + Blit.Make1 + (struct + type nonrec 'a t = 'a t + let length = length + let create_like ~len t = + if len = 0 + then [||] + else (assert (length t > 0); create ~len t.(0)) + ;; + let unsafe_blit = blit + end) +;; + +let invariant invariant_a t = iter t ~f:invariant_a + +module Private = struct + module Sort = Sort +end diff --git a/src/array.mli b/src/array.mli new file mode 100644 index 0000000..befce86 --- /dev/null +++ b/src/array.mli @@ -0,0 +1,339 @@ +(** Mutable vector of elements of type ['a] with O(1) [get] and [set] operations. *) + +open! Import + +type 'a t = 'a array [@@deriving_inline compare, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Binary_searchable.S1 with type 'a t := 'a t + +include Container.S1 with type 'a t := 'a t + +include Invariant.S1 with type 'a t := 'a t + +(** Maximum length of a normal array. The maximum length of a float array is + [max_length/2] on 32-bit machines and [max_length] on 64-bit machines. *) +val max_length : int + +(** [Array.get a n] returns the element number [n] of array [a]. + The first element has number 0. + The last element has number [Array.length a - 1]. + You can also write [a.(n)] instead of [Array.get a n]. + + Raise [Invalid_argument "index out of bounds"] + if [n] is outside the range 0 to [(Array.length a - 1)]. *) +external get : 'a t -> int -> 'a = "%array_safe_get" + +(** [Array.set a n x] modifies array [a] in place, replacing + element number [n] with [x]. + You can also write [a.(n) <- x] instead of [Array.set a n x]. + + Raise [Invalid_argument "index out of bounds"] + if [n] is outside the range 0 to [Array.length a - 1]. *) +external set : 'a t -> int -> 'a -> unit = "%array_safe_set" + +(** Unsafe version of [get]. Can cause arbitrary behavior when used for an out-of-bounds + array access. *) +external unsafe_get : 'a t -> int -> 'a = "%array_unsafe_get" + +(** Unsafe version of [set]. Can cause arbitrary behavior when used for an out-of-bounds + array access. *) +external unsafe_set : 'a t -> int -> 'a -> unit = "%array_unsafe_set" + +(** [create ~len x] creates an array of length [len] with the value [x] populated in + each element. *) +val create : len:int -> 'a -> 'a t + +(** [init n ~f] creates an array of length [n] where the [i]th element (starting at zero) + is initialized with [f i]. *) +val init : int -> f:(int -> 'a) -> 'a t + +(** [Array.make_matrix dimx dimy e] returns a two-dimensional array (an array of arrays) + with first dimension [dimx] and second dimension [dimy]. All the elements of this new + matrix are initially physically equal to [e]. The element ([x,y]) of a matrix [m] is + accessed with the notation [m.(x).(y)]. + + Raise [Invalid_argument] if [dimx] or [dimy] is negative or greater than + [Array.max_length]. + + If the value of [e] is a floating-point number, then the maximum size is only + [Array.max_length / 2]. *) +val make_matrix : dimx:int -> dimy:int -> 'a -> 'a t t + +(** [Array.append v1 v2] returns a fresh array containing the concatenation of the arrays + [v1] and [v2]. *) +val append : 'a t -> 'a t -> 'a t + +(** Like [Array.append], but concatenates a list of arrays. *) +val concat : 'a t list -> 'a t + +(** [Array.copy a] returns a copy of [a], that is, a fresh array + containing the same elements as [a]. *) +val copy : 'a t -> 'a t + +(** [Array.fill a ofs len x] modifies the array [a] in place, storing [x] in elements + number [ofs] to [ofs + len - 1]. + + Raise [Invalid_argument "Array.fill"] if [ofs] and [len] do not designate a valid + subarray of [a]. *) +val fill : 'a t -> pos:int -> len:int -> 'a -> unit + +(** [Array.blit v1 o1 v2 o2 len] copies [len] elements from array [v1], starting at + element number [o1], to array [v2], starting at element number [o2]. It works + correctly even if [v1] and [v2] are the same array, and the source and destination + chunks overlap. + + Raise [Invalid_argument "Array.blit"] if [o1] and [len] do not designate a valid + subarray of [v1], or if [o2] and [len] do not designate a valid subarray of [v2]. + + [int_blit] and [float_blit] provide fast bound-checked blits for immediate + data types. The unsafe versions do not bound-check the arguments. *) +include Blit.S1 with type 'a t := 'a t + +(** [Array.of_list l] returns a fresh array containing the elements of [l]. *) +val of_list : 'a list -> 'a t + +(** [Array.map t ~f] applies function [f] to all the elements of [t], and builds an array + with the results returned by [f]: [[| f t.(0); f t.(1); ...; f t.(Array.length t - 1) + |]]. *) +val map : 'a t -> f:('a -> 'b) -> 'b t + +(** [folding_map] is a version of [map] that threads an accumulator through calls to + [f]. *) +val folding_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'c t +val folding_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'c t + +(** [Array.fold_map] is a combination of [Array.fold] and [Array.map] that threads an + accumulator through calls to [f]. *) +val fold_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'b * 'c t +val fold_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'b * 'c t + +(** Like {!Array.iter}, but the function is applied to the index of the element as first + argument, and the element itself as second argument. *) +val iteri : 'a t -> f:(int -> 'a -> unit) -> unit + +(** Like {!Array.map}, but the function is applied to the index of the element as first + argument, and the element itself as second argument. *) +val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +val foldi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b) -> 'b + +(** [Array.fold_right f a ~init] computes [f a.(0) (f a.(1) ( ... (f a.(n-1) init) ...))], + where [n] is the length of the array [a]. *) +val fold_right : 'a t -> f:('a -> 'b -> 'b) -> init:'b -> 'b + +(** All sort functions in this module sort in increasing order by default. *) + +(** [sort] uses constant heap space. [stable_sort] uses linear heap space. + + To sort only part of the array, specify [pos] to be the index to start sorting from + and [len] indicating how many elements to sort. *) +val sort : ?pos:int -> ?len:int -> 'a t -> compare:('a -> 'a -> int) -> unit +val stable_sort : 'a t -> compare:('a -> 'a -> int) -> unit + +val is_sorted : 'a t -> compare:('a -> 'a -> int) -> bool + +(** [is_sorted_strictly xs ~compare] iff [is_sorted xs ~compare] and no two + consecutive elements in [xs] are equal according to [compare]. *) +val is_sorted_strictly : 'a t -> compare:('a -> 'a -> int) -> bool + +(** Like [List.concat_map], [List.concat_mapi]. *) +val concat_map : 'a t -> f:( 'a -> 'b array) -> 'b array +val concat_mapi : 'a t -> f:(int -> 'a -> 'b array) -> 'b array + +val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t + +val partitioni_tf : 'a t -> f:(int -> 'a -> bool) -> 'a t * 'a t + +val cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** [transpose] in the sense of a matrix transpose. It returns [None] if the arrays are + not all the same length. *) +val transpose : 'a t t -> 'a t t option +val transpose_exn : 'a t t -> 'a t t + +(** [filter_opt array] returns a new array where [None] entries are omitted and [Some x] + entries are replaced with [x]. Note that this changes the index at which elements + will appear. *) +val filter_opt : 'a option t -> 'a t + +(** [filter_map ~f array] maps [f] over [array] and filters [None] out of the + results. *) +val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** Like [filter_map] but uses {!Array.mapi}. *) +val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** Like [for_all], but passes the index as an argument. *) +val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like [exists], but passes the index as an argument. *) +val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like [count], but passes the index as an argument. *) +val counti : 'a t -> f:(int -> 'a -> bool) -> int + +(** Functions with the 2 suffix raise an exception if the lengths of the two given arrays + aren't the same. *) + +val iter2_exn : 'a t -> 'b t -> f:('a -> 'b -> unit) -> unit + +val map2_exn : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + +val fold2_exn : 'a t -> 'b t -> init:'c -> f:('c -> 'a -> 'b -> 'c) -> 'c + +(** [for_all2_exn t1 t2 ~f] fails if [length t1 <> length t2]. *) +val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool + +(** [exists2_exn t1 t2 ~f] fails if [length t1 <> length t2]. *) +val exists2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool + +(** [filter t ~f] removes the elements for which [f] returns false. *) +val filter : 'a t -> f:('a -> bool) -> 'a t + +(** Like [filter] except [f] also receives the index. *) +val filteri : 'a t -> f:(int -> 'a -> bool) -> 'a t + +(** [swap arr i j] swaps the value at index [i] with that at index [j]. *) +val swap : 'a t -> int -> int -> unit + +(** [rev_inplace t] reverses [t] in place. *) +val rev_inplace : 'a t -> unit + +(** [of_list_rev l] converts from list then reverses in place. *) +val of_list_rev : 'a list -> 'a t + +(** [of_list_map l ~f] is the same as [of_list (List.map l ~f)]. *) +val of_list_map : 'a list -> f:('a -> 'b) -> 'b t + +(** [of_list_mapi l ~f] is the same as [of_list (List.mapi l ~f)]. *) +val of_list_mapi : 'a list -> f:(int -> 'a -> 'b) -> 'b t + +(** [of_list_rev_map l ~f] is the same as [of_list (List.rev_map l ~f)]. *) +val of_list_rev_map : 'a list -> f:('a -> 'b) -> 'b t + +(** [of_list_rev_mapi l ~f] is the same as [of_list (List.rev_mapi l ~f)]. *) +val of_list_rev_mapi : 'a list -> f:(int -> 'a -> 'b) -> 'b t + +(** [replace t i ~f] = [t.(i) <- f (t.(i))]. *) +val replace : 'a t -> int -> f:('a -> 'a) -> unit +[@@deprecated "[since 2018-09] use [t.(i) <- f (t.(i))] instead"] + +(** Modifies an array in place -- [ar.(i)] will be set to [f(ar.(i))]. *) +val replace_all : 'a t -> f:('a -> 'a) -> unit +[@@deprecated "[since 2018-03] use [map_inplace] instead"] + +(** Modifies an array in place, applying [f] to every element of the array *) +val map_inplace : 'a t -> f:('a -> 'a) -> unit + +(** [find_exn f t] returns the first [a] in [t] for which [f t.(i)] is true. It raises + [Caml.Not_found] or [Not_found_s] if there is no such [a]. *) +val find_exn : 'a t -> f:('a -> bool) -> 'a + +(** Returns the first evaluation of [f] that returns [Some]. Raises [Caml.Not_found] or + [Not_found_s] if [f] always returns [None]. *) +val find_map_exn : 'a t -> f:('a -> 'b option) -> 'b + +(** [findi t f] returns the first index [i] of [t] for which [f i t.(i)] is true *) +val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + +(** [findi_exn t f] returns the first index [i] of [t] for which [f i t.(i)] is true. It + raises [Caml.Not_found] or [Not_found_s] if there is no such element. *) +val findi_exn : 'a t -> f:(int -> 'a -> bool) -> int * 'a + +(** [find_mapi t f] is like [find_map] but passes the index as an argument. *) +val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option + +(** [find_mapi_exn] is like [find_map_exn] but passes the index as an argument. *) +val find_mapi_exn : 'a t -> f:(int -> 'a -> 'b option) -> 'b + +(** [find_consecutive_duplicate t ~equal] returns the first pair of consecutive elements + [(a1, a2)] in [t] such that [equal a1 a2]. They are returned in the same order as + they appear in [t]. *) +val find_consecutive_duplicate : 'a t -> equal:('a -> 'a -> bool) -> ('a * 'a) option + +(** [reduce f [a1; ...; an]] is [Some (f (... (f (f a1 a2) a3) ...) an)]. Returns [None] + on the empty array. *) +val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a + +(** [permute ?random_state t] randomly permutes [t] in place. + + [permute] side-effects [random_state] by repeated calls to [Random.State.int]. If + [random_state] is not supplied, [permute] uses [Random.State.default]. *) +val permute : ?random_state:Random.State.t -> 'a t -> unit + +(** [random_element ?random_state t] is [None] if [t] is empty, else it is [Some x] for + some [x] chosen uniformly at random from [t]. + + [random_element] side-effects [random_state] by calling [Random.State.int]. If + [random_state] is not supplied, [random_element] uses [Random.State.default]. *) +val random_element : ?random_state:Random.State.t -> 'a t -> 'a option +val random_element_exn : ?random_state:Random.State.t -> 'a t -> 'a + +(** [zip] is like [List.zip], but for arrays. *) +val zip : 'a t -> 'b t -> ('a * 'b) t option +val zip_exn : 'a t -> 'b t -> ('a * 'b) t + +(** [unzip] is like [List.unzip], but for arrays. *) +val unzip : ('a * 'b) t -> 'a t * 'b t + +(** [sorted_copy ar compare] returns a shallow copy of [ar] that is sorted. Similar to + List.sort *) +val sorted_copy : 'a t -> compare:('a -> 'a -> int) -> 'a t + +val last : 'a t -> 'a + +val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +(** [unsafe_truncate t ~len] drops [length t - len] elements from the end of [t], changing + [t] so that [length t = len] afterwards. + + [unsafe_truncate] raises if [len <= 0 || len > length t]. + + It is not safe to do [unsafe_truncate] in the middle of a call to [map], [iter], etc., + or if you have given this array out to anything not under your control: in general, + code can rely on an array's length not changing. One must ensure code that calls + [unsafe_truncate] on an array does not interfere with other code that manipulates the + array. *) +val unsafe_truncate : _ t -> len:int -> unit + + +(** The input array is copied internally so that future modifications of it do not change + the sequence. *) +val to_sequence : 'a t -> 'a Sequence.t + +(** The input array is shared with the sequence and modifications of it will result in + modification of the sequence. *) +val to_sequence_mutable : 'a t -> 'a Sequence.t + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + module Sort : sig + module type Sort = sig + val sort + : 'a t + -> compare:('a -> 'a -> int) + -> left:int + -> right:int + -> unit + end + + module Insertion_sort : Sort + module Heap_sort : Sort + module Intro_sort : sig + include Sort + val five_element_sort + : 'a t -> compare:('a -> 'a -> int) -> int -> int -> int -> int -> int -> unit + end + end +end diff --git a/src/array0.ml b/src/array0.ml new file mode 100644 index 0000000..dc03510 --- /dev/null +++ b/src/array0.ml @@ -0,0 +1,62 @@ +(* [Array0] defines array functions that are primitives or can be simply defined in terms + of [Caml.Array]. [Array0] is intended to completely express the part of [Caml.Array] + that [Base] uses -- no other file in Base other than array0.ml should use [Caml.Array]. + [Array0] has few dependencies, and so is available early in Base's build order. All + Base files that need to use arrays and come before [Base.Array] in build order should + do [module Array = Array0]. This includes uses of subscript syntax ([x.(i)], [x.(i) <- + e]), which the OCaml parser desugars into calls to [Array.get] and [Array.set]. + Defining [module Array = Array0] is also necessary because it prevents ocamldep from + mistakenly causing a file to depend on [Base.Array]. *) + +open! Import0 + +module Sys = Sys0 + +let invalid_argf = Printf.invalid_argf + +module Array = struct + external create : int -> 'a -> 'a array = "caml_make_vect" + external get : 'a array -> int -> 'a = "%array_safe_get" + external length : 'a array -> int = "%array_length" + external set : 'a array -> int -> 'a -> unit = "%array_safe_set" + external unsafe_get : 'a array -> int -> 'a = "%array_unsafe_get" + external unsafe_set : 'a array -> int -> 'a -> unit = "%array_unsafe_set" +end + +include Array + +let max_length = Sys.max_array_length + +let create ~len x = + try create len x + with Invalid_argument _ -> + invalid_argf "Array.create ~len:%d: invalid length" len () +;; + +let append = Caml.Array.append +let blit = Caml.Array.blit +let concat = Caml.Array.concat +let copy = Caml.Array.copy +let fill = Caml.Array.fill +let init = Caml.Array.init +let make_matrix = Caml.Array.make_matrix +let of_list = Caml.Array.of_list +let sub = Caml.Array.sub +let to_list = Caml.Array.to_list + + +(* These are eta expanded in order to permute parameter order to follow Base + conventions. *) +let fold t ~init ~f = Caml.Array.fold_left t ~init ~f +let fold_right t ~f ~init = Caml.Array.fold_right t ~f ~init +let iter t ~f = Caml.Array.iter t ~f +let iteri t ~f = Caml.Array.iteri t ~f +let map t ~f = Caml.Array.map t ~f +let mapi t ~f = Caml.Array.mapi t ~f +let stable_sort t ~compare = Caml.Array.stable_sort t ~cmp:compare + +let swap t i j = + let tmp = t.(i) in + t.(i) <- t.(j); + t.(j) <- tmp; +;; diff --git a/src/array_permute.ml b/src/array_permute.ml new file mode 100644 index 0000000..c10a950 --- /dev/null +++ b/src/array_permute.ml @@ -0,0 +1,12 @@ +(** An internal-only module factored out due to a circular dependency between core_array + and core_list. Contains code for permuting an array. *) + +open! Import + +include Array0 + +(** randomly permute an array. *) +let permute ?(random_state = Random.State.default) t = + for i = length t downto 2 do + swap t (i - 1) (Random.State.int random_state i) + done diff --git a/src/avltree.ml b/src/avltree.ml new file mode 100644 index 0000000..d49c8cd --- /dev/null +++ b/src/avltree.ml @@ -0,0 +1,411 @@ +(* A few small things copied from other parts of Base because they depend on us, so we + can't use them. *) + +open! Import + +module Int = struct + type t = int + + let max (x : t) y = if x > y then x else y +end + +(* Its important that Empty have no args. It's tempting to make this type a record + (e.g. to hold the compare function), but a lot of memory is saved by Empty being an + immediate, since all unused buckets in the hashtbl don't use any memory (besides the + array cell) *) +type ('k, 'v) t = + | Empty + | Node of { mutable left : ('k, 'v) t + ; key : 'k + ; mutable value : 'v + ; mutable height : int + ; mutable right : ('k, 'v) t + } + | Leaf of { key : 'k + ; mutable value : 'v + } + + +let empty = Empty + +let height = function + | Empty -> 0 + | Leaf _ -> 1 + | Node { left = _; key = _; value = _; height; right = _ } -> height + +let invariant compare = + let legal_left_key key = function + | Empty -> () + | Leaf { key = left_key; value = _; } + | Node { left = _; key = left_key; value = _; height = _; right = _ } -> + assert (compare left_key key < 0) + in + let legal_right_key key = function + | Empty -> () + | Leaf { key = right_key; value = _; } + | Node { left = _; key = right_key; value = _; height = _; right = _ } -> + assert (compare right_key key > 0) + in + let rec inv = function + | Empty | Leaf _ -> () + | Node { left; key = k; value = _; height = h; right } -> + let (hl, hr) = (height left, height right) in + inv left; + inv right; + legal_left_key k left; + legal_right_key k right; + assert (h = Int.max hl hr + 1); + assert (abs (hl - hr) <= 2) + in inv + +let invariant t ~compare = invariant compare t + +(* In the following comments, + 't is balanced' means that 'invariant t' does not + raise an exception. This implies of course that each node's height field is + correct. + 't is balanceable' means that height of the left and right subtrees of t + differ by at most 3. *) + +(* @pre: left and right subtrees have correct heights + @post: output has the correct height *) +let update_height = function + | Node ({ left; key = _; value = _; height = old_height; right } as x) -> + let new_height = (Int.max (height left) (height right)) + 1 in + if new_height <> old_height then x.height <- new_height + | Empty | Leaf _ -> assert false + +(* @pre: left and right subtrees are balanced + @pre: tree is balanceable + @post: output is balanced (in particular, height is correct) *) +let balance tree = + match tree with + | Empty | Leaf _ -> tree + | Node ({ left; key = _; value = _; height = _; right } as root_node) -> + let hl = height left and hr = height right in + (* + 2 is critically important, lowering it to 1 will break the Leaf + assumptions in the code below, and will force us to promote leaf nodes in + the balance routine. It's also faster, since it will balance less often. + Note that the following code is delicate. The update_height calls must + occur in the correct order, since update_height assumes its children have + the correct heights. *) + if hl > hr + 2 then begin + match left with + (* It cannot be a leaf, because even if right is empty, a leaf + is only height 1 *) + | Empty | Leaf _ -> assert false + | Node ({ left = left_node_left; key = _; value = _; height = _; + right = left_node_right; } + as left_node) -> + if height left_node_left >= height left_node_right then begin + root_node.left <- left_node_right; + left_node.right <- tree; + update_height tree; + update_height left; + left + end else begin + (* if right is a leaf, then left must be empty. That means + height is 2. Even if hr is empty we still can't get here. *) + match left_node_right with + | Empty | Leaf _ -> assert false + | Node ({ left = lr_left; key = _; value = _; height = _; right = lr_right; } + as lr_node) -> + left_node.right <- lr_left; + root_node.left <- lr_right; + lr_node .right <- tree; + lr_node .left <- left; + update_height left; + update_height tree; + update_height left_node_right; + left_node_right + end + end else if hr > hl + 2 then begin + (* see above for an explanation of why right cannot be a leaf *) + match right with + | Empty | Leaf _ -> assert false + | Node ({ left = right_node_left; key = _; value = _; height = _; + right = right_node_right } + as right_node) -> + if height right_node_right >= height right_node_left then begin + root_node .right <- right_node_left; + right_node.left <- tree; + update_height tree; + update_height right; + right + end else begin + (* see above for an explanation of why this cannot be a leaf *) + match right_node_left with + | Empty | Leaf _ -> assert false + | Node ({ left = rl_left; key = _; value = _; height = _; right = rl_right } + as rl_node) + -> + right_node.left <- rl_right; + root_node .right <- rl_left; + rl_node .left <- tree; + rl_node .right <- right; + update_height right; + update_height tree; + update_height right_node_left; + right_node_left + end + end else begin + update_height tree; + tree + end +;; + +(* @pre: tree is balanceable + @pre: abs (height (right node) - height (balance tree)) <= 3 + @post: result is balanceable *) + +(* @pre: tree is balanceable + @pre: abs (height (right node) - height (balance tree)) <= 3 + @post: result is balanceable *) +let set_left node tree = + let tree = balance tree in + match node with + | Node ({ left; key = _; value = _; height = _; right = _ } as r) -> + if phys_equal left tree then () + else + r.left <- tree; + update_height node + | _ -> assert false + +(* @pre: tree is balanceable + @pre: abs (height (left node) - height (balance tree)) <= 3 + @post: result is balanceable *) +let set_right node tree = + let tree = balance tree in + match node with + | Node ({ left = _; key = _; value = _; height = _; right } as r) -> + if phys_equal right tree then () + else + r.right <- tree; + update_height node + | _ -> assert false + +(* @pre: t is balanced. + @post: result is balanced, with new node inserted + @post: !added = true iff the shape of the input tree changed. *) +let add = + let rec add t replace added compare k v = + match t with + | Empty -> + added := true; + Leaf { key = k; value = v } + | Leaf ({ key = k'; value = _ } as r) -> + let c = compare k' k in + (* This compare is reversed on purpose, we are pretending + that the leaf was just inserted instead of the other way + round, that way we only allocate one node. *) + if c = 0 then begin + added := false; + if replace then r.value <- v; + t + end else begin + added := true; + if c < 0 then + Node { left = t; key = k; value = v; height = 2; right = Empty } + else + Node { left = Empty; key = k; value = v; height = 2; right = t } + end + | Node ({left; key = k'; value = _; height = _; right } as r) -> + let c = compare k k' in + if c = 0 then begin + added := false; + if replace then r.value <- v; + end else if c < 0 then + set_left t (add left replace added compare k v) + else + set_right t (add right replace added compare k v); + t + in + fun t ~replace ~compare ~added ~key ~data -> + let t = add t replace added compare key data in + if !added then balance t else t +;; + +let rec first t = + match t with + | Empty -> None + | Leaf { key = k; value = v } + | Node { left = Empty; key = k; value = v; height = _; right = _ } -> Some (k, v) + | Node { left = l; key = _; value = _; height = _; right = _ } -> first l +;; + +let rec last t = + match t with + | Empty -> None + | Leaf { key = k; value = v } + | Node { left = _; key = k; value = v; height = _; right = Empty } -> Some (k, v) + | Node { left = _; key = _; value = _; height = _; right = r } -> last r +;; + + +let[@inline always] rec findi_and_call_impl t ~compare k ~call_if_found ~if_found ~if_not_found = + (* A little manual unrolling of the recursion. + This is really worth 5% on average *) + match t with + | Empty -> if_not_found k + | Leaf { key = k'; value = v } -> + if compare k k' = 0 then call_if_found ~if_found ~key:k' ~data:v + else if_not_found k + | Node { left; key = k'; value = v; height = _; right } -> + let c = compare k k' in + if c = 0 then call_if_found ~if_found ~key:k' ~data:v + else if c < 0 then begin + match left with + | Empty -> if_not_found k + | Leaf { key = k'; value = v }-> + if compare k k' = 0 then call_if_found ~if_found ~key:k' ~data:v + else if_not_found k + | Node { left; key = k'; value = v; height = _; right } -> + let c = compare k k' in + if c = 0 then call_if_found ~if_found ~key:k' ~data:v + else + findi_and_call_impl (if c < 0 then left else right) ~compare k ~call_if_found ~if_found ~if_not_found + end else begin + match right with + | Empty -> if_not_found k + | Leaf { key = k'; value = v } -> + if compare k k' = 0 then call_if_found ~if_found ~key:k' ~data:v + else if_not_found k + | Node { left; key = k'; value = v; height = _; right } -> + let c = compare k k' in + if c = 0 then call_if_found ~if_found ~key:k' ~data:v + else + findi_and_call_impl (if c < 0 then left else right) ~compare k ~call_if_found ~if_found ~if_not_found + end +;; + +let find_and_call = + let call_if_found ~if_found ~key:_ ~data = if_found data in + fun t ~compare k ~if_found ~if_not_found -> + findi_and_call_impl t ~compare k ~call_if_found ~if_found ~if_not_found + +let findi_and_call = + let call_if_found ~if_found ~key ~data = if_found ~key ~data in + fun t ~compare k ~if_found ~if_not_found -> + findi_and_call_impl t ~compare k ~call_if_found ~if_found ~if_not_found + +let find = + let if_found v = Some v in + let if_not_found _ = None in + fun t ~compare k -> + find_and_call t ~compare k ~if_found ~if_not_found + +let mem = + let if_found _ = true in + let if_not_found _ = false in + fun t ~compare k -> + find_and_call t ~compare k ~if_found ~if_not_found + +let remove = + let rec min_elt tree = + match tree with + | Empty -> Empty + | Leaf _ -> tree + | Node { left = Empty; key = _; value = _; height = _; right = _ } -> tree + | Node { left; key = _; value = _; height = _; right = _ } -> min_elt left + in + let rec remove_min_elt tree = + match tree with + | Empty -> assert false + | Leaf _ -> Empty (* This must be the root *) + | Node { left = Empty; key = _; value = _; height = _; right } -> right + | Node { left = Leaf _; key = k; value = v; height = _; right = Empty } -> + Leaf { key = k; value = v } + | Node { left = Leaf _; key = _; value = _; height = _; right = _ } as node -> + set_left node Empty; tree + | Node { left; key = _; value = _; height = _; right = _ } as node -> + set_left node (remove_min_elt left); tree + in + let merge t1 t2 = + match (t1, t2) with + | (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> + let tree = min_elt t2 in + match tree with + | Empty -> assert false + | Leaf { key = k; value = v } -> + let t2 = balance (remove_min_elt t2) in + Node { left = t1; key = k; value = v; + height = Int.max (height t1) (height t2) + 1; right = t2 + } + | Node _ as node -> + set_right node (remove_min_elt t2); + set_left node t1; + node + in + let rec remove t removed compare k = + match t with + | Empty -> + removed := false; + Empty + | Leaf { key = k'; value = _ } -> + if compare k k' = 0 then begin + removed := true; + Empty + end else begin + removed := false; + t + end + | Node { left; key = k'; value = _; height = _; right } -> + let c = compare k k' in + if c = 0 then begin + removed := true; + merge left right + end else if c < 0 then begin + set_left t (remove left removed compare k); + t + end else begin + set_right t (remove right removed compare k); + t + end + in + fun t ~removed ~compare k -> balance (remove t removed compare k) +;; + +let rec fold t ~init ~f = + match t with + | Empty -> init + | Leaf { key; value = data } -> f ~key ~data init + | Node { left = Leaf { key = lkey; value = ldata }; + key; value = data; height = _; + right = Leaf { key = rkey; value = rdata } + } -> + f ~key:rkey ~data:rdata (f ~key ~data (f ~key:lkey ~data:ldata init)) + | Node { left = Leaf { key = lkey; value = ldata}; key; value = data; height = _; + right = Empty } + -> f ~key ~data (f ~key:lkey ~data:ldata init) + | Node { left = Empty; key; value = data; height = _; + right = Leaf { key = rkey; value = rdata } + } -> + f ~key:rkey ~data:rdata (f ~key ~data init) + | Node { left; key; value = data; height = _; right = Leaf { key = rkey; value = rdata } + } -> + f ~key:rkey ~data:rdata (f ~key ~data (fold left ~init ~f)) + | Node { left = Leaf { key = lkey; value = ldata}; key; value = data; height = _; right } -> + fold right ~init:(f ~key ~data (f ~key:lkey ~data:ldata init)) ~f + | Node { left; key; value = data; height = _; right } -> + fold right ~init:(f ~key ~data (fold left ~init ~f)) ~f + +let rec iter t ~f = + match t with + | Empty -> () + | Leaf { key; value = data } -> f ~key ~data + | Node { left; key; value = data; height = _; right } -> + iter left ~f; + f ~key ~data; + iter right ~f + +let rec mapi_inplace t ~f = + match t with + | Empty -> () + | Leaf ({ key ; value } as t) -> + t.value <- f ~key ~data:value + | Node ({ left ; key ; value ; height = _ ; right } as t) -> + mapi_inplace ~f left; + t.value <- f ~key ~data:value; + mapi_inplace ~f right diff --git a/src/avltree.mli b/src/avltree.mli new file mode 100644 index 0000000..382467b --- /dev/null +++ b/src/avltree.mli @@ -0,0 +1,129 @@ +(** A low-level, mutable AVL tree. + + It is not intended to be used directly by casual users. It is used for implementing + other data structures. The interface is somewhat ugly, and it's that way for a + reason: the goal of this module is minimum memory overhead and maximum performance. + + {2 Caveats} + + 1. [compare] is passed to every function where it is used. If you pass a different + [compare] to functions on the same tree, then behavior is indeterminate. Why? Because + otherwise we'd need a top-level record to store [compare], and when building a hash + table, or other structure, that little [t] is a block that increases memory + overhead. However, if an empty tree is just a constructor [Empty], then it's just a + number, and uses no extra memory beyond the array bucket that holds it. That's the + first secret of how Hashtbl's memory overhead isn't higher than INRIA's, even though + it uses a tree instead of a list for buckets. + + 2. But if it's mutable, why do all the "mutators" return [t]? Answer: it is mutable, + but the root node might change due to balancing. Since we have no top-level record to + hold the current root node (see point 1), you have to do it. If you fail to do it, and + use an old root node, you're responsible for the (sure to be nasty) consequences. + + 3. What on earth is up with the [~removed] argument to some functions? See point 1: + since there is no top-level node, it isn't possible to keep track of how many nodes + are in the tree unless each mutator tells you whether or not it added or removed a + node (vs. replacing an existing one). If you intend to keep a count (as you must in a + hash table), then you will need to pay attention to this flag. + + After all this, you're probably asking yourself whether all these hacks are worth + it. Yes! They are! With them, we built a hash table that is faster than INRIA's (no + small feat) with the same memory overhead, sane add semantics (the add semantics they + used were a performance hack), and worst-case log(N) insertion, lookup, and + removal. *) + +open! Import + +(** We expose [t] to allow an optimization in Hashtbl that makes iter and fold more than + twice as fast. We keep the type private to reduce opportunities for external code to + violate avltree invariants. *) +type ('k, 'v) t = private + | Empty + | Node of { mutable left : ('k, 'v) t + ; key : 'k + ; mutable value : 'v + ; mutable height : int + ; mutable right : ('k, 'v) t + } + | Leaf of { key : 'k + ; mutable value : 'v + } + +val empty : ('k, 'v) t + +(** Checks invariants, raising an exception if any invariants fail. *) +val invariant : ('k, 'v) t -> compare:('k -> 'k -> int) -> unit + +(** Adds the specified key and data to the tree destructively (previous [t]'s are no + longer valid) using the specified comparison function. O(log(N)) time, O(1) space. + + The returned [t] is the new root node of the tree, and should be used on all further + calls to any other function in this module. The bool [ref], added, will be set to + [true] if a new node is added to the tree, or [false] if an existing node is replaced + (in the case that the key already exists). + + If [replace] (default true) is true then [add] will overwrite any existing mapping for + [key]. If [replace] is false, and there is an existing mapping for key, then [add] has + no effect. *) +val add + : ('k, 'v) t + -> replace:bool + -> compare:('k -> 'k -> int) + -> added:bool ref + -> key:'k + -> data:'v + -> ('k, 'v) t + +(** Returns the first (leftmost) or last (rightmost) element in the tree. *) + +val first : ('k, 'v) t -> ('k * 'v) option +val last : ('k, 'v) t -> ('k * 'v) option + +(** If the specified key exists in the tree, returns the corresponding value. O(log(N)) + time and O(1) space. *) +val find : ('k, 'v) t -> compare:('k -> 'k -> int) -> 'k -> 'v option + +(** [find_and_call t ~compare k ~if_found ~if_not_found] + + is equivalent to: + + [match find t ~compare k with Some v -> if_found v | None -> if_not_found k] + + except that it doesn't allocate the option. *) +val find_and_call + : ('k, 'v) t + -> compare:('k -> 'k -> int) + -> 'k + -> if_found:('v -> 'a) + -> if_not_found:('k -> 'a) + -> 'a + +val findi_and_call + : ('k, 'v) t + -> compare:('k -> 'k -> int) + -> 'k + -> if_found:(key:'k -> data:'v -> 'a) + -> if_not_found:('k -> 'a) + -> 'a + +(** Returns true if key is present in the tree, and false otherwise. *) +val mem : ('k, 'v) t -> compare:('k -> 'k -> int) -> 'k -> bool + +(** Removes key destructively from the tree if it exists, returning the new root node. + Previous root nodes are not usable anymore; do so at your peril. The [removed] ref + will be set to true if a node was actually removed, and false otherwise. *) +val remove + : ('k, 'v) t + -> removed:bool ref + -> compare:('k -> 'k -> int) + -> 'k + -> ('k, 'v) t + +(** Folds over the tree. *) +val fold : ('k, 'v) t -> init:'a -> f:(key:'k -> data:'v -> 'a -> 'a) -> 'a + +(** Iterates over the tree. *) +val iter : ('k, 'v) t -> f:(key:'k -> data:'v -> unit) -> unit + +(** Map over the the tree, changing the data in place. *) +val mapi_inplace : ('k, 'v) t -> f:(key:'k -> data:'v -> 'v) -> unit diff --git a/src/backtrace.ml b/src/backtrace.ml new file mode 100644 index 0000000..5a268a5 --- /dev/null +++ b/src/backtrace.ml @@ -0,0 +1,51 @@ +open! Import + +module Sys = Sys0 + +type t = Caml.Printexc.raw_backtrace + +let elide = ref am_testing +let elided_message = "<backtrace elided in test>" + +let get ?(at_most_num_frames = Int.max_value) () = + Caml.Printexc.get_callstack at_most_num_frames +;; + +let to_string t = + if !elide + then elided_message + else Caml.Printexc.raw_backtrace_to_string t +;; + +let to_string_list t = String.split_lines (to_string t) + +let sexp_of_t t = + Sexp.List (List.map (to_string_list t) ~f:(fun x -> Sexp.Atom x)) +;; + +module Exn = struct + + let set_recording = Caml.Printexc.record_backtrace + let am_recording = Caml.Printexc.backtrace_status + + let most_recent () = + Caml.Printexc.get_raw_backtrace () + ;; + + (* We turn on backtraces by default if OCAMLRUNPARAM isn't set. *) + let maybe_set_recording () = + match Sys.getenv "OCAMLRUNPARAM" with + | exception _ -> set_recording true + | (_ : string) -> () (* the caller set something, they are responsible *) + ;; + + let with_recording b ~f = + let saved = am_recording () in + set_recording b; + Exn.protect ~f ~finally:(fun () -> set_recording saved) + ;; +end + +let initialize_module () = + Exn.maybe_set_recording (); +;; diff --git a/src/backtrace.mli b/src/backtrace.mli new file mode 100644 index 0000000..e929816 --- /dev/null +++ b/src/backtrace.mli @@ -0,0 +1,79 @@ +(** Module for managing stack backtraces. + + The [Backtrace] module deals with two different kinds of backtraces: + + + Snapshots of the stack obtained on demand ([Backtrace.get]) + + The stack frames unwound when an exception is raised ([Backtrace.Exn]) +*) + +open! Import + +(** A [Backtrace.t] is a snapshot of the stack obtained by calling [Backtrace.get]. It is + represented as a string with newlines separating the frames. [sexp_of_t] splits the + string at newlines and removes some of the cruft, leaving a human-friendly list of + frames, but [to_string] does not. *) +type t [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +val get : ?at_most_num_frames:int -> unit -> t + +val to_string : t -> string + +val to_string_list : t -> string list + +(** The value of [elide] controls the behavior of backtrace serialization functions such + as {!to_string}, {!to_string_list}, and {!sexp_of_t}. When set to [false], these + functions behave as expected, returning a faithful representation of their argument. + When set to [true], these functions will ignore their argument and return a message + indicating that behavior. + + The default value is {!am_testing}. *) +val elide : bool ref + +(** [Backtrace.Exn] has functions for controlling and printing the backtrace of the most + recently raised exception. + + When an exception is raised, the runtime "unwinds" the stack, i.e., removes stack + frames, until it reaches a frame with an exception handler. It then matches the + exception against the patterns in the handler. If the exception matches, then the + program continues. If not, then the runtime continues unwinding the stack to the next + handler. + + If [am_recording () = true], then while the runtime is unwinding the stack, it keeps + track of the part of the stack that is unwound. This is available as a backtrace via + [most_recent ()]. Calling [most_recent] if [am_recording () = false] will yield the + empty backtrace. + + With [am_recording () = true], OCaml keeps only a backtrace for the most recently + raised exception. When one raises an exception, OCaml checks if it is physically equal + to the most recently raised exception. If it is, then OCaml appends the string + representation of the stack unwound by the current raise to the stored backtrace. If + the exception being raised is not physically equally to the most recently raised + exception, then OCaml starts recording a new backtrace. Thus one must call + [most_recent] before a subsequent [raise] of a (physically) distinct exception, or the + backtrace is lost. + + The initial value of [am_recording ()] is determined by the setting of the environment + variable OCAMLRUNPARAM. If OCAMLRUNPARAM is set, then [am_recording () = true] iff the + character "b" occurs in OCAMLRUNPARAM. If OCAMLRUNPARAM is not set (as is always the + case when running in a web browser), then [am_recording ()] is initially true. + + This is the same functionality as provided by the OCaml stdlib [Printexc] functions + [backtrace_status], [record_backtraces], [get_backtrace]. *) +module Exn : sig + val am_recording : unit -> bool + val set_recording : bool -> unit + + val with_recording : bool -> f:(unit -> 'a) -> 'a + + (** [most_recent ()] returns a backtrace containing the stack that was unwound by the + most recently raised exception. *) + val most_recent : unit -> t +end + +(** User code never calls this. It is called only in [std_kernel.ml], as a top-level side + effect, to initialize [am_recording ()] as specified above. *) +val initialize_module : unit -> unit diff --git a/src/base.ml b/src/base.ml new file mode 100644 index 0000000..1ef404b --- /dev/null +++ b/src/base.ml @@ -0,0 +1,471 @@ +(** This module is the toplevel of the Base library; it's what you get when you write + [open Base]. + + The recommended way to use Base is to build with [-open Base]. Files compiled this + way will have the environment described in this file as their initial environment. + + Base extends some modules and data structures from the standard library, like [Array], + [Buffer], [Bytes], [Char], [Hashtbl], [Int32], [Int64], [Lazy], [List], [Map], + [Nativeint], [Printf], [Random], [Set], [String], [Sys], and [Uchar]. One key + difference is that Base doesn't use exceptions as much as the standard library and + instead makes heavy use of the [Result] type, as in: + + {[ type ('a,'b) result = Ok of 'a | Error of 'b ]} + + Base also adds entirely new modules, most notably: + + - [Comparable], [Comparator], and [Comparisons] in lieu of polymorphic compare. + - [Container], which provides a consistent interface across container-like data + structures (arrays, lists, strings). + - [Result], [Error], and [Or_error], supporting the or-error pattern. + + Broadly the goal of Base is both to be a more complete standard library, with richer + APIs, and to be more consistent in its design. For instance, in the standard library + some things have modules and others don't; in Base, everything is a module. +*) + +(*_ We hide this from the web docs because the line wrapping is bad, making it + pretty much inscrutable. *) +(**/**) + +(** The intent is to shadow all of INRIA's standard library. Modules below would cause + compilation errors without being removed from [Shadow_stdlib] before inclusion. *) +include (Shadow_stdlib + : module type of struct include Shadow_stdlib end + (* Modules defined in Base *) + with module Array := Shadow_stdlib.Array + with module Bool := Shadow_stdlib.Bool + with module Buffer := Shadow_stdlib.Buffer + with module Bytes := Shadow_stdlib.Bytes + with module Char := Shadow_stdlib.Char + with module Float := Shadow_stdlib.Float + with module Hashtbl := Shadow_stdlib.Hashtbl + with module Int := Shadow_stdlib.Int + with module Int32 := Shadow_stdlib.Int32 + with module Int64 := Shadow_stdlib.Int64 + with module Lazy := Shadow_stdlib.Lazy + with module List := Shadow_stdlib.List + with module Map := Shadow_stdlib.Map + with module Nativeint := Shadow_stdlib.Nativeint + with module Option := Shadow_stdlib.Option + with module Printf := Shadow_stdlib.Printf + with module Queue := Shadow_stdlib.Queue + with module Random := Shadow_stdlib.Random + with module Result := Shadow_stdlib.Result + with module Set := Shadow_stdlib.Set + with module Stack := Shadow_stdlib.Stack + with module String := Shadow_stdlib.String + with module Sys := Shadow_stdlib.Sys + with module Uchar := Shadow_stdlib.Uchar + with module Unit := Shadow_stdlib.Unit + + (* Support for generated lexers *) + with module Lexing := Shadow_stdlib.Lexing + + with type ('a, 'b, 'c) format := ('a, 'b, 'c) format + with type ('a, 'b, 'c, 'd) format4 := ('a, 'b, 'c, 'd) format4 + with type ('a, 'b, 'c, 'd, 'e, 'f) format6 := ('a, 'b, 'c, 'd, 'e, 'f) format6 + + with type 'a ref := 'a ref + ) [@ocaml.warning "-3"] + +(**/**) + +open! Import + +module Applicative = Applicative +module Array = Array +module Avltree = Avltree +module Backtrace = Backtrace +module Binary_search = Binary_search +module Binary_searchable = Binary_searchable +module Blit = Blit +module Bool = Bool +module Buffer = Buffer +module Bytes = Bytes +module Char = Char +module Comparable = Comparable +module Comparator = Comparator +module Comparisons = Comparisons +module Container = Container +module Either = Either +module Equal = Equal +module Error = Error +module Exn = Exn +module Field = Field +module Float = Float +module Floatable = Floatable +module Fn = Fn +module Formatter = Formatter +module Hash = Hash +module Hash_set = Hash_set +module Hashable = Hashable +module Hasher = Hasher +module Hashtbl = Hashtbl +module Identifiable = Identifiable +module Indexed_container = Indexed_container +module Info = Info +module Int = Int +module Int32 = Int32 +module Int63 = Int63 +module Int64 = Int64 +module Intable = Intable +module Invariant = Invariant +module Lazy = Lazy +module List = List +module Map = Map +module Maybe_bound = Maybe_bound +module Monad = Monad +module Nativeint = Nativeint +module Option = Option +module Option_array = Option_array +module Or_error = Or_error +module Ordered_collection_common = Ordered_collection_common +module Ordering = Ordering +module Poly = Poly +module Polymorphic_compare = Poly +[@@deprecated "[since 2018-11] use [Poly] instead"] +module Popcount = Popcount +[@@deprecated "[since 2018-10] use [popcount] functions in the individual int modules"] +module Pretty_printer = Pretty_printer +module Printf = Printf +module Linked_queue = Linked_queue +module Queue = Queue +module Random = Random +module Ref = Ref +module Result = Result +module Sequence = Sequence +module Set = Set +module Sexpable = Sexpable +module Sign = Sign +module Sign_or_nan = Sign_or_nan +module Source_code_position = Source_code_position +module Stack = Stack +module Staged = Staged +module String = String +module Stringable = Stringable +module Sys = Sys +module T = T +module Type_equal = Type_equal +module Uniform_array = Uniform_array +module Unit = Unit +module Uchar = Uchar +module Validate = Validate +module Variant = Variant +module With_return = With_return +module Word_size = Word_size + +(* Avoid a level of indirection for uses of the signatures defined in [T]. *) +include T + +(* This is a hack so that odoc creates better documentation. *) +module Sexp = struct + include Sexp_with_comparable (** @inline *) +end + +(**/**) +module Exported_for_specific_uses = struct + module Fieldslib = Fieldslib + module Ppx_hash_lib = Ppx_hash_lib + module Sexplib = Sexplib + module Variantslib = Variantslib + module Ppx_compare_lib = Ppx_compare_lib + module Ppx_sexp_conv_lib = Ppx_sexp_conv_lib + let am_testing = am_testing +end +(**/**) + +module Export = struct + (* [deriving hash] is missing for [array] and [ref] since these types are mutable. + (string is also mutable, but we pretend it isn't for hashing purposes) *) + type 'a array = 'a Array. t [@@deriving_inline compare, equal, sexp] + let compare_array : 'a . ('a -> 'a -> int) -> 'a array -> 'a array -> int = + Array.compare + let equal_array : 'a . ('a -> 'a -> bool) -> 'a array -> 'a array -> bool = + Array.equal + let array_of_sexp : + 'a . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a array + = Array.t_of_sexp + let sexp_of_array : + 'a . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a array -> Ppx_sexp_conv_lib.Sexp.t + = Array.sexp_of_t + [@@@end] + type bool = Bool. t [@@deriving_inline compare, equal, hash, sexp] + let compare_bool : bool -> bool -> int = Bool.compare + let equal_bool : bool -> bool -> bool = Bool.equal + let (hash_fold_bool : + Ppx_hash_lib.Std.Hash.state -> bool -> Ppx_hash_lib.Std.Hash.state) = + Bool.hash_fold_t + and (hash_bool : bool -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Bool.hash in fun x -> func x + let bool_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> bool = Bool.t_of_sexp + let sexp_of_bool : bool -> Ppx_sexp_conv_lib.Sexp.t = Bool.sexp_of_t + [@@@end] + type char = Char. t [@@deriving_inline compare, equal, hash, sexp] + let compare_char : char -> char -> int = Char.compare + let equal_char : char -> char -> bool = Char.equal + let (hash_fold_char : + Ppx_hash_lib.Std.Hash.state -> char -> Ppx_hash_lib.Std.Hash.state) = + Char.hash_fold_t + and (hash_char : char -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Char.hash in fun x -> func x + let char_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> char = Char.t_of_sexp + let sexp_of_char : char -> Ppx_sexp_conv_lib.Sexp.t = Char.sexp_of_t + [@@@end] + type exn = Exn. t [@@deriving_inline sexp_of] + let sexp_of_exn : exn -> Ppx_sexp_conv_lib.Sexp.t = Exn.sexp_of_t + [@@@end] + type float = Float. t [@@deriving_inline compare, equal, hash, sexp] + let compare_float : float -> float -> int = Float.compare + let equal_float : float -> float -> bool = Float.equal + let (hash_fold_float : + Ppx_hash_lib.Std.Hash.state -> float -> Ppx_hash_lib.Std.Hash.state) = + Float.hash_fold_t + and (hash_float : float -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Float.hash in fun x -> func x + let float_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> float = Float.t_of_sexp + let sexp_of_float : float -> Ppx_sexp_conv_lib.Sexp.t = Float.sexp_of_t + [@@@end] + type int = Int. t [@@deriving_inline compare, equal, hash, sexp] + let compare_int : int -> int -> int = Int.compare + let equal_int : int -> int -> bool = Int.equal + let (hash_fold_int : + Ppx_hash_lib.Std.Hash.state -> int -> Ppx_hash_lib.Std.Hash.state) = + Int.hash_fold_t + and (hash_int : int -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Int.hash in fun x -> func x + let int_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> int = Int.t_of_sexp + let sexp_of_int : int -> Ppx_sexp_conv_lib.Sexp.t = Int.sexp_of_t + [@@@end] + type int32 = Int32. t [@@deriving_inline compare, equal, hash, sexp] + let compare_int32 : int32 -> int32 -> int = Int32.compare + let equal_int32 : int32 -> int32 -> bool = Int32.equal + let (hash_fold_int32 : + Ppx_hash_lib.Std.Hash.state -> int32 -> Ppx_hash_lib.Std.Hash.state) = + Int32.hash_fold_t + and (hash_int32 : int32 -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Int32.hash in fun x -> func x + let int32_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> int32 = Int32.t_of_sexp + let sexp_of_int32 : int32 -> Ppx_sexp_conv_lib.Sexp.t = Int32.sexp_of_t + [@@@end] + type int64 = Int64. t [@@deriving_inline compare, equal, hash, sexp] + let compare_int64 : int64 -> int64 -> int = Int64.compare + let equal_int64 : int64 -> int64 -> bool = Int64.equal + let (hash_fold_int64 : + Ppx_hash_lib.Std.Hash.state -> int64 -> Ppx_hash_lib.Std.Hash.state) = + Int64.hash_fold_t + and (hash_int64 : int64 -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Int64.hash in fun x -> func x + let int64_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> int64 = Int64.t_of_sexp + let sexp_of_int64 : int64 -> Ppx_sexp_conv_lib.Sexp.t = Int64.sexp_of_t + [@@@end] + type 'a list = 'a List. t [@@deriving_inline compare, equal, hash, sexp] + let compare_list : 'a . ('a -> 'a -> int) -> 'a list -> 'a list -> int = + List.compare + let equal_list : 'a . ('a -> 'a -> bool) -> 'a list -> 'a list -> bool = + List.equal + let hash_fold_list : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a list -> Ppx_hash_lib.Std.Hash.state + = List.hash_fold_t + let list_of_sexp : + 'a . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a list + = List.t_of_sexp + let sexp_of_list : + 'a . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a list -> Ppx_sexp_conv_lib.Sexp.t + = List.sexp_of_t + [@@@end] + type nativeint = Nativeint. t [@@deriving_inline compare, equal, hash, sexp] + let compare_nativeint : nativeint -> nativeint -> int = Nativeint.compare + let equal_nativeint : nativeint -> nativeint -> bool = Nativeint.equal + let (hash_fold_nativeint : + Ppx_hash_lib.Std.Hash.state -> nativeint -> Ppx_hash_lib.Std.Hash.state) = + Nativeint.hash_fold_t + and (hash_nativeint : nativeint -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Nativeint.hash in fun x -> func x + let nativeint_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> nativeint = + Nativeint.t_of_sexp + let sexp_of_nativeint : nativeint -> Ppx_sexp_conv_lib.Sexp.t = + Nativeint.sexp_of_t + [@@@end] + type 'a option = 'a Option. t [@@deriving_inline compare, equal, hash, sexp] + let compare_option : 'a . ('a -> 'a -> int) -> 'a option -> 'a option -> int + = Option.compare + let equal_option : 'a . ('a -> 'a -> bool) -> 'a option -> 'a option -> bool + = Option.equal + let hash_fold_option : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a option -> Ppx_hash_lib.Std.Hash.state + = Option.hash_fold_t + let option_of_sexp : + 'a . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a option + = Option.t_of_sexp + let sexp_of_option : + 'a . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a option -> Ppx_sexp_conv_lib.Sexp.t + = Option.sexp_of_t + [@@@end] + type 'a ref = 'a Ref. t [@@deriving_inline compare, equal, sexp] + let compare_ref : 'a . ('a -> 'a -> int) -> 'a ref -> 'a ref -> int = + Ref.compare + let equal_ref : 'a . ('a -> 'a -> bool) -> 'a ref -> 'a ref -> bool = + Ref.equal + let ref_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a ref + = Ref.t_of_sexp + let sexp_of_ref : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a ref -> Ppx_sexp_conv_lib.Sexp.t + = Ref.sexp_of_t + [@@@end] + type string = String. t [@@deriving_inline compare, equal, hash, sexp] + let compare_string : string -> string -> int = String.compare + let equal_string : string -> string -> bool = String.equal + let (hash_fold_string : + Ppx_hash_lib.Std.Hash.state -> string -> Ppx_hash_lib.Std.Hash.state) = + String.hash_fold_t + and (hash_string : string -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = String.hash in fun x -> func x + let string_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> string = String.t_of_sexp + let sexp_of_string : string -> Ppx_sexp_conv_lib.Sexp.t = String.sexp_of_t + [@@@end] + type bytes = Bytes. t [@@deriving_inline compare, equal, sexp] + let compare_bytes : bytes -> bytes -> int = Bytes.compare + let equal_bytes : bytes -> bytes -> bool = Bytes.equal + let bytes_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> bytes = Bytes.t_of_sexp + let sexp_of_bytes : bytes -> Ppx_sexp_conv_lib.Sexp.t = Bytes.sexp_of_t + [@@@end] + type unit = Unit. t [@@deriving_inline compare, equal, hash, sexp] + let compare_unit : unit -> unit -> int = Unit.compare + let equal_unit : unit -> unit -> bool = Unit.equal + let (hash_fold_unit : + Ppx_hash_lib.Std.Hash.state -> unit -> Ppx_hash_lib.Std.Hash.state) = + Unit.hash_fold_t + and (hash_unit : unit -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = Unit.hash in fun x -> func x + let unit_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> unit = Unit.t_of_sexp + let sexp_of_unit : unit -> Ppx_sexp_conv_lib.Sexp.t = Unit.sexp_of_t + [@@@end] + + (** Format stuff *) + + type nonrec ('a, 'b, 'c) format = ('a, 'b, 'c) format + type nonrec ('a, 'b, 'c, 'd) format4 = ('a, 'b, 'c, 'd) format4 + type nonrec ('a, 'b, 'c, 'd, 'e, 'f) format6 = ('a, 'b, 'c, 'd, 'e, 'f) format6 + + (** {2 Sexp} + + Exporting the ad-hoc types that are recognized by [ppx_sexp_*] converters. + [sexp_array], [sexp_list], and [sexp_option] allow a record field to be absent when + converting from a sexp, and if absent, the field will take a default value of the + appropriate type: + + {v + sexp_array [||] + sexp_bool false + sexp_list [] + sexp_option None + v} + + [sexp_opaque] causes the conversion to sexp to produce the atom [<opaque>]. + + For more documentation, see sexplib/README.md. *) + + type 'a sexp_array = 'a array + type 'a sexp_list = 'a list + type 'a sexp_opaque = 'a + type 'a sexp_option = 'a option + + (** List operators *) + + include List.Infix + + (** Int operators and comparisons *) + + include Int.O + include Int_replace_polymorphic_compare + + (** Float operators *) + + include Float.O_dot + + (** Reverse application operator. [x |> g |> f] is equivalent to [f (g (x))]. *) + (* This is declared as an external to be optimized away in more contexts. *) + external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + + (** Application operator. [g @@ f @@ x] is equivalent to [g (f (x))]. *) + external ( @@ ) : ('a -> 'b) -> 'a -> 'b = "%apply" + + (** Boolean operations *) + + (* These need to be declared as an external to get the lazy behavior *) + external ( && ) : bool -> bool -> bool = "%sequand" + external ( || ) : bool -> bool -> bool = "%sequor" + external not : bool -> bool = "%boolnot" + + (* This must be declared as an external for the warnings to work properly. *) + external ignore : _ -> unit = "%ignore" + + (** Common string operations *) + let ( ^ ) = String.( ^ ) + + (** Reference operations *) + + (* Declared as an externals so that the compiler skips the caml_modify when possible and + to keep reference unboxing working *) + external ( ! ) : 'a ref -> 'a = "%field0" + external ref : 'a -> 'a ref = "%makemutable" + external ( := ) : 'a ref -> 'a -> unit = "%setfield0" + + (** Pair operations *) + + let fst = fst + let snd = snd + + (** Exceptions stuff *) + + (* Declared as an external so that the compiler may rewrite '%raise' as '%reraise'. *) + external raise : exn -> _ = "%raise" + + let failwith = failwith + let invalid_arg = invalid_arg + let raise_s = Error.raise_s + + (** Misc *) + + let phys_equal = phys_equal + + external force : 'a Lazy.t -> 'a = "%lazy_force" +end + +include Export +include Container_intf.Export (** @inline *) + +exception Not_found_s = Not_found_s + +(* Various things to cleanup that were used without going through Base. *) +module Not_exposed_properly = struct + module Int63_emul = Int63_emul + module Float0 = Float0 + module Import = Import + module Int_conversions = Int_conversions + module Int_math = Int_math + module Pow_overflow_bounds = Pow_overflow_bounds + module Sexp_conv = Sexplib0.Sexp_conv + module Obj_array = Obj_array +end + +(* We perform these side effects here because we want them to run for any code that uses + [Base]. If this were in another module in [Base] that was not used in some program, + then the side effects might not be run in that program. This will run as long as the + program refers to at least one value directly in [Base]; referring to values in + [Base.Bool], for example, is not sufficient. *) +let () = + Backtrace.initialize_module (); +;; diff --git a/src/base.mld b/src/base.mld new file mode 100644 index 0000000..4a09082 --- /dev/null +++ b/src/base.mld @@ -0,0 +1,173 @@ +{1 Base} + +{b {{!Base} The full API is browsable here}}. + +Base is a standard library for OCaml. It provides a standard set of +general-purpose modules that are well tested, performant, and fully +portable across any environment that can run OCaml code. + +Unlike other standard library projects, Base is meant to be used as a +wholesale replacement of the standard library distributed with the +OCaml compiler. In particular, it makes different choices and doesn't +re-export features that are not fully portable such as I/O, which are +left to other libraries. + +Note that an API for OCaml's channel-based I/O can be found in the +{{!module:Stdio}[Stdio]} library. + +{2 Relationship to Core_kernel and Core} + +Base is the smallest, most self-contained version of Jane Street's +family of three standard library replacements. It is extended by +[Core_kernel], which is in turn extended by [Core]. + +In sum: + +- {b {!Base}}: Minimal stdlib replacement. Portable and lightweight and + intended to be highly stable. +- {b {!Core_kernel}}: Extension of Base. More fully featured, with more + code and dependencies, and APIs that evolve more quickly. Portable, + and works on Javascript. +- {b {!Core}}: Core_kernel extended with UNIX APIs. + + +{2 Using the OCaml standard library with Base} + +Base is intended as a full stdlib replacement. As a result, after an +[open Base], all the modules, values, types, etc., coming from the OCaml +standard library that one normally gets in the default environment are +deprecated. + +In order to access these values, one must use the [Caml] library, +which re-exports them all through the toplevel name +{{!module:Caml}[Caml]}: [Caml.String], [Caml.print_string], ... + +The new modules and values made available by Base are documented +{{!Base} here}. + +{2 Differences between Base and the OCaml standard library} + +Programmers who are used to the OCaml standard library should read +through this section to understand major differences between the two +libraries that one should be aware of when switching to Base. + +{3 Comparison operators} + +The comparison operators exposed by the OCaml standard library are +polymorphic: + +{[ +val compare : 'a -> 'a -> int +val ( <= ) : 'a -> 'a -> bool +(* ... *) +]} + +What they implement is structural comparison of the runtime +representation of values. Since these are often error-prone, +i.e., they don't correspond to what the user expects, they are not +exposed directly by Base. + +To use polymorphic comparison with Base, one should use the +{{!Base.Polymorphic_compare}[Polymorphic_compare]} module. The default +comparison operators exposed by Base are the integer ones, just like +the default arithmetic operators are the integer ones. + +The recommended way to compare arbitrary complex data structures is to +use the specific [compare] functions. For instance: + +{[ List.compare String.compare x y ]} + +The [ppx_compare] rewriter offers an alternative way to write this: + +{[ [%compare: string list] x y ]} + +{2 Base and ppx code generators} + +Base uses a few ppx code generators to implement: + +- reliable and customizable comparison of OCaml values; +- reliable and customizable hash of OCaml values; and +- conversions between OCaml values and s-expression. + +However, it doesn't need these code generators to build. Instead, it +uses ppx as a code verification tool during development. It works in a +very similar fashion to {{: https://github.com/janestreet/ppx_expect} +expect tests}. + +Whenever you see this in the code source: + +{[ +type t = ... [@@deriving_inline sexp_of] +let sexp_of_t = ... +[@@@end] +]} + +the code between the [[@@deriving_inline]] and the [[@@@end]] is +generated code. The generated code is currently quite big and hard to +read, however we are working on making it look like human-written +code. + +You can put the following elisp code in your [~/.emacs] file to hide +these blocks: + +{v +(defun deriving-inline-forward-sexp (&optional arg) + (search-forward-regexp "\\[@@@end\\]") nil nil arg) + +(defun setup-hide-deriving-inline () + (inline) + (hs-minor-mode t) + (let ((hs-hide-comments-when-hiding-all nil)) + (hs-hide-all))) + +(require 'hideshow) +(add-to-list 'hs-special-modes-alist + '(tuareg-mode "\\[@@deriving_inline[^]]*\\]" "\\[@@@end\\]" nil + deriving-inline-forward-sexp nil)) +(add-hook 'tuareg-mode-hook 'setup-hide-deriving-inline) +v} + +Things are not yet set up in the git repository to make it convenient +to change types and update the generated code, but they will be set up +soon. + +{2 Base coding rules} + +There are a few coding rules across the code base that are enforced by +lint tools. + +These rules are: + +{ul +{- Opening the [Caml] module is not allowed. Inside Base, the OCaml + stdlib is shadowed and accessible through the [Caml] module. We + forbid opening [Caml] so that we know exactly where things come + from.} +{- [Caml.Foo] modules cannot be aliased, one must use [Caml.Foo] + explicitly. This is to avoid having to remember a list of aliases + at the beginning of each file.} +{- For some modules that are both in the OCaml stdlib and Base, such as + [String], we define a module [String0] for common functions that + cannot be defined directly in [Base.String] to avoid creating a + circular dependency. Except for [String] itself, other modules + are not allowed to use [Caml.String] and must use either [String] or + [String0] instead.} +{- Indentation is exactly the one of [ocp-indent].} +{- A few other coding style rules enforced by + {{: https://github.com/janestreet/ppx_js_style} ppx_js_style}.} +} + + +The Base specific coding rules are checked by [ppx_base_lint], in the +[lint] subfolder. The indentation rules are checked by a wrapper around +[ocp-indent] and the coding style rules are checked by [ppx_js_style]. + +These checks are currently not run by [jbuilder], but it will soon get +a [-dev] flag to run them automatically. + +{2 Roadmap} + +Base is still under active development and there are several missing +feature that are yet to be added. Consult the +{{:https://github.com/janestreet/base/blob/master/ROADMAP.md}roadmap} to +see what is happening. diff --git a/src/binary_search.ml b/src/binary_search.ml new file mode 100644 index 0000000..0011158 --- /dev/null +++ b/src/binary_search.ml @@ -0,0 +1,109 @@ +open! Import + +(* These functions implement a search for the first (resp. last) element + satisfying a predicate, assuming that the predicate is increasing on + the container, meaning that, if the container is [u1...un], there exists a + k such that p(u1)=....=p(uk) = false and p(uk+1)=....=p(un)= true. + If this k = 1 (resp n), find_last_not_satisfying (resp find_first_satisfying) + will return None. *) + +let rec linear_search_first_satisfying t ~get ~lo ~hi ~pred = + if lo > hi + then None + else + if pred (get t lo) + then Some lo + else linear_search_first_satisfying t ~get ~lo:(lo + 1) ~hi ~pred +;; + +(* Takes a container [t], a predicate [pred] and two indices [lo < hi], such that + [pred] is increasing on [t] between [lo] and [hi]. + + return a range (lo, hi) where: + - lo and hi are close enough together for a linear search + - If [pred] is not constantly [false] on [t] between [lo] and [hi], the first element + on which [pred] is [true] is between [lo] and [hi]. *) +(* Invariant: the first element satisfying [pred], if it exists is between [lo] and [hi] *) +let rec find_range_near_first_satisfying t ~get ~lo ~hi ~pred = + (* Warning: this function will not terminate if the constant (currently 8) is + set <= 1 *) + if hi - lo <= 8 + then (lo,hi) + else + let mid = lo + ((hi - lo) / 2) in + if pred (get t mid) + (* INVARIANT check: it means the first satisfying element is between [lo] and [mid] *) + then find_range_near_first_satisfying t ~get ~lo ~hi:mid ~pred + (* INVARIANT check: it means the first satisfying element, if it exists, + is between [mid+1] and [hi] *) + else find_range_near_first_satisfying t ~get ~lo:(mid+1) ~hi ~pred +;; + +let find_first_satisfying ?pos ?len t ~get ~length ~pred = + let pos, len = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length t) + in + let lo = pos in + let hi = pos + len - 1 in + let (lo, hi) = find_range_near_first_satisfying t ~get ~lo ~hi ~pred in + linear_search_first_satisfying t ~get ~lo ~hi ~pred +;; + +(* Takes an array with shape [true,...true,false,...false] (i.e., the _reverse_ of what + is described above) and returns the index of the last true or None if there are no + true*) +let find_last_satisfying ?pos ?len t ~pred ~get ~length = + let pos, len = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length t) + in + if len = 0 + then None + else begin + (* The last satisfying is the one just before the first not satisfying *) + match find_first_satisfying ~pos ~len t ~get ~length ~pred:(Fn.non pred) with + | None -> Some (pos + len - 1) (* This means that all elements satisfy pred. + There is at least an element as (len > 0) *) + | Some i when i = pos -> None (* no element satisfies pred *) + | Some i -> Some (i - 1) + end +;; + +let binary_search ?pos ?len t ~length ~get ~compare how v = + match how with + | `Last_strictly_less_than -> + find_last_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v < 0) + | `Last_less_than_or_equal_to -> + find_last_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v <= 0) + | `First_equal_to -> + begin + match + find_first_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v >= 0) + with + | Some x when compare (get t x) v = 0 -> Some x + | None | Some _ -> None + end + | `Last_equal_to -> + begin + match + find_last_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v <= 0) + with + | Some x when compare (get t x) v = 0 -> Some x + | None | Some _ -> None + end + | `First_greater_than_or_equal_to -> + find_first_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v >= 0) + | `First_strictly_greater_than -> + find_first_satisfying ?pos ?len t ~get ~length ~pred:(fun x -> compare x v > 0) +;; + +let binary_search_segmented ?pos ?len t ~length ~get ~segment_of how = + let is_left x = + match segment_of x with + | `Left -> true + | `Right -> false + in + let is_right x = not (is_left x) in + match how with + | `Last_on_left -> find_last_satisfying ?pos ?len t ~length ~get ~pred:is_left + | `First_on_right -> find_first_satisfying ?pos ?len t ~length ~get ~pred:is_right +;; diff --git a/src/binary_search.mli b/src/binary_search.mli new file mode 100644 index 0000000..abbe3da --- /dev/null +++ b/src/binary_search.mli @@ -0,0 +1,86 @@ +(** General functions for performing binary searches over ordered sequences given + [length] and [get] functions. + + These functions can be specialized and added to a data structure using the functors + supplied in {{!Base.Binary_searchable}[Binary_searchable]} and described in + {{!Base.Binary_searchable_intf}[Binary_searchable_intf]}. + + {2:examples Examples} + + Below we assume that the functions [get], [length] and [compare] are in scope: + + {[ + (* Find the index of an element [e] in [t] *) + binary_search t ~get ~length ~compare `First_equal_to e; + + (* Find the index where an element [e] should be inserted *) + binary_search t ~get ~length ~compare `First_greater_than_or_equal_to e; + + (* Find the index in [t] where all elements to the left are less than [e] *) + binary_search_segmented t ~get ~length ~segment_of:(fun e' -> + if compare e' e <= 0 then `Left else `Right) `First_on_right + ]} *) + +open! Import + +(** [binary_search ?pos ?len t ~length ~get ~compare which elt] takes [t] that is sorted + in increasing order according to [compare], where [compare] and [elt] divide [t] into + three (possibly empty) segments: + + {v + | < elt | = elt | > elt | + v} + + [binary_search] returns the index in [t] of an element on the boundary of segments + as specified by [which]. See the diagram below next to the [which] variants. + + By default, [binary_search] searches the entire [t]. One can supply [?pos] or + [?len] to search a slice of [t]. + + [binary_search] does not check that [compare] orders [t], and behavior is + unspecified if [compare] doesn't order [t]. Behavior is also unspecified if + [compare] mutates [t]. *) +val binary_search + : ?pos:int + -> ?len:int + -> 't + -> length:('t -> int) + -> get:('t -> int -> 'elt) + -> compare:('elt -> 'key -> int) + -> [ `Last_strictly_less_than (** {v | < elt X | v} *) + | `Last_less_than_or_equal_to (** {v | <= elt X | v} *) + | `Last_equal_to (** {v | = elt X | v} *) + | `First_equal_to (** {v | X = elt | v} *) + | `First_greater_than_or_equal_to (** {v | X >= elt | v} *) + | `First_strictly_greater_than (** {v | X > elt | v} *) + ] + -> 'key + -> int option + +(** [binary_search_segmented ?pos ?len t ~length ~get ~segment_of which] takes a + [segment_of] function that divides [t] into two (possibly empty) segments: + + {v + | segment_of elt = `Left | segment_of elt = `Right | + v} + + [binary_search_segmented] returns the index of the element on the boundary of the + segments as specified by [which]: [`Last_on_left] yields the index of the last + element of the left segment, while [`First_on_right] yields the index of the first + element of the right segment. It returns [None] if the segment is empty. + + By default, [binary_search] searches the entire [t]. One can supply [?pos] or + [?len] to search a slice of [t]. + + [binary_search_segmented] does not check that [segment_of] segments [t] as in the + diagram, and behavior is unspecified if [segment_of] doesn't segment [t]. Behavior + is also unspecified if [segment_of] mutates [t]. *) +val binary_search_segmented + : ?pos:int + -> ?len:int + -> 't + -> length:('t -> int) + -> get:('t -> int -> 'elt) + -> segment_of:('elt -> [ `Left | `Right ]) + -> [ `Last_on_left | `First_on_right ] + -> int option diff --git a/src/binary_searchable.ml b/src/binary_searchable.ml new file mode 100644 index 0000000..8779b98 --- /dev/null +++ b/src/binary_searchable.ml @@ -0,0 +1,36 @@ +open! Import + +include Binary_searchable_intf + +module type Arg = sig + type 'a elt + type 'a t + val get : 'a t -> int -> 'a elt + val length : _ t -> int +end + +module Make_gen (T : Arg) = struct + let get = T.get + let length = T.length + + let binary_search ?pos ?len t ~compare how v = + Binary_search.binary_search ?pos ?len t ~get ~length ~compare how v + + let binary_search_segmented ?pos ?len t ~segment_of how = + Binary_search.binary_search_segmented ?pos ?len t ~get ~length ~segment_of how +end + +module Make (T : Indexable) = + Make_gen (struct + type 'a elt = T.elt + type 'a t = T.t + include (T : Indexable with type elt := T.elt with type t := T.t) + end) + +module Make1 (T : Indexable1) = + Make_gen (struct + type 'a elt = 'a + type 'a t = 'a T.t + let get = T.get + let length = T.length + end) diff --git a/src/binary_searchable.mli b/src/binary_searchable.mli new file mode 100644 index 0000000..85837ea --- /dev/null +++ b/src/binary_searchable.mli @@ -0,0 +1 @@ +include Binary_searchable_intf.Binary_searchable (** @inline *) diff --git a/src/binary_searchable_intf.ml b/src/binary_searchable_intf.ml new file mode 100644 index 0000000..03b0725 --- /dev/null +++ b/src/binary_searchable_intf.ml @@ -0,0 +1,76 @@ +(** Module types for a [binary_search] function for a sequence, and functors for building + [binary_search] functions. *) + +open! Import + +(** An [Indexable] type is a finite sequence of elements indexed by consecutive integers + [0] ... [length t - 1]. [get] and [length] must be O(1) for the resulting + [binary_search] to be lg(n). *) +module type Indexable = sig + type elt + type t + + val get : t -> int -> elt + val length : t -> int +end + +module type Indexable1 = sig + type 'a t + + val get : 'a t -> int -> 'a + val length : _ t -> int +end + +type ('t, 'elt, 'key) binary_search = + ?pos:int + -> ?len:int + -> 't + -> compare:('elt -> 'key -> int) + -> [ `Last_strictly_less_than (** {v | < elt X | v} *) + | `Last_less_than_or_equal_to (** {v | <= elt X | v} *) + | `Last_equal_to (** {v | = elt X | v} *) + | `First_equal_to (** {v | X = elt | v} *) + | `First_greater_than_or_equal_to (** {v | X >= elt | v} *) + | `First_strictly_greater_than (** {v | X > elt | v} *) + ] + -> 'key + -> int option + +type ('t, 'elt) binary_search_segmented = + ?pos:int + -> ?len:int + -> 't + -> segment_of:('elt -> [ `Left | `Right ]) + -> [ `Last_on_left | `First_on_right ] + -> int option + +module type S = sig + type elt + type t + + (** See [Binary_search.binary_search] in binary_search.ml *) + val binary_search : (t, elt, 'key) binary_search + + (** See [Binary_search.binary_search_segmented] in binary_search.ml *) + val binary_search_segmented : (t, elt) binary_search_segmented +end + +module type S1 = sig + type 'a t + + val binary_search : ('a t, 'a, 'key) binary_search + val binary_search_segmented : ('a t, 'a) binary_search_segmented +end + +module type Binary_searchable = sig + module type S = S + module type S1 = S1 + module type Indexable = Indexable + module type Indexable1 = Indexable1 + + type nonrec ('t, 'elt, 'key) binary_search = ('t, 'elt, 'key) binary_search + type nonrec ('t, 'elt) binary_search_segmented = ('t, 'elt) binary_search_segmented + + module Make (T : Indexable) : S with type t := T.t with type elt := T.elt + module Make1 (T : Indexable1) : S1 with type 'a t := 'a T.t +end diff --git a/src/blit.ml b/src/blit.ml new file mode 100644 index 0000000..c2a3ae9 --- /dev/null +++ b/src/blit.ml @@ -0,0 +1,113 @@ +open! Import + + +include Blit_intf + +module type Sequence_gen = sig + type 'a t + val length : _ t -> int +end + +module Make_gen + (Src : Sequence_gen) + (Dst : sig + include Sequence_gen + val create_like : len:int -> 'a Src.t -> 'a t + val unsafe_blit : ('a Src.t, 'a t) blit + end) = struct + + let unsafe_blit = Dst.unsafe_blit + + let blit ~src ~src_pos ~dst ~dst_pos ~len = + Ordered_collection_common.check_pos_len_exn + ~pos:src_pos ~len ~total_length:(Src.length src); + Ordered_collection_common.check_pos_len_exn + ~pos:dst_pos ~len ~total_length:(Dst.length dst); + if len > 0 then unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len; + ;; + + let blito + ~src ?(src_pos = 0) ?(src_len = Src.length src - src_pos) ~dst ?(dst_pos = 0) + () = + blit ~src ~src_pos ~len:src_len ~dst ~dst_pos; + ;; + + (* [sub] and [subo] ensure that every position of the created sequence is populated by + an element of the source array. Thus every element of [dst] below is well + defined. *) + let sub src ~pos ~len = + Ordered_collection_common.check_pos_len_exn ~pos ~len + ~total_length:(Src.length src); + let dst = Dst.create_like ~len src in + if len > 0 then unsafe_blit ~src ~src_pos:pos ~dst ~dst_pos:0 ~len; + dst + ;; + + let subo ?(pos = 0) ?len src = + sub src ~pos ~len:(match len with Some i -> i | None -> Src.length src - pos) + ;; +end + +module Make1 + (Sequence : sig + include Sequence_gen + val create_like : len:int -> 'a t -> 'a t + val unsafe_blit : ('a t, 'a t) blit + end) = + Make_gen + (Sequence) + (Sequence) + +module Make1_generic + (Sequence : Sequence1) = + Make_gen + (Sequence) + (Sequence) + +module Make + (Sequence : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (t, t) blit + end) = struct + module Sequence = struct + type 'a t = Sequence.t + open Sequence + let create_like ~len _ = create ~len + let length = length + let unsafe_blit = unsafe_blit + end + include Make_gen (Sequence) (Sequence) +end + +module Make_distinct + (Src : Sequence) + (Dst : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (Src.t, t) blit + end) = + Make_gen + (struct + type 'a t = Src.t + open Src + let length = length + end) + (struct + type 'a t = Dst.t + open Dst + let length = length + let create_like ~len _ = create ~len + let unsafe_blit = unsafe_blit + end) + +module Make_to_string + (T : sig type t end) + (To_bytes : S_distinct with type src := T.t with type dst := bytes) += struct + open To_bytes + let sub src ~pos ~len = + Bytes0.unsafe_to_string ~no_mutation_while_string_reachable:(sub src ~pos ~len) + let subo ?pos ?len src = + Bytes0.unsafe_to_string ~no_mutation_while_string_reachable:(subo ?pos ?len src) +end diff --git a/src/blit.mli b/src/blit.mli new file mode 100644 index 0000000..ea4c29c --- /dev/null +++ b/src/blit.mli @@ -0,0 +1 @@ +include Blit_intf.Blit (** @inline *) diff --git a/src/blit_intf.ml b/src/blit_intf.ml new file mode 100644 index 0000000..b99f2a3 --- /dev/null +++ b/src/blit_intf.ml @@ -0,0 +1,167 @@ +(** Standard type for [blit] functions, and reusable code for validating [blit] + arguments. *) + +open! Import + +(** If [blit : (src, dst) blit], then [blit ~src ~src_pos ~len ~dst ~dst_pos] blits [len] + values from [src] starting at position [src_pos] to [dst] at position [dst_pos]. + Furthermore, [blit] raises if [src_pos], [len], and [dst_pos] don't specify valid + slices of [src] and [dst]. *) +type ('src, 'dst) blit + = src : 'src + -> src_pos : int + -> dst : 'dst + -> dst_pos : int + -> len : int + -> unit + +(** [blito] is like [blit], except that the [src_pos], [src_len], and [dst_pos] are + optional (hence the "o" in "blito"). Also, we use [src_len] rather than [len] as a + reminder that if [src_len] isn't supplied, then the default is to take the slice + running from [src_pos] to the end of [src]. *) +type ('src, 'dst) blito + = src : 'src + -> ?src_pos : int (** default is [0] *) + -> ?src_len : int (** default is [length src - src_pos] *) + -> dst : 'dst + -> ?dst_pos : int (** default is [0] *) + -> unit + -> unit + +(** If [sub : (src, dst) sub], then [sub ~src ~pos ~len] returns a sequence of type [dst] + containing [len] characters of [src] starting at [pos]. + + [subo] is like [sub], except [pos] and [len] are optional. *) +type ('src, 'dst) sub = 'src -> pos:int -> len:int -> 'dst +type ('src, 'dst) subo + = ?pos : int (** default is [0] *) + -> ?len : int (** default is [length src - pos] *) + -> 'src + -> 'dst + +(*_ These are not implemented less-general-in-terms-of-more-general because odoc produces + unreadable documentation in that case, with or without [inline] on [include]. *) + +module type S = sig + type t + val blit : (t, t) blit + val blito : (t, t) blito + val unsafe_blit : (t, t) blit + val sub : (t, t) sub + val subo : (t, t) subo +end + +module type S1 = sig + type 'a t + val blit : ('a t, 'a t) blit + val blito : ('a t, 'a t) blito + val unsafe_blit : ('a t, 'a t) blit + val sub : ('a t, 'a t) sub + val subo : ('a t, 'a t) subo +end + +module type S_distinct = sig + type src + type dst + val blit : (src, dst) blit + val blito : (src, dst) blito + val unsafe_blit : (src, dst) blit + val sub : (src, dst) sub + val subo : (src, dst) subo +end + +module type S_to_string = sig + type t + val sub : (t, string) sub + val subo : (t, string) subo +end + +(** Users of modules matching the blit signatures [S], [S1], and [S1_distinct] only need + to understand the code above. The code below is only for those that need to implement + modules that match those signatures. *) + +module type Sequence = sig + type t + val length : t -> int +end + +type 'a poly = 'a + +module type Sequence1 = sig + type 'a t + + (** [Make1*] guarantees to only call [create_like ~len t] with [len > 0] if [length t > + 0]. *) + val create_like : len:int -> 'a t -> 'a t + val length : _ t -> int + + val unsafe_blit : ('a t, 'a t) blit +end + +module type Blit = sig + type nonrec ('src, 'dst) blit = ('src, 'dst) blit + type nonrec ('src, 'dst) blito = ('src, 'dst) blito + type nonrec ('src, 'dst) sub = ('src, 'dst) sub + type nonrec ('src, 'dst) subo = ('src, 'dst) subo + + module type S = S + module type S1 = S1 + module type S_distinct = S_distinct + module type S_to_string = S_to_string + module type Sequence = Sequence + module type Sequence1 = Sequence1 + + (** There are various [Make*] functors that turn an [unsafe_blit] function into a [blit] + function. The functors differ in whether the sequence type is monomorphic or + polymorphic, and whether the src and dst types are distinct or are the same. + + The blit functions make sure the slices are valid and then call [unsafe_blit]. They + guarantee at a call [unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len] that: + + {[ + len > 0 + && src_pos >= 0 + && src_pos + len <= get_src_len src + && dst_pos >= 0 + && dst_pos + len <= get_dst_len dst + ]} + + The [Make*] functors also automatically create unit tests. *) + + (** [Make] is for blitting between two values of the same monomorphic type. *) + module Make + (Sequence : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (t, t) blit + end) + : S with type t := Sequence.t + + (** [Make_distinct] is for blitting between values of distinct monomorphic types. *) + module Make_distinct + (Src : Sequence) + (Dst : sig + include Sequence + val create : len:int -> t + val unsafe_blit : (Src.t, t) blit + end) + : S_distinct + with type src := Src.t + with type dst := Dst.t + + module Make_to_string + (T : sig type t end) + (To_bytes : S_distinct with type src := T.t with type dst := bytes) + : S_to_string with type t := T.t + + (** [Make1] is for blitting between two values of the same polymorphic type. *) + module Make1 + (Sequence : Sequence1) + : S1 with type 'a t := 'a Sequence.t + + (** [Make1_generic] is for blitting between two values of the same container type that's + not fully polymorphic (in the sense of Container.Generic). *) + module Make1_generic + (Sequence : Sequence1) + : S1 with type 'a t := 'a Sequence.t +end diff --git a/src/bool.ml b/src/bool.ml new file mode 100644 index 0000000..1fbf23b --- /dev/null +++ b/src/bool.ml @@ -0,0 +1,89 @@ +open! Import + +let invalid_argf = Printf.invalid_argf + +module T = struct + type t = bool [@@deriving_inline compare, enumerate, hash, sexp] + let compare : t -> t -> int = compare_bool + let all : t list = [false; true] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_bool + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_bool in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = bool_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_bool + [@@@end] + + let of_string = function + | "true" -> true + | "false" -> false + | s -> invalid_argf "Bool.of_string: expected true or false but got %s" s () + ;; + + let to_string = Caml.string_of_bool +end + +include T +include Comparator.Make(T) +include Comparable.Validate(T) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Bool" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Bool_replace_polymorphic_compare + + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +(* We use [Obj.magic] here as other implementations generate a conditional jump and the + performance difference is noticeable. *) +let to_int (x : bool) = (Caml.Obj.magic x : int) + +module Non_short_circuiting = struct + (* We don't expose this, since we don't want to break the invariant mentioned below of + (to_int true = 1) and (to_int false = 0). *) + let unsafe_of_int (x : int) = (Caml.Obj.magic x : bool) + + let (||) a b = + unsafe_of_int (to_int a lor to_int b) + + let (&&) a b = + unsafe_of_int (to_int a land to_int b) +end + +(* We do this as a direct assert on the theory that it's a cheap thing to test and a + really core invariant that we never expect to break, and we should be happy for a + program to fail immediately if this is violated. *) +let () = + assert (Poly.(=) (to_int true ) 1 && + Poly.(=) (to_int false) 0); +;; + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Bool_replace_polymorphic_compare diff --git a/src/bool.mli b/src/bool.mli new file mode 100644 index 0000000..163d304 --- /dev/null +++ b/src/bool.mli @@ -0,0 +1,34 @@ +(** Boolean type extended to be enumerable, hashable, sexpable, comparable, and + stringable. *) + +open! Import + +type t = bool [@@deriving_inline enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Identifiable.S with type t := t + +(** + - [to_int true = 1] + - [to_int false = 0] *) +val to_int : t -> int + +module Non_short_circuiting : sig + (** Non-short circuiting and branch-free boolean operators. + + The default versions of these infix operators are short circuiting, which + requires branching instructions to implement. The operators below are + instead branch-free, and therefore not short-circuiting. *) + + val (&&) : t -> t -> t + val (||) : t -> t -> t +end diff --git a/src/buffer.ml b/src/buffer.ml new file mode 100644 index 0000000..37338f7 --- /dev/null +++ b/src/buffer.ml @@ -0,0 +1,28 @@ +open! Import + +include Buffer_intf + +include Caml.Buffer + +let contents_bytes = to_bytes + +let add_substring t s ~pos ~len = add_substring t s pos len +let add_subbytes t s ~pos ~len = add_subbytes t s pos len +let sexp_of_t t = sexp_of_string (contents t) + +module To_bytes = + Blit.Make_distinct + (struct + type nonrec t = t + let length = length + end) + (struct + type t = Bytes.t + let create ~len = Bytes.create len + let length = Bytes.length + let unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len = + Caml.Buffer.blit src src_pos dst dst_pos len + end) + +include To_bytes +module To_string = Blit.Make_to_string (Caml.Buffer) (To_bytes) diff --git a/src/buffer.mli b/src/buffer.mli new file mode 100644 index 0000000..a5212a3 --- /dev/null +++ b/src/buffer.mli @@ -0,0 +1,8 @@ +(** Extensible character buffers. + + This module implements character buffers that automatically expand as necessary. It + provides cumulative concatenation of strings in quasi-linear time (instead of + quadratic time when strings are concatenated pairwise). +*) + +include Buffer_intf.Buffer (** @inline *) diff --git a/src/buffer_intf.ml b/src/buffer_intf.ml new file mode 100644 index 0000000..37ee8e5 --- /dev/null +++ b/src/buffer_intf.ml @@ -0,0 +1,81 @@ +open! Import + +module type S = sig + (** The abstract type of buffers. *) + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** [create n] returns a fresh buffer, initially empty. The [n] parameter is the + initial size of the internal storage medium that holds the buffer contents. That + storage is automatically reallocated when more than [n] characters are stored in the + buffer, but shrinks back to [n] characters when [reset] is called. + + For best performance, [n] should be of the same order of magnitude as the number of + characters that are expected to be stored in the buffer (for instance, 80 for a + buffer that holds one output line). Nothing bad will happen if the buffer grows + beyond that limit, however. In doubt, take [n = 16] for instance. *) + val create : int -> t + + (** Return a copy of the current contents of the buffer. The buffer itself is + unchanged. *) + val contents : t -> string + val contents_bytes : t -> bytes + + (** [blit ~src ~src_pos ~dst ~dst_pos ~len] copies [len] characters from the current + contents of the buffer [src], starting at offset [src_pos] to bytes [dst], starting + at character [dst_pos]. + + Raises [Invalid_argument] if [src_pos] and [len] do not designate a valid substring + of [src], or if [dst_pos] and [len] do not designate a valid substring of [dst]. *) + + include Blit.S_distinct with type src := t with type dst := bytes + module To_string : Blit.S_to_string with type t := t + + (** Gets the (zero-based) n-th character of the buffer. Raises [Invalid_argument] if + index out of bounds. *) + val nth : t -> int -> char + + (** Returns the number of characters currently contained in the buffer. *) + val length : t -> int + + (** Empties the buffer. *) + val clear : t -> unit + + (** Empties the buffer and deallocates the internal storage holding the buffer contents, + replacing it with the initial internal storage of length [n] that was allocated by + [create n]. For long-lived buffers that may have grown a lot, [reset] allows faster + reclamation of the space used by the buffer. *) + val reset : t -> unit + + (** [add_char b c] appends the character [c] at the end of the buffer [b]. *) + val add_char : t -> char -> unit + + (** [add_string b s] appends the string [s] at the end of the buffer [b]. *) + val add_string : t -> string -> unit + + (** [add_substring b s pos len] takes [len] characters from offset [pos] in string [s] + and appends them at the end of the buffer [b]. *) + val add_substring : t -> string -> pos:int -> len:int -> unit + + (** [add_bytes b s] appends the bytes [s] at the end of the buffer [b]. *) + val add_bytes : t -> bytes -> unit + + (** [add_subbytes b s pos len] takes [len] characters from offset [pos] in bytes [s] + and appends them at the end of the buffer [b]. *) + val add_subbytes : t -> bytes -> pos:int -> len:int -> unit + + (** [add_buffer b1 b2] appends the current contents of buffer [b2] at the end of buffer + [b1]. [b2] is not modified. *) + val add_buffer : t -> t -> unit +end + +module type Buffer = sig + module type S = S + + (** Buffers using strings as underlying storage medium: *) + + include S with type t = Caml.Buffer.t (** @open *) +end diff --git a/src/bytes.ml b/src/bytes.ml new file mode 100644 index 0000000..c701427 --- /dev/null +++ b/src/bytes.ml @@ -0,0 +1,125 @@ +open! Import + +let stage = Staged.stage + +module T = struct + type t = bytes [@@deriving_inline sexp] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = bytes_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_bytes + [@@@end] + include Bytes0 + + let module_name = "Base.Bytes" + + let pp fmt t = Caml.Format.fprintf fmt "%S" (to_string t) +end + +include T + +module To_bytes = + Blit.Make + (struct + include T + let create ~len = create len + end) +include To_bytes + +include Comparator.Make(T) +include Comparable.Validate(T) + +include Pretty_printer.Register_pp(T) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Bytes_replace_polymorphic_compare + +module To_string = Blit.Make_to_string (T) (To_bytes) + +module From_string = Blit.Make_distinct(struct + type t = string + let length = String.length + end) + (struct + type nonrec t = t + let create ~len = create len + let length = length + let unsafe_blit = unsafe_blit_string + end) + +let init n ~f = + if Int_replace_polymorphic_compare.(<) n 0 then Printf.invalid_argf "Bytes.init %d" n (); + let t = create n in + for i = 0 to n - 1 do + unsafe_set t i (f i) + done; + t + +let of_char_list l = + let t = create (List.length l) in + List.iteri l ~f:(fun i c -> set t i c); + t + +let to_list t = + let rec loop t i acc = + if Int_replace_polymorphic_compare.(<) i 0 + then acc + else loop t (i - 1) (unsafe_get t i :: acc) + in + loop t (length t - 1) [] + +let tr ~target ~replacement s = + for i = 0 to length s - 1 do + if Char.equal (unsafe_get s i) target + then unsafe_set s i replacement + done + +let tr_multi ~target ~replacement = + if Int_replace_polymorphic_compare.(=) (String.length target) 0 + then stage ignore + else if Int_replace_polymorphic_compare.(=) (String.length replacement) 0 + then invalid_arg "tr_multi: replacement is the empty string" + else + match Bytes_tr.tr_create_map ~target ~replacement with + | None -> stage ignore + | Some tr_map -> + stage (fun s -> + for i = 0 to length s - 1 do + unsafe_set s i (String.unsafe_get tr_map (Char.to_int (unsafe_get s i))) + done) + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let contains ?pos ?len t char = + let (pos, len) = + Ordered_collection_common.get_pos_len_exn () ?pos ?len ~total_length:(length t) + in + let last = pos + len in + let rec loop i = + Int_replace_polymorphic_compare.(<) i last + && (Char.equal (get t i) char || loop (i + 1)) + in + loop pos +;; + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Bytes_replace_polymorphic_compare diff --git a/src/bytes.mli b/src/bytes.mli new file mode 100644 index 0000000..88f2480 --- /dev/null +++ b/src/bytes.mli @@ -0,0 +1,197 @@ +(** OCaml's byte sequence type, semantically similar to a [char array], but + taking less space in memory. + + A byte sequence is a mutable data structure that contains a fixed-length + sequence of bytes (of type [char]). Each byte can be indexed in constant + time for reading or writing. *) + +open! Import + +type t = bytes [@@deriving_inline sexp] +include +sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + + +(** {1 Common Interfaces} *) + +include Blit .S with type t := t +include Comparable .S with type t := t +include Stringable .S with type t := t + +(** Note that [pp] allocates in order to preserve the state of the byte + sequence it was initially called with. *) +include Pretty_printer.S with type t := t + +module To_string : sig + val sub : (t, string) Blit.sub + val subo : (t, string) Blit.subo +end + +module From_string : Blit.S_distinct with type src := string and type dst := t + +(** [create len] returns a newly-allocated and uninitialized byte sequence of + length [len]. No guarantees are made about the contents of the return + value. *) +val create : int -> t + +(** [make len c] returns a newly-allocated byte sequence of length [len] filled + with the byte [c]. *) +val make : int -> char -> t + +(** [copy t] returns a newly-allocated byte sequence that contains the same + bytes as [t]. *) +val copy : t -> t + +(** [init len ~f] returns a newly-allocated byte sequence of length [len] with + index [i] in the sequence being initialized with the result of [f i]. *) +val init : int -> f:(int -> char) -> t + +(** [of_char_list l] returns a newly-allocated byte sequence where each byte in + the sequence corresponds to the byte in [l] at the same index. *) +val of_char_list : char list -> t + +(** [length t] returns the number of bytes in [t]. *) +val length : t -> int + +(** [get t i] returns the [i]th byte of [t]. *) +val get : t -> int -> char +external unsafe_get : t -> int -> char = "%bytes_unsafe_get" + +(** [set t i c] sets the [i]th byte of [t] to [c]. *) +val set : t -> int -> char -> unit +external unsafe_set : t -> int -> char -> unit = "%bytes_unsafe_set" + +(** [fill t ~pos ~len c] modifies [t] in place, replacing all the bytes from + [pos] to [pos + len] with [c]. *) +val fill : t -> pos:int -> len:int -> char -> unit + +(** [tr ~target ~replacement t] modifies [t] in place, replacing every instance + of [target] in [s] with [replacement]. *) +val tr : target:char -> replacement:char -> t -> unit + +(** [tr_multi ~target ~replacement] returns an in-place function that replaces + every instance of a character in [target] with the corresponding character + in [replacement]. + + If [replacement] is shorter than [target], it is lengthened by repeating + its last character. Empty [replacement] is illegal unless [target] also is. + + If [target] contains multiple copies of the same character, the last + corresponding [replacement] character is used. Note that character ranges + are {b not} supported, so [~target:"a-z"] means the literal characters ['a'], + ['-'], and ['z']. *) +val tr_multi : target:string -> replacement:string -> (t -> unit) Staged.t + +(** [to_list t] returns the bytes in [t] as a list of chars. *) +val to_list : t -> char list + +(** [contains ?pos ?len t c] returns [true] iff [c] appears in [t] between [pos] + and [pos + len]. *) +val contains : ?pos:int -> ?len:int -> t -> char -> bool + +(** Maximum length of a byte sequence, which is architecture-dependent. Attempting to + create a [Bytes] larger than this will raise an exception. *) +val max_length : int + +(** {2:unsafe Unsafe conversions (for advanced users)} + + This section describes unsafe, low-level conversion functions between + [bytes] and [string]. They might not copy the internal data; used + improperly, they can break the immutability invariant on strings provided + by the [-safe-string] option. They are available for expert library + authors, but for most purposes you should use the always-correct + {!Bytes.to_string} and {!Bytes.of_string} instead. +*) + +(** Unsafely convert a byte sequence into a string. + + To reason about the use of [unsafe_to_string], it is convenient to + consider an "ownership" discipline. A piece of code that + manipulates some data "owns" it; there are several disjoint ownership + modes, including: + {ul + {- Unique ownership: the data may be accessed and mutated} + {- Shared ownership: the data has several owners, that may only + access it, not mutate it.}} + Unique ownership is linear: passing the data to another piece of + code means giving up ownership (we cannot access the + data again). A unique owner may decide to make the data shared + (giving up mutation rights on it), but shared data may not become + uniquely-owned again. + [unsafe_to_string s] can only be used when the caller owns the byte + sequence [s] -- either uniquely or as shared immutable data. The + caller gives up ownership of [s], and gains (the same mode of) ownership + of the returned string. + There are two valid use-cases that respect this ownership + discipline: + {ol + {- The first is creating a string by initializing and mutating a byte + sequence that is never changed after initialization is performed. + {[ + let string_init len f : string = + let s = Bytes.create len in + for i = 0 to len - 1 do Bytes.set s i (f i) done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:s + ]} + This function is safe because the byte sequence [s] will never be + accessed or mutated after [unsafe_to_string] is called. The + [string_init] code gives up ownership of [s], and returns the + ownership of the resulting string to its caller. + + Note that it would be unsafe if [s] was passed as an additional + parameter to the function [f] as it could escape this way and be + mutated in the future -- [string_init] would give up ownership of + [s] to pass it to [f], and could not call [unsafe_to_string] + safely. + + We have provided the {!String.init}, {!String.map} and + {!String.mapi} functions to cover most cases of building + new strings. You should prefer those over [to_string] or + [unsafe_to_string] whenever applicable.} + {- The second is temporarily giving ownership of a byte sequence to + a function that expects a uniquely owned string and returns ownership + back, so that we can mutate the sequence again after the call ended. + {[ + let bytes_length (s : bytes) = + String.length + (Bytes.unsafe_to_string ~no_mutation_while_string_reachable:s) + ]} + In this use-case, we do not promise that [s] will never be mutated + after the call to [bytes_length s]. The {!String.length} function + temporarily borrows unique ownership of the byte sequence + (and sees it as a [string]), but returns this ownership back to + the caller, which may assume that [s] is still a valid byte + sequence after the call. Note that this is only correct because we + know that {!String.length} does not capture its argument -- it could + escape by a side-channel such as a memoization combinator. + The caller may not mutate [s] while the string is borrowed (it has + temporarily given up ownership). This affects concurrent programs, + but also higher-order functions: if {!String.length} returned + a closure to be called later, [s] should not be mutated until this + closure is fully applied and returns ownership.}} +*) +val unsafe_to_string : no_mutation_while_string_reachable:t -> string + +(** Unsafely convert a shared string to a byte sequence that should + not be mutated. + + The same ownership discipline that makes [unsafe_to_string] + correct applies to [unsafe_of_string_promise_no_mutation], + however unique ownership of string values is extremely difficult + to reason about correctly in practice. As such, one should always + assume strings are shared, never uniquely owned (For example, + string literals are implicitly shared by the compiler, so you + never uniquely own them) + + The only case we have reasonable confidence is safe is if the + produced [bytes] is shared -- used as an immutable byte + sequence. This is possibly useful for incremental migration of + low-level programs that manipulate immutable sequences of bytes + (for example {!Marshal.from_bytes}) and previously used the + [string] type for this purpose. +*) +val unsafe_of_string_promise_no_mutation : string -> t diff --git a/src/bytes0.ml b/src/bytes0.ml new file mode 100644 index 0000000..40ae910 --- /dev/null +++ b/src/bytes0.ml @@ -0,0 +1,61 @@ +(* [Bytes0] defines string functions that are primitives or can be simply + defined in terms of [Caml.Bytes]. [Bytes0] is intended to completely express + the part of [Caml.Bytes] that [Base] uses -- no other file in Base other + than bytes0.ml should use [Caml.Bytes]. [Bytes0] has few dependencies, and + so is available early in Base's build order. + + All Base files that need to use strings and come before [Base.Bytes] in + build order should do: + + {[ + module Bytes = Bytes0 + ]} + + Defining [module Bytes = Bytes0] is also necessary because it prevents + ocamldep from mistakenly causing a file to depend on [Base.Bytes]. *) + +let blit_string = Caml.Bytes.blit_string + +let sub_string t ~pos ~len = + Caml.Bytes.sub_string t pos len + +open! Import0 + +module Sys = Sys0 + +module Primitives = struct + external get : bytes -> int -> char = "%bytes_safe_get" + external length : bytes -> int = "%bytes_length" + external unsafe_get : bytes -> int -> char = "%bytes_unsafe_get" + include Bytes_set_primitives + + (* [unsafe_blit_string] is not exported in the [stdlib] so we export it here *) + external unsafe_blit_string + : src:string -> src_pos:int -> dst:bytes -> dst_pos:int -> len:int -> unit + = "caml_blit_string" [@@noalloc] +end + +include Primitives + +let max_length = Sys.max_string_length + +let blit = Caml.Bytes.blit +let compare = Caml.Bytes.compare +let copy = Caml.Bytes.copy +let create = Caml.Bytes.create +let fill = Caml.Bytes.fill +let make = Caml.Bytes.make +let sub = Caml.Bytes.sub +let unsafe_blit = Caml.Bytes.unsafe_blit + +let to_string = Caml.Bytes.to_string +let of_string = Caml.Bytes.of_string + +let unsafe_to_string ~no_mutation_while_string_reachable:s = + Caml.Bytes.unsafe_to_string s +let unsafe_of_string_promise_no_mutation = Caml.Bytes.unsafe_of_string + +(* These are eta expanded in order to label arguments, following the + Base conventions. *) +let blit_string ~src ~src_pos ~dst ~dst_pos ~len = + blit_string src src_pos dst dst_pos len diff --git a/src/bytes_tr.ml b/src/bytes_tr.ml new file mode 100644 index 0000000..b8f6e89 --- /dev/null +++ b/src/bytes_tr.ml @@ -0,0 +1,40 @@ +open! Import0.Int_replace_polymorphic_compare + +module Bytes = Bytes0 +module String = String0 + +(* Construct a byte string of length 256, mapping every input character code to + its corresponding output character. + + Benchmarks indicate that this is faster than the lambda (including cost of + this function), even if target/replacement are just 2 characters each. + + Return None if the translation map is equivalent to just the identity. *) +let tr_create_map ~target ~replacement = + let tr_map = Bytes.create 256 in + for i = 0 to 255 do + Bytes.unsafe_set tr_map i (Char.of_int_exn i) + done; + for i = 0 to (min (String.length target) (String.length replacement)) - 1 do + let index = Char.to_int (String.unsafe_get target i) in + Bytes.unsafe_set tr_map index (String.unsafe_get replacement i) + done; + let last_replacement = String.unsafe_get replacement (String.length replacement - 1) in + for i = min (String.length target) (String.length replacement) to String.length target - 1 do + let index = Char.to_int (String.unsafe_get target i) in + Bytes.unsafe_set tr_map index last_replacement + done; + + let rec have_any_different tr_map i = + if i = 256 + then false + else if Char.(<>) (Bytes0.unsafe_get tr_map i) (Char.of_int_exn i) + then true + else have_any_different tr_map (i + 1) + in + (* quick check on the first target character which will 99% be true *) + let first_target = String.get target 0 in + if Char.(<>) (Bytes0.unsafe_get tr_map (Char.to_int first_target)) first_target + || have_any_different tr_map 0 + then Some (Bytes0.unsafe_to_string ~no_mutation_while_string_reachable:tr_map) + else None diff --git a/src/char.ml b/src/char.ml new file mode 100644 index 0000000..ffa06b1 --- /dev/null +++ b/src/char.ml @@ -0,0 +1,105 @@ +open! Import + +module Array = Array0 +module String = String0 + +include Char0 + +module T = struct + type t = char [@@deriving_inline compare, hash, sexp] + let compare : t -> t -> int = compare_char + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_char + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_char in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = char_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_char + [@@@end] + + let to_string t = String.make 1 t + + let of_string s = + match String.length s with + | 1 -> String.get s 0 + | _ -> failwithf "Char.of_string: %S" s () +end + +include T + +include Identifiable.Make (struct + include T + let module_name = "Base.Char" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Char_replace_polymorphic_compare + +let all = + Array.init 256 ~f:unsafe_of_int + |> Array.to_list + +let is_lowercase = function + | 'a' .. 'z' -> true + | _ -> false + +let is_uppercase = function + | 'A' .. 'Z' -> true + | _ -> false + +let is_print = function + | ' ' .. '~' -> true + | _ -> false + +let is_whitespace = function + | '\t' + | '\n' + | '\011' (* vertical tab *) + | '\012' (* form feed *) + | '\r' + | ' ' + -> true + | _ + -> false +;; + +let is_digit = function + | '0' .. '9' -> true + | _ -> false + +let is_alpha = function + | 'a' .. 'z' | 'A' .. 'Z' -> true + | _ -> false + +(* Writing these out, instead of calling [is_alpha] and [is_digit], reduces + runtime by approx. 30% *) +let is_alphanum = function + | 'a' .. 'z' | 'A' .. 'Z' | '0' .. '9' -> true + | _ -> false + +let get_digit_unsafe t = to_int t - to_int '0' + +let get_digit_exn t = + if is_digit t + then get_digit_unsafe t + else failwithf "Char.get_digit_exn %C: not a digit" t () +;; + +let get_digit t = if is_digit t then Some (get_digit_unsafe t) else None + +module O = struct + let ( >= ) = ( >= ) + let ( <= ) = ( <= ) + let ( = ) = ( = ) + let ( > ) = ( > ) + let ( < ) = ( < ) + let ( <> ) = ( <> ) +end + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Char_replace_polymorphic_compare diff --git a/src/char.mli b/src/char.mli new file mode 100644 index 0000000..adfd960 --- /dev/null +++ b/src/char.mli @@ -0,0 +1,73 @@ +(** A type for 8-bit characters. *) + +open! Import + +(** An alias for the type of characters. *) +type t = char [@@deriving_inline enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Identifiable.S with type t := t + +module O : Comparisons.Infix with type t := t + +(** Returns the ASCII code of the argument. *) +val to_int : t -> int + +(** Returns the character with the given ASCII code or [None] is the argument is outside + the range 0 to 255. *) +val of_int : int -> t option + +(** Returns the character with the given ASCII code. Raises [Failure] if the argument is + outside the range 0 to 255. *) +val of_int_exn : int -> t + +val unsafe_of_int : int -> t + +(** Returns a string representing the given character, with special characters escaped + following the lexical conventions of OCaml. *) +val escaped : t -> string + +(** Converts the given character to its equivalent lowercase character. *) +val lowercase : t -> t + +(** Converts the given character to its equivalent uppercase character. *) +val uppercase : t -> t + +(** '0' - '9' *) +val is_digit : t -> bool + +(** 'a' - 'z' *) +val is_lowercase : t -> bool + +(** 'A' - 'Z' *) +val is_uppercase : t -> bool + +(** 'a' - 'z' or 'A' - 'Z' *) +val is_alpha : t -> bool + +(** 'a' - 'z' or 'A' - 'Z' or '0' - '9' *) +val is_alphanum : t -> bool + +(** ' ' - '~' *) +val is_print : t -> bool + +(** ' ' or '\t' or '\r' or '\n' *) +val is_whitespace : t -> bool + +(** Returns [Some i] if [is_digit c] and [None] otherwise. *) +val get_digit : t -> int option + +(** Returns [i] if [is_digit c] and raises [Failure] otherwise. *) +val get_digit_exn : t -> int + +val min_value : t +val max_value : t diff --git a/src/char0.ml b/src/char0.ml new file mode 100644 index 0000000..1e90989 --- /dev/null +++ b/src/char0.ml @@ -0,0 +1,39 @@ +(* [Char0] defines char functions that are primitives or can be simply defined in terms of + [Caml.Char]. [Char0] is intended to completely express the part of [Caml.Char] that + [Base] uses -- no other file in Base other than char0.ml should use [Caml.Char]. + [Char0] has few dependencies, and so is available early in Base's build order. All + Base files that need to use chars and come before [Base.Char] in build order should do + [module Char = Char0]. Defining [module Char = Char0] is also necessary because it + prevents ocamldep from mistakenly causing a file to depend on [Base.Char]. *) + +open! Import0 + +let failwithf = Printf.failwithf + +let escaped = Caml.Char.escaped +let lowercase = Caml.Char.lowercase_ascii +let to_int = Caml.Char.code +let unsafe_of_int = Caml.Char.unsafe_chr +let uppercase = Caml.Char.uppercase_ascii + +(* We use our own range test when converting integers to chars rather than + calling [Caml.Char.chr] because it's simple and it saves us a function call + and the try-with (exceptions cost, especially in the world with backtraces). *) +let int_is_ok i = 0 <= i && i <= 255 + +let min_value = unsafe_of_int 0 +let max_value = unsafe_of_int 255 + +let of_int i = + if int_is_ok i + then Some (unsafe_of_int i) + else None +;; + +let of_int_exn i = + if int_is_ok i + then unsafe_of_int i + else failwithf "Char.of_int_exn got integer out of range: %d" i () +;; + +let equal (t1 : char) t2 = Poly.equal t1 t2 diff --git a/src/comparable.ml b/src/comparable.ml new file mode 100644 index 0000000..97c9d6d --- /dev/null +++ b/src/comparable.ml @@ -0,0 +1,217 @@ +open! Import + +include Comparable_intf + +module Validate + (T : sig type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) : Validate with type t := T.t = +struct + + module V = Validate + open Maybe_bound + + let to_string t = Sexp.to_string (T.sexp_of_t t) + + let validate_bound ~min ~max t = + V.bounded ~name:to_string ~lower:min ~upper:max ~compare:T.compare t + ;; + + let validate_lbound ~min t = validate_bound ~min ~max:Unbounded t + let validate_ubound ~max t = validate_bound ~max ~min:Unbounded t +end + +module With_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + include Validate with type t := t + end) = struct + open T + + (* Preallocate the interesting bounds to minimize allocation in the implementations of + [validate_*]. *) + let excl_zero = Maybe_bound.Excl zero + let incl_zero = Maybe_bound.Incl zero + + let validate_positive t = validate_lbound ~min:excl_zero t + let validate_non_negative t = validate_lbound ~min:incl_zero t + let validate_negative t = validate_ubound ~max:excl_zero t + let validate_non_positive t = validate_ubound ~max:incl_zero t + let is_positive t = compare t zero > 0 + let is_non_negative t = compare t zero >= 0 + let is_negative t = compare t zero < 0 + let is_non_positive t = compare t zero <= 0 + let sign t = Sign0.of_int (compare t zero) +end + +module Validate_with_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + end) = struct + module V = Validate (T) + include V + include With_zero (struct include T include V end) +end + +module Poly (T : sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + module Replace_polymorphic_compare = struct + type t = T.t [@@deriving_inline sexp_of] + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = T.sexp_of_t + [@@@end] + include Poly + end + include Poly + + let between t ~low ~high = low <= t && t <= high + + let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + + let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + + let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + + module C = struct + include T + include Comparator.Make (Replace_polymorphic_compare) + end + include C + include Validate (struct type nonrec t = t [@@deriving_inline compare, sexp_of] + let compare : t -> t -> int = compare + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_t + [@@@end] end) +end + +module Make_using_comparator (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + end) : S with type t := T.t and type comparator_witness = T.comparator_witness = struct + module T = struct + include T + let compare = comparator.compare + end + include T + + module Replace_polymorphic_compare = struct + module Without_squelch = struct + let compare = T.compare + let (>) a b = compare a b > 0 + let (<) a b = compare a b < 0 + let (>=) a b = compare a b >= 0 + let (<=) a b = compare a b <= 0 + let (=) a b = compare a b = 0 + let (<>) a b = compare a b <> 0 + let equal = (=) + let min t t' = if t <= t' then t else t' + let max t t' = if t >= t' then t else t' + end + include Without_squelch + end + include Replace_polymorphic_compare.Without_squelch + let ascending = compare + let descending t t' = compare t' t + let between t ~low ~high = low <= t && t <= high + + let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + + let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + + let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + + include Validate (T) +end + +module Make (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) = Make_using_comparator(struct + include T + include Comparator.Make (T) + end) + +module Inherit + (C : sig type t [@@deriving_inline compare] + include sig [@@@ocaml.warning "-32"] val compare : t -> t -> int end[@@ocaml.doc + "@inline"] + [@@@end] end) + (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val component : t -> C.t + end) = + Make (struct + type t = T.t [@@deriving_inline sexp_of] + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = T.sexp_of_t + [@@@end] + let compare t t' = C.compare (T.component t) (T.component t') + end) + +(* compare [x] and [y] lexicographically using functions in the list [cmps] *) +let lexicographic cmps x y = + let rec loop = function + | cmp :: cmps -> let res = cmp x y in if res = 0 then loop cmps else res + | [] -> 0 + in + loop cmps +;; + +let lift cmp ~f x y = cmp (f x) (f y) diff --git a/src/comparable.mli b/src/comparable.mli new file mode 100644 index 0000000..da677d9 --- /dev/null +++ b/src/comparable.mli @@ -0,0 +1 @@ +include Comparable_intf.Comparable (** @inline *) diff --git a/src/comparable_intf.ml b/src/comparable_intf.ml new file mode 100644 index 0000000..b09359b --- /dev/null +++ b/src/comparable_intf.ml @@ -0,0 +1,217 @@ +open! Import + +module type Infix = Comparisons.Infix +module type Polymorphic_compare = Comparisons.S + +module type Validate = sig + type t + + val validate_lbound : min : t Maybe_bound.t -> t Validate.check + val validate_ubound : max : t Maybe_bound.t -> t Validate.check + val validate_bound + : min : t Maybe_bound.t + -> max : t Maybe_bound.t + -> t Validate.check +end + +module type With_zero = sig + type t + + val validate_positive : t Validate.check + val validate_non_negative : t Validate.check + val validate_negative : t Validate.check + val validate_non_positive : t Validate.check + val is_positive : t -> bool + val is_non_negative : t -> bool + val is_negative : t -> bool + val is_non_positive : t -> bool + + (** Returns [Neg], [Zero], or [Pos] in a way consistent with the above functions. *) + val sign : t -> Sign0.t +end + +module type S = sig + include Polymorphic_compare + + (** [ascending] is identical to [compare]. [descending x y = ascending y x]. These are + intended to be mnemonic when used like [List.sort ~compare:ascending] and [List.sort + ~cmp:descending], since they cause the list to be sorted in ascending or descending + order, respectively. *) + val ascending : t -> t -> int + val descending : t -> t -> int + + (** [between t ~low ~high] means [low <= t <= high] *) + val between : t -> low:t -> high:t -> bool + + (** [clamp_exn t ~min ~max] returns [t'], the closest value to [t] such that + [between t' ~low:min ~high:max] is true. + + Raises if [not (min <= max)]. *) + val clamp_exn : t -> min:t -> max:t -> t + val clamp : t -> min:t -> max:t -> t Or_error.t + + include Comparator.S with type t := t + + include Validate with type t := t +end + +(** Usage example: + + {[ + module Foo : sig + type t = ... + include Comparable.S with type t := t + end + ]} + + Then use [Comparable.Make] in the struct (see comparable.mli for an example). *) + +module type Comparable = sig + (** Defines functors for making modules comparable. *) + + (** Usage example: + + {[ + module Foo = struct + module T = struct + type t = ... [@@deriving_inline compare, sexp][@@@end] + end + include T + include Comparable.Make (T) + end + ]} + + Then include [Comparable.S] in the signature + + {[ + module Foo : sig + type t = ... + include Comparable.S with type t := t + end + ]} + + To add an [Infix] submodule: + + {[ + module C = Comparable.Make (T) + include C + module Infix = (C : Comparable.Infix with type t := t) + ]} + + A common pattern is to define a module [O] with a restricted signature. It aims to be + (locally) opened to bring useful operators into scope without shadowing unexpected + variable names. E.g., in the [Date] module: + + {[ + module O = struct + include (C : Comparable.Infix with type t := t) + let to_string t = .. + end + ]} + + Opening [Date] would shadow [now], but opening [Date.O] doesn't: + + {[ + let now = .. in + let someday = .. in + Date.O.(now > someday) + ]} *) + + + module type Infix = Infix + module type S = S + module type Polymorphic_compare = Polymorphic_compare + module type Validate = Validate + module type With_zero = With_zero + + (** [lexicographic cmps x y] compares [x] and [y] lexicographically using functions in the + list [cmps]. *) + val lexicographic : ('a -> 'a -> int) list -> 'a -> 'a -> int + + (** [lift cmp ~f x y] compares [x] and [y] by comparing [f x] and [f y] via [cmp]. *) + val lift : ('a -> 'a -> 'int_or_bool) -> f:('b -> 'a) -> ('b -> 'b -> 'int_or_bool) + + (** Inherit comparability from a component. *) + module Inherit + (C : sig type t [@@deriving_inline compare] + include sig [@@@ocaml.warning "-32"] val compare : t -> t -> int end[@@ocaml.doc + "@inline"] + [@@@end] end) + (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val component : t -> C.t + end) : S with type t := T.t + + module Make (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : S with type t := T.t + + module Make_using_comparator (T : sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + end) : S + with type t := T.t + with type comparator_witness := T.comparator_witness + + module Poly (T : sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) : S with type t := T.t + + module Validate (T : sig type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) + : Validate with type t := T.t + + module With_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + include Validate with type t := t + end) : With_zero with type t := T.t + + module Validate_with_zero + (T : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + val zero : t + end) + : sig + include Validate with type t := T.t + include With_zero with type t := T.t + end +end diff --git a/src/comparator.ml b/src/comparator.ml new file mode 100644 index 0000000..f843bb4 --- /dev/null +++ b/src/comparator.ml @@ -0,0 +1,127 @@ +open! Import + +type ('a, 'witness) t = + { compare : 'a -> 'a -> int + ; sexp_of_t : 'a -> Sexp.t + } + +type ('a, 'b) comparator = ('a, 'b) t + +module type S = sig + type t + type comparator_witness + val comparator : (t, comparator_witness) comparator +end + +module type S1 = sig + type 'a t + type comparator_witness + val comparator : ('a t, comparator_witness) comparator +end + +module type S_fc = sig + type comparable_t + include S with type t := comparable_t +end + +let make (type t) ~compare ~sexp_of_t = + (module struct + type comparable_t = t + type comparator_witness + let comparator = { compare; sexp_of_t } + end : S_fc with type comparable_t = t) + +module S_to_S1 (S : S) = struct + type 'a t = S.t + type comparator_witness = S.comparator_witness + open S + let comparator = comparator +end + +module Make (M : sig type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + include M + type comparator_witness + let comparator = M.({ compare; sexp_of_t }) +end + +module Make1 (M : sig + type 'a t + val compare : 'a t -> 'a t -> int + val sexp_of_t : 'a t -> Sexp.t + end) = struct + type comparator_witness + let comparator = M.({ compare; sexp_of_t }) +end + +module Poly = struct + type 'a t = 'a + include Make1 (struct + type 'a t = 'a + let compare = Poly.compare + let sexp_of_t _ = Sexp.Atom "_" + end) +end + +module type Derived = sig + type 'a t + type 'cmp comparator_witness + + val comparator + : ('a, 'cmp) comparator + -> ('a t, 'cmp comparator_witness) comparator +end + +module Derived (M : sig type 'a t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + type 'cmp comparator_witness + + let comparator a = { + compare = M.compare a.compare; + sexp_of_t = M.sexp_of_t a.sexp_of_t; + } +end + +module type Derived2 = sig + type ('a, 'b) t + type ('cmp_a, 'cmp_b) comparator_witness + + val comparator + : ('a, 'cmp_a) comparator + -> ('b, 'cmp_b) comparator + -> (('a, 'b) t, ('cmp_a, 'cmp_b) comparator_witness) comparator +end + +module Derived2 (M : sig type ('a, 'b) t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end) = struct + type ('cmp_a, 'cmp_b) comparator_witness + + let comparator a b = { + compare = M.compare a.compare b.compare; + sexp_of_t = M.sexp_of_t a.sexp_of_t b.sexp_of_t; + } +end diff --git a/src/comparator.mli b/src/comparator.mli new file mode 100644 index 0000000..c069f62 --- /dev/null +++ b/src/comparator.mli @@ -0,0 +1,121 @@ +(** A type-indexed value that allows one to compare (and for generating error messages, + serialize) values of the type in question. + + One of the type parameters is a phantom parameter used to distinguish comparators + potentially built on different comparison functions. In particular, we want to + distinguish those using polymorphic compare from those using a monomorphic compare. *) + +open! Import + +type ('a, 'witness) t = + private + { compare : 'a -> 'a -> int + ; sexp_of_t : 'a -> Sexp.t + } + +type ('a, 'b) comparator = ('a, 'b) t + +module type S = sig + type t + type comparator_witness + val comparator : (t, comparator_witness) comparator +end + +module type S1 = sig + type 'a t + type comparator_witness + val comparator : ('a t, comparator_witness) comparator +end + +module type S_fc = sig + type comparable_t + include S with type t := comparable_t +end + +(** [make] creates a comparator witness for the given comparison. It is intended as a + lightweight alternative to the functors below, to be used like so: + [include (val Comparator.make ~compare ~sexp_of_t)] *) +val make + : compare:('a -> 'a -> int) + -> sexp_of_t:('a -> Sexp.t) + -> (module S_fc with type comparable_t = 'a) + +module Poly : S1 with type 'a t = 'a + +module S_to_S1 (S : S) : S1 + with type 'a t = S.t + with type comparator_witness = S.comparator_witness + +(** [Make] creates a [comparator] value and its phantom [comparator_witness] type for a + nullary type. *) +module Make (M : sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : S with type t := M.t + +(** [Make1] creates a [comparator] value and its phantom [comparator_witness] type for a + unary type. It takes a [compare] and [sexp_of_t] that have + non-standard types because the [Comparator.t] type doesn't allow passing in + additional values for the type argument. *) +module Make1 (M : sig + type 'a t + val compare : 'a t -> 'a t -> int + val sexp_of_t : _ t -> Sexp.t + end) : S1 with type 'a t := 'a M.t + +module type Derived = sig + type 'a t + type 'cmp comparator_witness + + val comparator + : ('a, 'cmp) comparator + -> ('a t, 'cmp comparator_witness) comparator +end + +(** [Derived] creates a [comparator] function that constructs a comparator for the type + ['a t] given a comparator for the type ['a]. *) +module Derived (M : sig + type 'a t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : Derived with type 'a t := 'a M.t + +module type Derived2 = sig + type ('a, 'b) t + type ('cmp_a, 'cmp_b) comparator_witness + + val comparator + : ('a, 'cmp_a) comparator + -> ('b, 'cmp_b) comparator + -> (('a, 'b) t, ('cmp_a, 'cmp_b) comparator_witness) comparator +end + +(** [Derived2] creates a [comparator] function that constructs a comparator for the type + [('a, 'b) t] given comparators for the type ['a] and ['b]. *) +module Derived2 (M : sig + type ('a, 'b) t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end) : Derived2 with type ('a, 'b) t := ('a, 'b) M.t diff --git a/src/comparisons.ml b/src/comparisons.ml new file mode 100644 index 0000000..1cbf91c --- /dev/null +++ b/src/comparisons.ml @@ -0,0 +1,29 @@ +(** Interfaces for infix comparison operators and comparison functions. *) + +open! Import + +(** [Infix] lists the typical infix comparison operators. These functions are provided by + [<M>.O] modules, i.e., modules that expose monomorphic infix comparisons over some + [<M>.t]. *) +module type Infix = sig + type t + val ( >= ) : t -> t -> bool + val ( <= ) : t -> t -> bool + val ( = ) : t -> t -> bool + val ( > ) : t -> t -> bool + val ( < ) : t -> t -> bool + val ( <> ) : t -> t -> bool +end + +module type S = sig + include Infix + + val equal : t -> t -> bool + + (** [compare t1 t2] returns 0 if [t1] is equal to [t2], a negative integer if [t1] is + less than [t2], and a positive integer if [t1] is greater than [t2]. *) + val compare : t -> t -> int + + val min : t -> t -> t + val max : t -> t -> t +end diff --git a/src/container.ml b/src/container.ml new file mode 100644 index 0000000..7a253cd --- /dev/null +++ b/src/container.ml @@ -0,0 +1,175 @@ +open! Import + +module Array = Array0 +module List = List0 + +include Container_intf + +let with_return = With_return.with_return + +type ('t, 'a, 'accum) fold = 't -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum +type ('t, 'a) iter = 't -> f:('a -> unit) -> unit +type 't length = 't -> int + +let iter ~fold t ~f = fold t ~init:() ~f:(fun () a -> f a) + +let count ~fold t ~f = fold t ~init:0 ~f:(fun n a -> if f a then n + 1 else n) + +let sum (type a) ~fold (module M : Summable with type t = a) t ~f = + fold t ~init:M.zero ~f:(fun n a -> M.(+) n (f a)) +;; + +let fold_result ~fold ~init ~f t = + with_return (fun {return} -> + Result.Ok (fold t ~init ~f:(fun acc item -> + match f acc item with + | Result.Ok x -> x + | Error _ as e -> return e))) +;; + +let fold_until ~fold ~init ~f ~finish t = + with_return (fun {return} -> + finish (fold t ~init ~f:(fun acc item -> + match f acc item with + | Continue_or_stop.Continue x -> x + | Stop x -> return x))) +;; + +let min_elt ~fold t ~compare = + fold t ~init:None ~f:(fun acc elt -> + match acc with + | None -> Some elt + | Some min -> if compare min elt > 0 then Some elt else acc) +;; + +let max_elt ~fold t ~compare = + fold t ~init:None ~f:(fun acc elt -> + match acc with + | None -> Some elt + | Some max -> if compare max elt < 0 then Some elt else acc) +;; + +let length ~fold c = fold c ~init:0 ~f:(fun acc _ -> acc + 1) + +let is_empty ~iter c = + with_return (fun r -> + iter c ~f:(fun _ -> r.return false); + true) +;; + +let exists ~iter c ~f = + with_return (fun r -> + iter c ~f:(fun x -> if f x then r.return true); + false) +;; + +let for_all ~iter c ~f = + with_return (fun r -> + iter c ~f:(fun x -> if not (f x) then r.return false); + true) +;; + +let find_map ~iter t ~f = + with_return (fun r -> + iter t ~f:(fun x -> match f x with None -> () | Some _ as res -> r.return res); + None) +;; + +let find ~iter c ~f = + with_return (fun r -> + iter c ~f:(fun x -> if f x then r.return (Some x)); + None) +;; + +let to_list ~fold c = List.rev (fold c ~init:[] ~f:(fun acc x -> x :: acc)) + +let to_array ~length ~iter c = + let array = ref [||] in + let i = ref 0 in + iter c ~f:(fun x -> + if !i = 0 then (array := Array.create ~len:(length c) x); + !array.(!i) <- x; + incr i); + !array +;; + +module Make_gen (T : Make_gen_arg) : sig + include Generic with type 'a t := 'a T.t with type 'a elt := 'a T.elt +end = struct + let fold = T.fold + + let iter = + match T.iter with + | `Custom iter -> iter + | `Define_using_fold -> fun t ~f -> iter ~fold t ~f + ;; + + let length = + match T.length with + | `Custom length -> length + | `Define_using_fold -> fun t -> length ~fold t + ;; + + let is_empty t = is_empty ~iter t + let sum m t = sum ~fold m t + let count t ~f = count ~fold t ~f + let exists t ~f = exists ~iter t ~f + let for_all t ~f = for_all ~iter t ~f + let find_map t ~f = find_map ~iter t ~f + let find t ~f = find ~iter t ~f + let to_list t = to_list ~fold t + let to_array t = to_array ~length ~iter t + + let min_elt t ~compare = min_elt ~fold t ~compare + let max_elt t ~compare = max_elt ~fold t ~compare + + let fold_result t ~init ~f = fold_result t ~fold ~init ~f + let fold_until t ~init ~f ~finish = fold_until t ~fold ~init ~f ~finish +end + +module Make (T : Make_arg) = struct + include + Make_gen (struct + include T + type 'a elt = 'a + end) + + let mem t a ~equal = exists t ~f:(equal a) +end + +module Make0 (T : Make0_arg) = struct + include + Make_gen (struct + include (T : Make0_arg with type t := T.t with module Elt := T.Elt) + type 'a t = T.t + type 'a elt = T.Elt.t + end) + + let mem t elt = exists t ~f:(T.Elt.equal elt) +end + +open T + + +(* The following functors exist as a consistency check among all the various [S?] + interfaces. They ensure that each particular [S?] is an instance of a more generic + signature. *) +module Check (T : T1) (Elt : T1) + (M : Generic with type 'a t := 'a T.t with type 'a elt := 'a Elt.t) = struct end + +module Check_S0 (M : S0) = + Check (struct type 'a t = M.t end) (struct type 'a t = M.elt end) (M) + +module Check_S0_phantom (M : S0_phantom) = + Check (struct type 'a t = 'a M.t end) (struct type 'a t = M.elt end) (M) + +module Check_S1 (M : S1) = + Check (struct type 'a t = 'a M.t end) (struct type 'a t = 'a end) (M) + +type phantom + +module Check_S1_phantom (M : S1_phantom) = + Check (struct type 'a t = ('a, phantom) M.t end) (struct type 'a t = 'a end) (M) + +module Check_S1_phantom_invariant (M : S1_phantom_invariant) = + Check (struct type 'a t = ('a, phantom) M.t end) (struct type 'a t = 'a end) (M) diff --git a/src/container.mli b/src/container.mli new file mode 100644 index 0000000..37313b1 --- /dev/null +++ b/src/container.mli @@ -0,0 +1 @@ +include Container_intf.Container (** @inline *) diff --git a/src/container_intf.ml b/src/container_intf.ml new file mode 100644 index 0000000..9c41fe4 --- /dev/null +++ b/src/container_intf.ml @@ -0,0 +1,624 @@ +(** Provides generic signatures for container data structures. + + These signatures include functions ([iter], [fold], [exists], [for_all], ...) that + you would expect to find in any container. Used by including [Container.S0] or + [Container.S1] in the signature for every container-like data structure ([Array], + [List], [String], ...) to ensure a consistent interface. *) + +open! Import + +module Export = struct + (** [Continue_or_stop.t] is used by the [f] argument to [fold_until] in order to + indicate whether folding should continue, or stop early. *) + module Continue_or_stop = struct + type ('a, 'b) t = + | Continue of 'a + | Stop of 'b + end +end +include Export + +module type Summable = sig + type t + + (** The result of summing no values. *) + val zero : t + + (** An operation that combines two [t]'s and handles [zero + x] by just returning [x], + as well as in the symmetric case. *) + val (+) : t -> t -> t +end + +(** Signature for monomorphic container, e.g., string. *) +module type S0 = sig + type t + type elt + + (** Checks whether the provided element is there, using equality on [elt]s. *) + val mem : t -> elt -> bool + + val length : t -> int + + val is_empty : t -> bool + + (** [iter] must allow exceptions raised in [f] to escape, terminating the iteration + cleanly. The same holds for all functions below taking an [f]. *) + val iter : t -> f:(elt -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t]. *) + val fold : t -> init:'accum -> f:('accum -> elt -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : t -> f:(elt -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : t -> f:(elt -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : t -> f:(elt -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. *) + val sum + : (module Summable with type t = 'sum) + -> t -> f:(elt -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : t -> f:(elt -> bool) -> elt option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : t -> f:(elt -> 'a option) -> 'a option + + val to_list : t -> elt list + val to_array : t -> elt array + + (** Returns a min (resp. max) element from the collection using the provided [compare] + function. In case of a tie, the first element encountered while traversing the + collection is returned. The implementation uses [fold] so it has the same + complexity as [fold]. Returns [None] iff the collection is empty. *) + val min_elt : t -> compare:(elt -> elt -> int) -> elt option + val max_elt : t -> compare:(elt -> elt -> int) -> elt option +end + +module type S0_phantom = sig + type elt + type 'a t + + (** Checks whether the provided element is there, using equality on [elt]s. *) + val mem : _ t -> elt -> bool + + val length : _ t -> int + + val is_empty : _ t -> bool + + val iter : _ t -> f:(elt -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t]. *) + val fold : _ t -> init:'accum -> f:('accum -> elt -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : _ t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : _ t + -> init:'accum + -> f:('accum -> elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : _ t -> f:(elt -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : _ t -> f:(elt -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : _ t -> f:(elt -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. The order in which the + elements will be summed is unspecified. *) + val sum + : (module Summable with type t = 'sum) + -> _ t -> f:(elt -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : _ t -> f:(elt -> bool) -> elt option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : _ t -> f:(elt -> 'a option) -> 'a option + + val to_list : _ t -> elt list + val to_array : _ t -> elt array + + (** Returns a min (resp max) element from the collection using the provided [compare] + function, or [None] if the collection is empty. In case of a tie, the first element + encountered while traversing the collection is returned. *) + val min_elt : _ t -> compare:(elt -> elt -> int) -> elt option + val max_elt : _ t -> compare:(elt -> elt -> int) -> elt option +end + +(** Signature for polymorphic container, e.g., ['a list] or ['a array]. *) +module type S1 = sig + type 'a t + + (** Checks whether the provided element is there, using [equal]. *) + val mem : 'a t -> 'a -> equal:('a -> 'a -> bool) -> bool + + val length : 'a t -> int + + val is_empty : 'a t -> bool + + val iter : 'a t -> f:('a -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t] *) + val fold : 'a t -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : 'a t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : 'a t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : 'a t -> f:('a -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : 'a t -> f:('a -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : 'a t -> f:('a -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. *) + val sum + : (module Summable with type t = 'sum) + -> 'a t -> f:('a -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : 'a t -> f:('a -> bool) -> 'a option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : 'a t -> f:('a -> 'b option) -> 'b option + + val to_list : 'a t -> 'a list + val to_array : 'a t -> 'a array + + (** Returns a minimum (resp maximum) element from the collection using the provided + [compare] function, or [None] if the collection is empty. In case of a tie, the first + element encountered while traversing the collection is returned. The implementation + uses [fold] so it has the same complexity as [fold]. *) + val min_elt : 'a t -> compare:('a -> 'a -> int) -> 'a option + val max_elt : 'a t -> compare:('a -> 'a -> int) -> 'a option +end + +module type S1_phantom_invariant = sig + type ('a, 'phantom) t + + (** Checks whether the provided element is there, using [equal]. *) + val mem : ('a, _) t -> 'a -> equal:('a -> 'a -> bool) -> bool + + val length : (_, _) t -> int + val is_empty : (_, _) t -> bool + val iter : ('a, _) t -> f:('a -> unit) -> unit + + (** [fold t ~init ~f] returns [f (... f (f (f init e1) e2) e3 ...) en], where [e1..en] + are the elements of [t]. *) + val fold : ('a, _) t -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum + + (** [fold_result t ~init ~f] is a short-circuiting version of [fold] that runs in the + [Result] monad. If [f] returns an [Error _], that value is returned without any + additional invocations of [f]. *) + val fold_result + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f ~finish] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. If [f] never returns [Stop _], the final result + is computed by [finish]. + + Example: + + {[ + type maybe_negative = + | Found_negative of int + | All_nonnegative of { sum : int } + + (** [first_neg_or_sum list] returns the first negative number in [list], if any, + otherwise returns the sum of the list. *) + let first_neg_or_sum = + List.fold_until ~init:0 + ~f:(fun sum x -> + if x < 0 + then Stop (Found_negative x) + else Continue (sum + x)) + ~finish:(fun sum -> All_nonnegative { sum }) + ;; + + let x = first_neg_or_sum [1; 2; 3; 4; 5] + val x : maybe_negative = All_nonnegative {sum = 15} + + let y = first_neg_or_sum [1; 2; -3; 4; 5] + val y : maybe_negative = Found_negative -3 + ]} *) + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Returns [true] if and only if there exists an element for which the provided + function evaluates to [true]. This is a short-circuiting operation. *) + val exists : ('a, _) t -> f:('a -> bool) -> bool + + (** Returns [true] if and only if the provided function evaluates to [true] for all + elements. This is a short-circuiting operation. *) + val for_all : ('a, _) t -> f:('a -> bool) -> bool + + (** Returns the number of elements for which the provided function evaluates to true. *) + val count : ('a, _) t -> f:('a -> bool) -> int + + (** Returns the sum of [f i] for all [i] in the container. *) + val sum + : (module Summable with type t = 'sum) + -> ('a, _) t -> f:('a -> 'sum) -> 'sum + + (** Returns as an [option] the first element for which [f] evaluates to true. *) + val find : ('a, _) t -> f:('a -> bool) -> 'a option + + (** Returns the first evaluation of [f] that returns [Some], and returns [None] if there + is no such element. *) + val find_map : ('a, _) t -> f:('a -> 'b option) -> 'b option + + val to_list : ('a, _) t -> 'a list + val to_array : ('a, _) t -> 'a array + + (** Returns a min (resp max) element from the collection using the provided [compare] + function. In case of a tie, the first element encountered while traversing the + collection is returned. The implementation uses [fold] so it has the same complexity + as [fold]. Returns [None] iff the collection is empty. *) + val min_elt : ('a, _) t -> compare:('a -> 'a -> int) -> 'a option + val max_elt : ('a, _) t -> compare:('a -> 'a -> int) -> 'a option +end + +module type S1_phantom = sig + type ('a, +'phantom) t + include S1_phantom_invariant with type ('a, 'phantom) t := ('a, 'phantom) t +end + +module type Generic = sig + type 'a t + type 'a elt + val length : _ t -> int + val is_empty : _ t -> bool + val iter : 'a t -> f:('a elt -> unit) -> unit + val fold : 'a t -> init:'accum -> f:('accum -> 'a elt -> 'accum) -> 'accum + val fold_result + : 'a t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + val fold_until + : 'a t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + val exists : 'a t -> f:('a elt -> bool) -> bool + val for_all : 'a t -> f:('a elt -> bool) -> bool + val count : 'a t -> f:('a elt -> bool) -> int + val sum + : (module Summable with type t = 'sum) + -> 'a t -> f:('a elt -> 'sum) -> 'sum + val find : 'a t -> f:('a elt -> bool) -> 'a elt option + val find_map : 'a t -> f:('a elt -> 'b option) -> 'b option + val to_list : 'a t -> 'a elt list + val to_array : 'a t -> 'a elt array + val min_elt : 'a t -> compare:('a elt -> 'a elt -> int) -> 'a elt option + val max_elt : 'a t -> compare:('a elt -> 'a elt -> int) -> 'a elt option +end + +module type Generic_phantom = sig + type ('a, 'phantom) t + type 'a elt + val length : (_, _) t -> int + val is_empty : (_, _) t -> bool + val iter : ('a, _) t -> f:('a elt -> unit) -> unit + val fold : ('a, _) t -> init:'accum -> f:('accum -> 'a elt -> 'accum) -> 'accum + val fold_result + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a elt -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + val exists : ('a, _) t -> f:('a elt -> bool) -> bool + val for_all : ('a, _) t -> f:('a elt -> bool) -> bool + val count : ('a, _) t -> f:('a elt -> bool) -> int + val sum + : (module Summable with type t = 'sum) + -> ('a, _) t -> f:('a elt -> 'sum) -> 'sum + val find : ('a, _) t -> f:('a elt -> bool) -> 'a elt option + val find_map : ('a, _) t -> f:('a elt -> 'b option) -> 'b option + val to_list : ('a, _) t -> 'a elt list + val to_array : ('a, _) t -> 'a elt array + val min_elt : ('a, _) t -> compare:('a elt -> 'a elt -> int) -> 'a elt option + val max_elt : ('a, _) t -> compare:('a elt -> 'a elt -> int) -> 'a elt option +end + +module type Make_gen_arg = sig + type 'a t + type 'a elt + + val fold : 'a t -> init:'accum -> f:('accum -> 'a elt -> 'accum) -> 'accum + + (** The [iter] argument to [Container.Make] specifies how to implement the + container's [iter] function. [`Define_using_fold] means to define [iter] + via: + + {[ + iter t ~f = Container.iter ~fold t ~f + ]} + + [`Custom] overrides the default implementation, presumably with something more + efficient. Several other functions returned by [Container.Make] are defined in + terms of [iter], so passing in a more efficient [iter] will improve their efficiency + as well. *) + val iter : [ `Define_using_fold + | `Custom of 'a t -> f:('a elt -> unit) -> unit + ] + + (** The [length] argument to [Container.Make] specifies how to implement the + container's [length] function. [`Define_using_fold] means to define + [length] via: + + {[ + length t ~f = Container.length ~fold t ~f + ]} + + [`Custom] overrides the default implementation, presumably with something more + efficient. Several other functions returned by [Container.Make] are defined in + terms of [length], so passing in a more efficient [length] will improve their + efficiency as well. *) + val length : [ `Define_using_fold + | `Custom of 'a t -> int + ] +end + +module type Make_arg = Make_gen_arg with type 'a elt := 'a Monad.Ident.t + +module type Make0_arg = sig + module Elt : sig + type t + val equal : t -> t -> bool + end + + type t + + val fold : t -> init:'accum -> f:('accum -> Elt.t -> 'accum) -> 'accum + val iter : [ `Define_using_fold + | `Custom of t -> f:(Elt.t -> unit) -> unit + ] + val length : [ `Define_using_fold + | `Custom of t -> int + ] +end + +module type Container = sig + include module type of struct include Export end + + module type S0 = S0 + module type S0_phantom = S0_phantom + module type S1 = S1 + module type S1_phantom_invariant = S1_phantom_invariant + module type S1_phantom = S1_phantom + module type Generic = Generic + module type Generic_phantom = Generic_phantom + + module type Summable = Summable + + (** Generic definitions of container operations in terms of [fold]. + + E.g.: [iter ~fold t ~f = fold t ~init:() ~f:(fun () a -> f a)]. *) + + type ('t, 'a, 'accum) fold = 't -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum + type ('t, 'a) iter = 't -> f:('a -> unit) -> unit + type 't length = 't -> int + + val iter : fold:('t, 'a, unit ) fold -> ('t, 'a) iter + val count : fold:('t, 'a, int ) fold -> 't -> f:('a -> bool) -> int + val min_elt : fold:('t, 'a, 'a option) fold -> 't -> compare:('a -> 'a -> int) -> 'a option + val max_elt : fold:('t, 'a, 'a option) fold -> 't -> compare:('a -> 'a -> int) -> 'a option + val length : fold:('t, _, int ) fold -> 't -> int + val to_list : fold:('t, 'a, 'a list ) fold -> 't -> 'a list + val sum + : fold : ('t, 'a, 'sum) fold + -> (module Summable with type t = 'sum) + -> 't -> f:('a -> 'sum) -> 'sum + + val fold_result + : fold:('t, 'a, 'b) fold + -> init:'b + -> f:('b -> 'a -> ('b, 'e) Result.t) + -> 't + -> ('b, 'e) Result.t + + val fold_until + : fold:('t, 'a, 'b) fold + -> init:'b + -> f:('b -> 'a -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 't + -> 'final + + (** Generic definitions of container operations in terms of [iter] and [length]. *) + val is_empty : iter:('t, 'a) iter -> 't -> bool + val exists : iter:('t, 'a) iter -> 't -> f:('a -> bool) -> bool + val for_all : iter:('t, 'a) iter -> 't -> f:('a -> bool) -> bool + val find : iter:('t, 'a) iter -> 't -> f:('a -> bool) -> 'a option + val find_map : iter:('t, 'a) iter -> 't -> f:('a -> 'b option) -> 'b option + val to_array : length:'t length -> iter:('t, 'a) iter -> 't -> 'a array + + (** The idiom for using [Container.Make] is to bind the resulting module and to + explicitly import each of the functions that one wants: + + {[ + module C = Container.Make (struct ... end) + let count = C.count + let exists = C.exists + let find = C.find + (* ... *) + ]} + + This is preferable to: + + {[ + include Container.Make (struct ... end) + ]} + + because the [include] makes it too easy to shadow specialized implementations of + container functions ([length] being a common one). + + [Container.Make0] is like [Container.Make], but for monomorphic containers like + [string]. *) + module Make (T : Make_arg) : S1 with type 'a t := 'a T.t + module Make0 (T : Make0_arg) : S0 with type t := T.t and type elt := T.Elt.t +end diff --git a/src/discover/discover.ml b/src/discover/discover.ml new file mode 100644 index 0000000..f2e0b92 --- /dev/null +++ b/src/discover/discover.ml @@ -0,0 +1,20 @@ +open Configurator.V1 + +let program = + {| +int main(int argc, char ** argv) +{ + return __builtin_popcount(argc); +} +|} +;; + +let () = + let output = ref "" in + main + ~name:"discover" + ~args:[ "-o", Set_string output, "FILENAME output file" ] + (fun c -> + let has_popcnt = c_test c ~c_flags:[ "-mpopcnt" ] program in + Flags.write_sexp !output (if has_popcnt then [ "-mpopcnt" ] else [])) +;; diff --git a/src/discover/discover.mli b/src/discover/discover.mli new file mode 100644 index 0000000..e790aeb --- /dev/null +++ b/src/discover/discover.mli @@ -0,0 +1 @@ +(* empty *) diff --git a/src/discover/dune b/src/discover/dune new file mode 100644 index 0000000..9e74a1e --- /dev/null +++ b/src/discover/dune @@ -0,0 +1,2 @@ +(executables (names discover) (libraries dune.configurator) + (preprocess no_preprocessing))
\ No newline at end of file diff --git a/src/dune b/src/dune new file mode 100644 index 0000000..422e43d --- /dev/null +++ b/src/dune @@ -0,0 +1,30 @@ +(rule (targets int63_backend.ml) + (deps (:first_dep select-int63-backend/select.ml)) + (action + (run %{ocaml} %{first_dep} -portable-int63 + !%{lib-available:base-native-int63} -arch-sixtyfour %{arch_sixtyfour} -o + %{targets}))) + +(rule (targets bytes_set_primitives.ml) + (deps (:first_dep select-bytes-set-primitives/select.ml)) + (action + (run %{ocaml} %{first_dep} -ocaml-version %{ocaml_version} -o %{targets}))) + +(rule (targets pow_overflow_bounds.ml) + (deps (:first_dep ../generate/generate_pow_overflow_bounds.exe)) + (action (run %{first_dep} -atomic -o %{targets})) (mode fallback)) + +(library (name base) (public_name base) + (libraries caml sexplib0 shadow_stdlib) (install_c_headers internalhash) + (c_flags :standard -D_LARGEFILE64_SOURCE (:include mpopcnt.sexp)) + (c_names exn_stubs int_math_stubs internalhash_stubs hash_stubs am_testing) + (preprocess no_preprocessing) + (lint + (pps ppx_base ppx_base_lint -check-doc-comments -type-conv-keep-w32=impl + -apply=js_style,base_lint,type_conv)) + (js_of_ocaml (javascript_files runtime.js))) + +(rule (targets mpopcnt.sexp) (deps discover/discover.exe) + (action (run ./discover/discover.exe -o %{targets}))) + +(ocamllex hex_lexer) diff --git a/src/either.ml b/src/either.ml new file mode 100644 index 0000000..3a0165a --- /dev/null +++ b/src/either.ml @@ -0,0 +1,289 @@ +open! Import + +include Either_intf + +module Array = Array0 + +type ('f, 's) t = + | First of 'f + | Second of 's +[@@deriving_inline compare, hash, sexp] +let compare : + 'f 's . + ('f -> 'f -> int) -> ('s -> 's -> int) -> ('f, 's) t -> ('f, 's) t -> int + = + fun _cmp__f -> + fun _cmp__s -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (First _a__003_, First _b__004_) -> _cmp__f _a__003_ _b__004_ + | (First _, _) -> (-1) + | (_, First _) -> 1 + | (Second _a__005_, Second _b__006_) -> + _cmp__s _a__005_ _b__006_) +let hash_fold_t : type f s. + (Ppx_hash_lib.Std.Hash.state -> f -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> s -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> (f, s) t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_f -> + fun _hash_fold_s -> + fun hsv -> + fun arg -> + match arg with + | First _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in _hash_fold_f hsv _a0 + | Second _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in _hash_fold_s hsv _a0 +let t_of_sexp : type f s. + (Ppx_sexp_conv_lib.Sexp.t -> f) -> + (Ppx_sexp_conv_lib.Sexp.t -> s) -> Ppx_sexp_conv_lib.Sexp.t -> (f, s) t + = + let _tp_loc = "src/either.ml.t" in + fun _of_f -> + fun _of_s -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("first"|"First" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_f v0 in First v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("second"|"Second" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_s v0 in Second v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("first"|"First") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("second"|"Second") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : type f s. + (f -> Ppx_sexp_conv_lib.Sexp.t) -> + (s -> Ppx_sexp_conv_lib.Sexp.t) -> (f, s) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_f -> + fun _of_s -> + function + | First v0 -> + let v0 = _of_f v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "First"; v0] + | Second v0 -> + let v0 = _of_s v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Second"; v0] +[@@@end] + +let swap = function + | First x -> Second x + | Second x -> First x +;; + +let is_first = function + | First _ -> true + | Second _ -> false +;; + +let is_second = function + | First _ -> false + | Second _ -> true +;; + +let value (First x | Second x) = x +;; + +let value_map t ~first ~second = + match t with + | First x -> first x + | Second x -> second x +;; + +let iter = value_map +;; + +let map t ~first ~second = + match t with + | First x -> First (first x) + | Second x -> Second (second x) +;; + +let first x = First x +let second x = Second x +;; + +let equal eq1 eq2 t1 t2 = + match t1, t2 with + | First x, First y -> eq1 x y + | Second x, Second y -> eq2 x y + | First _, Second _ + | Second _, First _ -> false +;; + +let invariant f s = function + | First x -> f x + | Second y -> s y +;; + +module Make_focused (M : sig + type (+'a, +'b) t + + val return : 'a -> ('a, _) t + val other : 'b -> (_, 'b) t + + val either : ('a, 'b) t -> return:('a -> 'c) -> other:('b -> 'c) -> 'c + + val combine + : ('a, 'd) t + -> ('b, 'd) t + -> f:('a -> 'b -> 'c) + -> other:('d -> 'd -> 'd) + -> ('c, 'd) t + end) = struct + include M + open With_return + + let map t ~f = either t ~return:(fun x -> return (f x)) ~other + + include Monad.Make2 (struct + type nonrec ('a, 'b) t = ('a, 'b) t + + let return = return + ;; + + let bind t ~f = either t ~return:f ~other + ;; + + let map = `Custom map + end) + + module App = Applicative.Make2 (struct + type nonrec ('a, 'b) t = ('a, 'b) t + + let return = return + ;; + + let apply t1 t2 = + let return f = either t2 ~return:(fun x -> return (f x)) ~other in + either t1 ~return ~other + ;; + + let map = `Custom map + end) + + include App + + module Args = Applicative.Make_args2 (struct + type nonrec ('a, 'b) t = ('a, 'b) t + include App + end) + [@@warning "-3"] + + let combine_all = + let rec other_loop f acc = function + | [] -> other acc + | t :: ts -> + either t ~return:(fun _ -> other_loop f acc ts) + ~other:(fun o -> other_loop f (f acc o) ts) + in + let rec return_loop f acc = function + | [] -> return (List.rev acc) + | t :: ts -> + either t ~return:(fun x -> return_loop f (x :: acc) ts) + ~other:(fun o -> other_loop f o ts) + in + fun ts ~f -> return_loop f [] ts + ;; + + let combine_all_unit = + let rec other_loop f acc = function + | [] -> other acc + | t :: ts -> + either t ~return:(fun () -> other_loop f acc ts) + ~other:(fun o -> other_loop f (f acc o) ts) + in + let rec return_loop f = function + | [] -> return () + | t :: ts -> + either t ~return:(fun () -> return_loop f ts) + ~other:(fun o -> other_loop f o ts) + in + fun ts ~f -> return_loop f ts + ;; + + let to_option t = either t ~return:Option.some ~other:(fun _ -> None) + ;; + + let value t ~default = either t ~return:Fn.id ~other:(fun _ -> default) + ;; + + let with_return f = + with_return (fun ret -> + other (f (With_return.prepend ret ~f:return))) + ;; + +end + +module First = Make_focused (struct + type nonrec ('a, 'b) t = ('a, 'b) t + + let return = first + let other = second + ;; + + let either t ~return ~other = + match t with + | First x -> return x + | Second y -> other y + ;; + + let combine t1 t2 ~f ~other = + match t1, t2 with + | First x, First y -> First (f x y) + | Second x, Second y -> Second (other x y) + | Second x, _ + | _, Second x -> Second x + end) + +module Second = Make_focused (struct + type nonrec ('a, 'b) t = ('b, 'a) t + + let return = second + let other = first + ;; + + let either t ~return ~other = + match t with + | Second y -> return y + | First x -> other x + ;; + + let combine t1 t2 ~f ~other = + match t1, t2 with + | Second x, Second y -> Second (f x y) + | First x, First y -> First (other x y) + | First x, _ + | _, First x -> First x + end) + +module Export = struct + type ('f, 's) _either + = ('f, 's) t + = First of 'f + | Second of 's +end diff --git a/src/either.mli b/src/either.mli new file mode 100644 index 0000000..e3587aa --- /dev/null +++ b/src/either.mli @@ -0,0 +1 @@ +include Either_intf.Either (** @inline *) diff --git a/src/either_intf.ml b/src/either_intf.ml new file mode 100644 index 0000000..05a9d80 --- /dev/null +++ b/src/either_intf.ml @@ -0,0 +1,93 @@ +(** A type that represents values with two possibilities. + + [Either] can be seen as a generic sum type, the dual of [Tuple]. [First] is neither + more important nor less important than [Second]. + + Many functions in [Either] focus on just one constructor. The [Focused] signature + abstracts over which constructor is the focus. To use these functions, use the + [First] or [Second] modules in [S]. *) + +open! Import + +module type Focused = sig + type (+'focus, +'other) t + + include Monad.S2 with type ('a, 'b) t := ('a, 'b) t + include Applicative.S2 with type ('a, 'b) t := ('a, 'b) t + + module Args : Applicative.Args2 with type ('a, 'e) arg := ('a, 'e) t + [@@warning "-3"] + [@@deprecated "[since 2018-09] Use [ppx_let] instead."] + + val value : ('a, _) t -> default:'a -> 'a + + val to_option : ('a, _) t -> 'a option + + val with_return : ('a With_return.return -> 'b) -> ('a, 'b) t + + val combine + : ('a, 'd) t + -> ('b, 'd) t + -> f:('a -> 'b -> 'c) + -> other:('d -> 'd -> 'd) + -> ('c, 'd) t + + val combine_all : ('a, 'b) t list -> f:('b -> 'b -> 'b) -> ('a list, 'b) t + + val combine_all_unit : (unit, 'b) t list -> f:('b -> 'b -> 'b) -> (unit, 'b) t +end + +module type Either = sig + + type ('f, 's) t = + | First of 'f + | Second of 's + [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('f -> 'f -> int) -> + ('s -> 's -> int) -> ('f, 's) t -> ('f, 's) t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'f -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> 's -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> + ('f, 's) t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('f,'s) t := ('f, 's) t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Invariant.S2 with type ('a, 'b) t := ('a, 'b) t + + val swap : ('f, 's) t -> ('s, 'f) t + + val value : ('a, 'a) t -> 'a + + val iter : ('a, 'b) t -> first:('a -> unit) -> second:('b -> unit) -> unit + val value_map : ('a, 'b) t -> first:('a -> 'c) -> second:('b -> 'c) -> 'c + val map : ('a, 'b) t -> first:('a -> 'c) -> second:('b -> 'd) -> ('c, 'd) t + + val equal : ('f -> 'f -> bool) -> ('s -> 's -> bool) -> ('f, 's) t -> ('f, 's) t -> bool + + module type Focused = Focused + + module First : Focused with type ('a, 'b) t = ('a, 'b) t + module Second : Focused with type ('a, 'b) t = ('b, 'a) t + + val is_first : (_, _) t -> bool + val is_second : (_, _) t -> bool + + (** [first] and [second] are [First.return] and [Second.return]. *) + val first : 'f -> ('f, _) t + val second : 's -> (_, 's) t + + (**/**) + + module Export : sig + type ('f, 's) _either + = ('f, 's) t + = First of 'f + | Second of 's + end +end diff --git a/src/equal.ml b/src/equal.ml new file mode 100644 index 0000000..df359e6 --- /dev/null +++ b/src/equal.ml @@ -0,0 +1,41 @@ +(** This module defines signatures that are to be included in other signatures to ensure a + consistent interface to [equal] functions. There is a signature ([S], [S1], [S2], + [S3]) for each arity of type. Usage looks like: + + {[ + type t + include Equal.S with type t := t + ]} + + or + + {[ + type 'a t + include Equal.S1 with type 'a t := 'a t + ]} *) + +open! Import + +type 'a t = 'a -> 'a -> bool + +type 'a equal = 'a t + +module type S = sig + type t + val equal : t equal +end + +module type S1 = sig + type 'a t + val equal : 'a equal -> 'a t equal +end + +module type S2 = sig + type ('a, 'b) t + val equal : 'a equal -> 'b equal -> ('a, 'b) t equal +end + +module type S3 = sig + type ('a, 'b, 'c) t + val equal : 'a equal -> 'b equal -> 'c equal -> ('a, 'b, 'c) t equal +end diff --git a/src/error.ml b/src/error.ml new file mode 100644 index 0000000..34b6b1a --- /dev/null +++ b/src/error.ml @@ -0,0 +1,20 @@ +(* This module is trying to minimize dependencies on modules in Core, so as to allow + [Error] and [Or_error] to be used in various places. Please avoid adding new + dependencies. *) + +open! Import + +include Info + +let raise t = raise (to_exn t) + +let raise_s sexp = raise (create_s sexp) + +let to_info t = t +let of_info t = t + +include Pretty_printer.Register_pp(struct + type nonrec t = t + let module_name = "Base.Error" + let pp = pp + end) diff --git a/src/error.mli b/src/error.mli new file mode 100644 index 0000000..49d34da --- /dev/null +++ b/src/error.mli @@ -0,0 +1,15 @@ +(** A lazy string, implemented with [Info], but intended specifically for error + messages. *) + +open! Import + +include Info_intf.S with type t = private Info.t (** @open *) + +(** Note that the exception raised by this function maintains a reference to the [t] + passed in. *) +val raise : t -> _ + +val raise_s : Sexp.t -> _ + +val to_info : t -> Info.t +val of_info : Info.t -> t diff --git a/src/exn.ml b/src/exn.ml new file mode 100644 index 0000000..3f22fb6 --- /dev/null +++ b/src/exn.ml @@ -0,0 +1,151 @@ +open! Import + +type t = exn [@@deriving_inline sexp_of] +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_exn +[@@@end] + +let exit = Caml.exit + +exception Finally of t * t [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add ([%extension_constructor Finally]) + (function + | Finally (v0, v1) -> + let v0 = sexp_of_t v0 + and v1 = sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "src/exn.ml.Finally"; v0; v1] + | _ -> assert false) +[@@@end] +exception Reraised of string * t [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Reraised]) + (function + | Reraised (v0, v1) -> + let v0 = sexp_of_string v0 + and v1 = sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "src/exn.ml.Reraised"; v0; v1] + | _ -> assert false) +[@@@end] + +exception Sexp of Sexp.t + +(* We install a custom exn-converter rather than use: + + {[ + exception Sexp of Sexp.t [@@deriving_inline sexp][@@@end] + ]} + + to eliminate the extra wrapping of [(Sexp ...)]. *) +let () = + Sexplib.Conv.Exn_converter.add [%extension_constructor Sexp] + (function + | Sexp t -> t + | _ -> + (* Reaching this branch indicates a bug in sexplib. *) + assert false) +;; + +let create_s sexp = Sexp sexp + +let reraise exc str = + raise (Reraised (str, exc)) + +let reraisef exc format = + Printf.ksprintf (fun str () -> reraise exc str) format + +let to_string exc = Sexp.to_string_hum ~indent:2 (sexp_of_exn exc) +let to_string_mach exc = Sexp.to_string_mach (sexp_of_exn exc) + +let sexp_of_t = sexp_of_exn + +let protectx ~f x ~(finally : _ -> unit) = + let res = + try f x + with exn -> + (try finally x with final_exn -> raise (Finally (exn, final_exn))); + raise exn + in + finally x; + res +;; + +let protect ~f ~finally = protectx ~f () ~finally + +let does_raise (type a) (f : unit -> a) = + try + ignore (f () : a); + false + with _ -> + true +;; + +include Pretty_printer.Register_pp (struct + type t = exn + let pp ppf t = + match sexp_of_exn_opt t with + | Some sexp -> Sexp.pp_hum ppf sexp + | None -> Caml.Format.pp_print_string ppf (Caml.Printexc.to_string t) + ;; + let module_name = "Base.Exn" + end) + +let print_with_backtrace exc raw_backtrace = + Caml.Format.eprintf "@[<2>Uncaught exception:@\n@\n@[%a@]@]@\n@." pp exc; + if Caml.Printexc.backtrace_status () + then Caml.Printexc.print_raw_backtrace Caml.stderr raw_backtrace; + Caml.flush Caml.stderr; +;; + +let set_uncaught_exception_handler () = + Caml.Printexc.set_uncaught_exception_handler print_with_backtrace +;; + +let handle_uncaught_aux ~do_at_exit ~exit f = + try f () + with exc -> + let raw_backtrace = Caml.Printexc.get_raw_backtrace () in + (* One reason to run [do_at_exit] handlers before printing out the error message is + that it helps curses applications bring the terminal in a good state, otherwise the + error message might get corrupted. Also, the OCaml top-level uncaught exception + handler does the same. *) + if do_at_exit then (try Caml.do_at_exit () with _ -> ()); + begin + try + print_with_backtrace exc raw_backtrace + with _ -> + try + Caml.Printf.eprintf "Exn.handle_uncaught could not print; exiting anyway\n%!"; + with _ -> () + end; + exit 1 +;; + +let handle_uncaught_and_exit f = handle_uncaught_aux f ~exit ~do_at_exit:true + +let handle_uncaught ~exit:must_exit f = + handle_uncaught_aux f ~exit:(if must_exit then exit else ignore) + ~do_at_exit:must_exit + +let reraise_uncaught str func = + try func () with + | exn -> raise (Reraised (str, exn)) + +external clear_backtrace : unit -> unit = "Base_clear_caml_backtrace_pos" [@@noalloc] + +let raise_without_backtrace e = + (* We clear the backtrace to reduce confusion, so that people don't think whatever + is stored corresponds to this raise. *) + clear_backtrace (); + Caml.raise_notrace e +;; + +let initialize_module () = + set_uncaught_exception_handler (); +;; + +module Private = struct + let clear_backtrace = clear_backtrace +end diff --git a/src/exn.mli b/src/exn.mli new file mode 100644 index 0000000..bba935c --- /dev/null +++ b/src/exn.mli @@ -0,0 +1,95 @@ +(** Exceptions. + + [sexp_of_t] uses a global table of sexp converters. To register a converter for a new + exception, add [[@@deriving_inline sexp][@@@end]] to its definition. If no suitable converter is + found, the standard converter in [Printexc] will be used to generate an atomic + S-expression. *) + +open! Import + +type t = exn [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Pretty_printer.S with type t := t + +(** Raised when finalization after an exception failed, too. + The first exception argument is the one raised by the initial + function, the second exception the one raised by the finalizer. *) +exception Finally of t * t + +exception Reraised of string * t + +(** [create_s sexp] returns an exception [t] such that [phys_equal (sexp_of_t t) sexp]. + This is useful when one wants to create an exception that serves as a message and the + particular exn constructor doesn't matter. *) +val create_s : Sexp.t -> t + +(** Same as [raise], except that the backtrace is not recorded. *) +val raise_without_backtrace : t -> _ + +val reraise : t -> string -> _ + +(** Types with [format4] are hard to read, so here's an example. + + {[ + let foobar str = + try + ... + with exn -> + Exn.reraisef exn "Foobar is buggy on: %s" str () + ]} *) +val reraisef : t -> ('a, unit, string, unit -> _) format4 -> 'a + +val to_string : t -> string (** Human-readable, multi-line. *) + +val to_string_mach : t -> string (** Machine format, single-line. *) + +(** Executes [f] and afterwards executes [finally], whether [f] throws an exception or + not. *) +val protectx : f:('a -> 'b) -> 'a -> finally:('a -> unit) -> 'b + +val protect : f:(unit -> 'a) -> finally:(unit -> unit) -> 'a + +(** [handle_uncaught ~exit f] catches an exception escaping [f] and prints an error + message to stderr. Exits with return code 1 if [exit] is [true], and returns unit + otherwise. + + Note that since OCaml 4.02.0, you don't need to use this at the entry point of your + program, as the OCaml runtime will do better than this function. *) +val handle_uncaught : exit:bool -> (unit -> unit) -> unit + +(** [handle_uncaught_and_exit f] returns [f ()], unless that raises, in which case it + prints the exception and exits nonzero. *) +val handle_uncaught_and_exit : (unit -> 'a) -> 'a + +(** Traces exceptions passing through. Useful because in practice, backtraces still don't + seem to work. + + Example: + {[ + let rogue_function () = if Random.bool () then failwith "foo" else 3 + let traced_function () = Exn.reraise_uncaught "rogue_function" rogue_function + traced_function ();; + ]} + {v : Program died with Reraised("rogue_function", Failure "foo") v} *) +val reraise_uncaught : string -> (unit -> 'a) -> 'a + +(** [does_raise f] returns [true] iff [f ()] raises, which is often useful in unit + tests. *) +val does_raise : (unit -> _) -> bool + +(** User code never calls this. It is called in [std_kernel.ml] as a top-level side + effect to change the display of exceptions and install an uncaught-exception + printer. *) +val initialize_module : unit -> unit + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + val clear_backtrace : unit -> unit +end diff --git a/src/exn_stubs.c b/src/exn_stubs.c new file mode 100644 index 0000000..e6f905a --- /dev/null +++ b/src/exn_stubs.c @@ -0,0 +1,8 @@ +#include <caml/mlvalues.h> + +extern int caml_backtrace_pos; + +CAMLprim value Base_clear_caml_backtrace_pos () { + caml_backtrace_pos = 0; + return Val_unit; +} diff --git a/src/field.ml b/src/field.ml new file mode 100644 index 0000000..a4b2406 --- /dev/null +++ b/src/field.ml @@ -0,0 +1,67 @@ +(* The type [t] should be abstract to make the fset and set functions unavailable + for private types at the level of types (and not by putting None in the field). + Unfortunately, making the type abstract means that when creating fields (through + a [create] function) value restriction kicks in. This is worked around by instead + not making the type abstract, but forcing anyone breaking the abstraction to use + the [For_generated_code] module, making it obvious to any reader that something ugly + is going on. + t_with_perm (and derivatives) is the type that users really use. It is a constructor + because: + 1. it makes type errors more readable (less aliasing) + 2. the typer in ocaml 4.01 allows this: + + {[ + module A = struct + type t = {a : int} + end + type t = A.t + let f (x : t) = x.a + ]} + + (although with Warning 40: a is used out of scope) + which means that if [t_with_perm] was really an alias on [For_generated_code.t], + people could say [t.setter] and break the abstraction with no indication that + something ugly is going on in the source code. + The warning is (I think) for people who want to make their code compatible with + previous versions of ocaml, so we may very well turn it off. + + The type t_with_perm could also have been a [unit -> For_generated_code.t] to work + around value restriction and then [For_generated_code.t] would have been a proper + abstract type, but it looks like it could impact performance (for example, a fold on a + record type with 40 fields would actually allocate the 40 [For_generated_code.t]'s at + every single fold.) *) + +module For_generated_code = struct + type ('perm, 'record, 'field) t = { + force_variance : 'perm -> unit; + (* force [t] to be contravariant in ['perm], because phantom type variables on + concrete types don't work that well otherwise (using :> can remove them easily) *) + name : string; + setter : ('record -> 'field -> unit) option; + getter : ('record -> 'field); + fset : ('record -> 'field -> 'record); + } +end + +type ('perm, 'record, 'field) t_with_perm = + | Field of ('perm, 'record, 'field) For_generated_code.t +type ('record, 'field) t = ([ `Read | `Set_and_create], 'record, 'field) t_with_perm +type ('record, 'field) readonly_t = ([ `Read ], 'record, 'field) t_with_perm + +let name (Field field) = field.name + +let get (Field field) r = field.getter r + +let fset (Field field) r v = field.fset r v + +let setter (Field field) = field.setter + +type ('perm, 'record, 'result) user = + { f : 'field. ('perm, 'record, 'field) t_with_perm -> 'result } + +let map (Field field) r ~f = field.fset r (f (field.getter r)) + +let updater (Field field) = + match field.setter with + | None -> None + | Some setter -> Some (fun r ~f -> setter r (f (field.getter r))) diff --git a/src/field.mli b/src/field.mli new file mode 100644 index 0000000..f4f211c --- /dev/null +++ b/src/field.mli @@ -0,0 +1,39 @@ +(** OCaml record field. *) + +(**/**) +module For_generated_code : sig + (*_ don't use this by hand, it is only meant for ppx_fields_conv *) + type ('perm, 'record, 'field) t = { + force_variance : 'perm -> unit; + name : string; + setter : ('record -> 'field -> unit) option; + getter : ('record -> 'field); + fset : ('record -> 'field -> 'record); + } +end +(**/**) + +(** ['record] is the type of the record. ['field] is the type of the + values stored in the record field with name [name]. ['perm] is a way + of restricting the operations that can be used. *) +type ('perm, 'record, 'field) t_with_perm = + | Field of ('perm, 'record, 'field) For_generated_code.t + +(** A record field with no restrictions. *) +type ('record, 'field) t = ([ `Read | `Set_and_create], 'record, 'field) t_with_perm + +(** A record that can only be read, because it belongs to a private type. *) +type ('record, 'field) readonly_t = ([ `Read ], 'record, 'field) t_with_perm + +val name : (_, _, _) t_with_perm -> string +val get : (_, 'r, 'a) t_with_perm -> 'r -> 'a +val fset : ([> `Set_and_create], 'r, 'a) t_with_perm -> 'r -> 'a -> 'r +val setter : ([> `Set_and_create], 'r, 'a) t_with_perm -> ('r -> 'a -> unit) option + +val map : ([> `Set_and_create], 'r, 'a) t_with_perm -> 'r -> f:('a -> 'a) -> 'r +val updater + : ([> `Set_and_create], 'r, 'a) t_with_perm + -> ('r -> f:('a -> 'a) -> unit) option + +type ('perm, 'record, 'result) user = + { f : 'field. ('perm, 'record, 'field) t_with_perm -> 'result } diff --git a/src/fieldslib.ml b/src/fieldslib.ml new file mode 100644 index 0000000..dd41b03 --- /dev/null +++ b/src/fieldslib.ml @@ -0,0 +1,3 @@ +(** This module is for use by ppx_fields_conv, and is thus not in the interface of + Base. *) +module Field = Field diff --git a/src/float.ml b/src/float.ml new file mode 100644 index 0000000..41e277b --- /dev/null +++ b/src/float.ml @@ -0,0 +1,1099 @@ +open! Import +open! Printf + +module Bytes = Bytes0 +include Float0 + +let ceil = Caml.ceil +let floor = Caml.floor +let mod_float = Caml.mod_float +let modf = Caml.modf + +let raise_s = Error.raise_s + +module T = struct + type t = float [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_float + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_float in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = float_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_float + [@@@end] + let compare = Float_replace_polymorphic_compare.compare +end + +include T +include Comparator.Make(T) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open Float_replace_polymorphic_compare + +let to_float x = x +let of_float x = x + +let of_string s = + try Caml.float_of_string s with + | _ -> invalid_argf "Float.of_string %s" s () +;; + +external format_float : string -> float -> string = "caml_format_float" + +(* Stolen from [pervasives.ml]. Adds a "." at the end if needed. It is in + [pervasives.mli], but it also says not to use it directly, so we copy and paste the + code. It makes the assumption on the string passed in argument that it was returned by + [format_float]. *) +let valid_float_lexem s = + let l = String.length s in + let rec loop i = + if Int_replace_polymorphic_compare.(>=) i l then s ^ "." else + match s.[i] with + | '0' .. '9' | '-' -> loop (i + 1) + | _ -> s + in + loop 0 +;; + +(* Let [y] be a power of 2. Then the next representable float is: + [z = y * (1 + 2 ** -52)] + and the previous one is + [x = y * (1 - 2 ** -53)] + + In general, every two adjacent floats are within a factor of between [1 + 2**-53] + and [1 + 2**-52] from each other, that is within [1 + 1.1e-16] and [1 + 2.3e-16]. + + So if the decimal representation of a float starts with "1", then its adjacent floats + will usually differ from it by 1, and sometimes by 2, at the 17th significant digit + (counting from 1). + + On the other hand, if the decimal representation starts with "9", then the adjacent + floats will be off by no more than 23 at the 16th and 17th significant digits. + + E.g.: + + {v + # sprintf "%.17g" (1024. *. (1. -. 2.** (-53.)));; + 11111111 + 1234 5678901234567 + - : string = "1023.9999999999999" + v} + Printing a couple of extra digits reveals that the difference indeed is roughly 11 at + digits 17th and 18th (that is, 13th and 14th after "."): + + {v + # sprintf "%.19g" (1024. *. (1. -. 2.** (-53.)));; + 1111111111 + 1234 567890123456789 + - : string = "1023.999999999999886" + v} + + The ulp (the difference between adjacent floats) is twice as big on the other side of + 1024.: + + {v + # sprintf "%.19g" (1024. *. (1. +. 2.** (-52.)));; + 1111111111 + 1234 567890123456789 + - : string = "1024.000000000000227" + v} + + Now take a power of 2 which starts with 99: + + {v + # 2.**93. ;; + 1111111111 + 1 23456789012345678 + - : float = 9.9035203142830422e+27 + + # 2.**93. *. (1. +. 2.** (-52.));; + - : float = 9.9035203142830444e+27 + + # 2.**93. *. (1. -. 2.** (-53.));; + - : float = 9.9035203142830411e+27 + v} + + The difference between 2**93 and its two neighbors is slightly more than, respectively, + 1 and 2 at significant digit 16. + + Those examples show that: + - 17 significant digits is always sufficient to represent a float without ambiguity + - 15th significant digit can always be represented accurately + - converting a decimal number with 16 significant digits to its nearest float and back + can change the last decimal digit by no more than 1 + + To make sure that floats obtained by conversion from decimal fractions (e.g. "3.14") + are printed without trailing non-zero digits, one should choose the first among the + '%.15g', '%.16g', and '%.17g' representations which does round-trip: + + {v + # sprintf "%.15g" 3.14;; + - : string = "3.14" (* pick this one *) + # sprintf "%.16g" 3.14;; + - : string = "3.14" + # sprintf "%.17g" 3.14;; + - : string = "3.1400000000000001" (* do not pick this one *) + + # sprintf "%.15g" 8.000000000000002;; + - : string = "8" (* do not pick this one--does not round-trip *) + # sprintf "%.16g" 8.000000000000002;; + - : string = "8.000000000000002" (* prefer this one *) + # sprintf "%.17g" 8.000000000000002;; + - : string = "8.0000000000000018" (* this one has one digit of junk at the end *) + v} + + Skipping the '%.16g' in the above procedure saves us some time, but it means that, as + seen in the second example above, occasionally numbers with exactly 16 significant + digits will have an error introduced at the 17th digit. That is probably OK for + typical use, because a number with 16 significant digits is "ugly" already. Adding one + more doesn't make it much worse for a human reader. + + On the other hand, we cannot skip '%.15g' and only look at '%.16g' and '%.17g', since + the inaccuracy at the 16th digit might introduce the noise we want to avoid: + + {v + # sprintf "%.15g" 9.992;; + - : string = "9.992" (* pick this one *) + # sprintf "%.16g" 9.992;; + - : string = "9.992000000000001" (* do not pick this one--junk at the end *) + # sprintf "%.17g" 9.992;; + - : string = "9.9920000000000009" + v} +*) +let to_string x = + valid_float_lexem ( + let y = format_float "%.15g" x in + if float_of_string y = x then + y + else + format_float "%.17g" x) +;; + +let nan = Caml.nan + +let infinity = Caml.infinity +let neg_infinity = Caml.neg_infinity + +let max_value = infinity +let min_value = neg_infinity + +let max_finite_value = Caml.max_float + +let min_positive_subnormal_value = 2. ** -1074. +let min_positive_normal_value = 2. ** -1022. + +let zero = 0. +let one = 1. +let minus_one = -1. + +let pi = 0x3.243F6A8885A308D313198A2E037073 +let sqrt_pi = 0x1.C5BF891B4EF6AA79C3B0520D5DB938 +let sqrt_2pi = 0x2.81B263FEC4E0B2CAF9483F5CE459DC +let euler = 0x0.93C467E37DB0C7A4D1BE3F810152CB + +(* The bits of INRIA's [Pervasives] that we just want to expose in + [Float]. Most are already deprecated in [Pervasives], and + eventually all of them should be. *) +include (Caml : sig + external frexp : float -> float * int = "caml_frexp_float" + external ldexp : (float [@unboxed]) -> (int [@untagged]) -> (float [@unboxed]) = "caml_ldexp_float" "caml_ldexp_float_unboxed" [@@noalloc] + external log10 : float -> float = "caml_log10_float" "log10" + [@@unboxed] [@@noalloc] + external expm1 : float -> float = "caml_expm1_float" "caml_expm1" + [@@unboxed] [@@noalloc] + external log1p : float -> float = "caml_log1p_float" "caml_log1p" + [@@unboxed] [@@noalloc] + external copysign : float -> float -> float = "caml_copysign_float" "caml_copysign" + [@@unboxed] [@@noalloc] + external cos : float -> float = "caml_cos_float" "cos" + [@@unboxed] [@@noalloc] + external sin : float -> float = "caml_sin_float" "sin" + [@@unboxed] [@@noalloc] + external tan : float -> float = "caml_tan_float" "tan" + [@@unboxed] [@@noalloc] + external acos : float -> float = "caml_acos_float" "acos" + [@@unboxed] [@@noalloc] + external asin : float -> float = "caml_asin_float" "asin" + [@@unboxed] [@@noalloc] + external atan : float -> float = "caml_atan_float" "atan" + [@@unboxed] [@@noalloc] + external atan2 : float -> float -> float = "caml_atan2_float" "atan2" + [@@unboxed] [@@noalloc] + external hypot : float -> float -> float = "caml_hypot_float" "caml_hypot" + [@@unboxed] [@@noalloc] + external cosh : float -> float = "caml_cosh_float" "cosh" + [@@unboxed] [@@noalloc] + external sinh : float -> float = "caml_sinh_float" "sinh" + [@@unboxed] [@@noalloc] + external tanh : float -> float = "caml_tanh_float" "tanh" + [@@unboxed] [@@noalloc] + external sqrt : float -> float = "caml_sqrt_float" "sqrt" + [@@unboxed] [@@noalloc] + external exp : float -> float = "caml_exp_float" "exp" + [@@unboxed] [@@noalloc] + external log : float -> float = "caml_log_float" "log" + [@@unboxed] [@@noalloc] + end) + +(* We need this indirection because these are exposed as "val" instead of "external" *) +let frexp = frexp +let ldexp = ldexp + +let epsilon_float = Caml.epsilon_float + +let of_int = Int.to_float +let to_int = Int.of_float + +let of_int63 i = Int63.to_float i + +let of_int64 i = Caml.Int64.to_float i + +let to_int64 = Caml.Int64.of_float + +let iround_lbound = lower_bound_for_int Int.num_bits +let iround_ubound = upper_bound_for_int Int.num_bits + +(* The performance of the "exn" rounding functions is important, so they are written + out separately, and tuned individually. (We could have the option versions call + the "exn" versions, but that imposes arguably gratuitous overhead---especially + in the case where the capture of backtraces is enabled upon "with"---and that seems + not worth it when compared to the relatively small amount of code duplication.) *) + +(* Error reporting below is very carefully arranged so that, e.g., [iround_nearest_exn] + itself can be inlined into callers such that they don't need to allocate a box for the + [float] argument. This is done with a box [box] function carefully chosen to allow the + compiler to create a separate box for the float only in error cases. See, e.g., + [../../zero/test/price_test.ml] for a mechanical test of this property when building + with [X_LIBRARY_INLINING=true]. *) + +let iround_up t = + if t > 0.0 then begin + let t' = ceil t in + if t' <= iround_ubound then + Some (Int.of_float_unchecked t') + else + None + end + else begin + if t >= iround_lbound then + Some (Int.of_float_unchecked t) + else + None + end + +let iround_up_exn t = + if t > 0.0 then begin + let t' = ceil t in + if t' <= iround_ubound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_up_exn: argument (%f) is too large" (box t) () + end + else begin + if t >= iround_lbound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_up_exn: argument (%f) is too small or NaN" (box t) () + end +[@@ocaml.inline always] + +let iround_down t = + if t >= 0.0 then begin + if t <= iround_ubound then + Some (Int.of_float_unchecked t) + else + None + end + else begin + let t' = floor t in + if t' >= iround_lbound then + Some (Int.of_float_unchecked t') + else + None + end + +let iround_down_exn t = + if t >= 0.0 then begin + if t <= iround_ubound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_down_exn: argument (%f) is too large" (box t) () + end + else begin + let t' = floor t in + if t' >= iround_lbound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_down_exn: argument (%f) is too small or NaN" (box t) () + end +[@@ocaml.inline always] + +let iround_towards_zero t = + if t >= iround_lbound && t <= iround_ubound then + Some (Int.of_float_unchecked t) + else + None + +let iround_towards_zero_exn t = + if t >= iround_lbound && t <= iround_ubound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_towards_zero_exn: argument (%f) is out of range or NaN" + (box t) + () +[@@ocaml.inline always] + +(* Outside of the range (round_nearest_lb..round_nearest_ub), all representable doubles + are integers in the mathematical sense, and [round_nearest] should be identity. + + However, for odd numbers with the absolute value between 2**52 and 2**53, the formula + [round_nearest x = floor (x + 0.5)] does not hold: + + {v + # let naive_round_nearest x = floor (x +. 0.5);; + # let x = 2. ** 52. +. 1.;; + val x : float = 4503599627370497. + # naive_round_nearest x;; + - : float = 4503599627370498. + v} +*) + +let round_nearest_lb = -.(2. ** 52.) +let round_nearest_ub = 2. ** 52. + +(* For [x = one_ulp `Down 0.5], the formula [floor (x +. 0.5)] for rounding to nearest + does not work, because the exact result is halfway between [one_ulp `Down 1.] and [1.], + and it gets rounded up to [1.] due to the round-ties-to-even rule. *) +let one_ulp_less_than_half = one_ulp `Down 0.5 +let add_half_for_round_nearest t = + t +. (if t = one_ulp_less_than_half then + one_ulp_less_than_half (* since t < 0.5, make sure the result is < 1.0 *) + else + 0.5) + +let iround_nearest_32 t = + if t >= 0. then + let t' = add_half_for_round_nearest t in + if t' <= iround_ubound then + Some (Int.of_float_unchecked t') + else + None + else + let t' = floor (t +. 0.5) in + if t' >= iround_lbound then + Some (Int.of_float_unchecked t') + else + None + +let iround_nearest_64 t = + if t >= 0. then + if t < round_nearest_ub then + Some (Int.of_float_unchecked (add_half_for_round_nearest t)) + else + if t <= iround_ubound then + Some (Int.of_float_unchecked t) + else + None + else + if t > round_nearest_lb then + Some (Int.of_float_unchecked (floor (t +. 0.5))) + else + if t >= iround_lbound then + Some (Int.of_float_unchecked t) + else + None + +let iround_nearest = + match Word_size.word_size with + | W64 -> iround_nearest_64 + | W32 -> iround_nearest_32 + +let iround_nearest_exn_32 t = + if t >= 0. then + let t' = add_half_for_round_nearest t in + if t' <= iround_ubound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too large" (box t) () + else + let t' = floor (t +. 0.5) in + if t' >= iround_lbound then + Int.of_float_unchecked t' + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too small" (box t) () + +let iround_nearest_exn_64 t = + if t >= 0. then + if t < round_nearest_ub then + Int.of_float_unchecked (add_half_for_round_nearest t) + else + if t <= iround_ubound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too large" (box t) () + else + if t > round_nearest_lb then + Int.of_float_unchecked (floor (t +. 0.5)) + else + if t >= iround_lbound then + Int.of_float_unchecked t + else + invalid_argf "Float.iround_nearest_exn: argument (%f) is too small or NaN" (box t) () +[@@ocaml.inline always] + +let iround_nearest_exn = + match Word_size.word_size with + | W64 -> iround_nearest_exn_64 + | W32 -> iround_nearest_exn_32 + +(* The following [iround_exn] and [iround] functions are slower than the ones above. + Their equivalence to those functions is tested in the unit tests below. *) + +let iround_exn ?(dir=`Nearest) t = + match dir with + | `Zero -> iround_towards_zero_exn t + | `Nearest -> iround_nearest_exn t + | `Up -> iround_up_exn t + | `Down -> iround_down_exn t +[@@inline] + +let iround ?(dir=`Nearest) t = + try Some (iround_exn ~dir t) + with _ -> None + +let is_inf x = + match Caml.classify_float x with + | FP_infinite -> true + | _ -> false + +let min_inan (x : t) y = + if is_nan y then x + else if is_nan x then y + else if x < y then x else y + +let max_inan (x : t) y = + if is_nan y then x + else if is_nan x then y + else if x > y then x else y + +let add = (+.) +let sub = (-.) +let neg = (~-.) +let abs = Caml.abs_float +let scale = ( *. ) + +let square x = x *. x + +module Parts : sig + type t + + val fractional : t -> float + val integral : t -> float + val modf : float -> t +end = struct + type t = float * float + + let fractional t = fst t + let integral t = snd t + let modf = modf +end +let modf = Parts.modf + +let round_down = floor + +let round_up = ceil + +let round_towards_zero t = + if t >= 0. + then round_down t + else round_up t + +(* see the comment above [round_nearest_lb] and [round_nearest_ub] for an explanation *) +let round_nearest t = + if t > round_nearest_lb && t < round_nearest_ub then + floor (add_half_for_round_nearest t) + else + t +. 0. + +let round_nearest_half_to_even t = + if t <= round_nearest_lb || t >= round_nearest_ub then + t +. 0. + else + let floor = floor t in + (* [ceil_or_succ = if t is an integer then t +. 1. else ceil t]. Faster than [ceil]. *) + let ceil_or_succ = floor +. 1. in + let diff_floor = t -. floor in + let diff_ceil = ceil_or_succ -. t in + if diff_floor < diff_ceil then + floor + else + if diff_floor > diff_ceil then + ceil_or_succ + else + (* exact tie, pick the even *) + if mod_float floor 2. = 0. then + floor + else + ceil_or_succ + +let int63_round_lbound = lower_bound_for_int Int63.num_bits +let int63_round_ubound = upper_bound_for_int Int63.num_bits + +let int63_round_up_exn t = + if t > 0.0 then begin + let t' = ceil t in + if t' <= int63_round_ubound then + Int63.of_float_unchecked t' + else + invalid_argf "Float.int63_round_up_exn: argument (%f) is too large" (Float0.box t) () + end + else begin + if t >= int63_round_lbound then + Int63.of_float_unchecked t + else + invalid_argf "Float.int63_round_up_exn: argument (%f) is too small or NaN" + (Float0.box t) () + end + +let int63_round_down_exn t = + if t >= 0.0 then begin + if t <= int63_round_ubound then + Int63.of_float_unchecked t + else + invalid_argf "Float.int63_round_down_exn: argument (%f) is too large" + (Float0.box t) () + end + else begin + let t' = floor t in + if t' >= int63_round_lbound then + Int63.of_float_unchecked t' + else + invalid_argf "Float.int63_round_down_exn: argument (%f) is too small or NaN" + (Float0.box t) () + end + +let int63_round_nearest_portable_alloc_exn t0 = + let t = round_nearest t0 in + if t > 0. + then begin + if t <= int63_round_ubound + then Int63.of_float_unchecked t + else invalid_argf + "Float.int63_round_nearest_portable_alloc_exn: argument (%f) is too large" + (box t0) + () + end + else begin + if t >= int63_round_lbound + then Int63.of_float_unchecked t + else invalid_argf + "Float.int63_round_nearest_portable_alloc_exn: argument (%f) is too small or NaN" + (box t0) + () + end + +let int63_round_nearest_arch64_noalloc_exn f = Int63.of_int (iround_nearest_exn f) + +let int63_round_nearest_exn = + match Word_size.word_size with + | W64 -> int63_round_nearest_arch64_noalloc_exn + | W32 -> int63_round_nearest_portable_alloc_exn + +let round ?(dir=`Nearest) t = + match dir with + | `Nearest -> round_nearest t + | `Down -> round_down t + | `Up -> round_up t + | `Zero -> round_towards_zero t + +module Class = struct + type t = + | Infinite + | Nan + | Normal + | Subnormal + | Zero + [@@deriving_inline compare, enumerate, sexp] + let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare + let all : t list = [Infinite; Nan; Normal; Subnormal; Zero] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/float.ml.Class.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("infinite"|"Infinite") -> Infinite + | Ppx_sexp_conv_lib.Sexp.Atom ("nan"|"Nan") -> Nan + | Ppx_sexp_conv_lib.Sexp.Atom ("normal"|"Normal") -> Normal + | Ppx_sexp_conv_lib.Sexp.Atom ("subnormal"|"Subnormal") -> Subnormal + | Ppx_sexp_conv_lib.Sexp.Atom ("zero"|"Zero") -> Zero + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("infinite"|"Infinite"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("nan"|"Nan"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("normal"|"Normal"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("subnormal"|"Subnormal"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("zero"|"Zero"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Infinite -> Ppx_sexp_conv_lib.Sexp.Atom "Infinite" + | Nan -> Ppx_sexp_conv_lib.Sexp.Atom "Nan" + | Normal -> Ppx_sexp_conv_lib.Sexp.Atom "Normal" + | Subnormal -> Ppx_sexp_conv_lib.Sexp.Atom "Subnormal" + | Zero -> Ppx_sexp_conv_lib.Sexp.Atom "Zero" + [@@@end] + + let to_string t = string_of_sexp (sexp_of_t t) + let of_string s = t_of_sexp (sexp_of_string s) +end + +let classify t = + let module C = Class in + match Caml.classify_float t with + | FP_normal -> C.Normal + | FP_subnormal -> C.Subnormal + | FP_zero -> C.Zero + | FP_infinite -> C.Infinite + | FP_nan -> C.Nan +;; + +let is_finite t = + not (t = infinity || t = neg_infinity || is_nan t) +;; + +let insert_underscores ?(delimiter='_') ?(strip_zero=false) string = + match String.lsplit2 string ~on:'.' with + | None -> + Int_conversions.insert_delimiter string ~delimiter + | Some (left, right) -> + let left = Int_conversions.insert_delimiter left ~delimiter in + let right = + if strip_zero + then String.rstrip right ~drop:(fun c -> Char.(=) c '0') + else right + in + match right with + | "" -> left + | _ -> left ^ "." ^ right +;; + +let to_string_hum ?delimiter ?(decimals=3) ?strip_zero f = + if Int_replace_polymorphic_compare.(<) decimals 0 then + invalid_argf "to_string_hum: invalid argument ~decimals=%d" decimals (); + match classify f with + | Class.Infinite -> if f > 0. then "inf" else "-inf" + | Class.Nan -> "nan" + | Class.Normal + | Class.Subnormal + | Class.Zero -> insert_underscores (sprintf "%.*f" decimals f) ?delimiter ?strip_zero +;; + +let sexp_of_t t = + let sexp = sexp_of_t t in + match !Sexp.of_float_style with + | `No_underscores -> sexp + | `Underscores -> + match sexp with + | List _ -> raise_s (Sexp.message "[sexp_of_float] produced strange sexp" + ["sexp", Sexp.sexp_of_t sexp]) + | Atom string -> + if String.contains string 'E' + then sexp + else Atom (insert_underscores string) +;; + +let to_padded_compact_string t = + + (* Round a ratio toward the nearest integer, resolving ties toward the nearest even + number. For sane inputs (in particular, when [denominator] is an integer and + [abs numerator < 2e52]) this should be accurate. Otherwise, the result might be a + little bit off, but we don't really use that case. *) + let iround_ratio_exn ~numerator ~denominator = + let k = floor (numerator /. denominator) in + (* if [abs k < 2e53], then both [k] and [k +. 1.] are accurately represented, and in + particular [k +. 1. > k]. If [denominator] is also an integer, and + [abs (denominator *. (k +. 1)) < 2e53] (and in some other cases, too), then [lower] + and [higher] are actually both accurate. Since (roughly) + [numerator = denominator *. k] then for [abs numerator < 2e52] we should be + fine. *) + let lower = denominator *. k in + let higher = denominator *. (k +. 1.) in + (* Subtracting numbers within a factor of two from each other is accurate. + So either the two subtractions below are accurate, or k = 0, or k = -1. + In case of a tie, round to even. *) + let diff_right = higher -. numerator in + let diff_left = numerator -. lower in + let k = iround_nearest_exn k in + if diff_right < diff_left then + k + 1 + else if diff_right > diff_left then + k + else + (* a tie *) + if Int_replace_polymorphic_compare.(=) (k mod 2) 0 then k else k + 1 + in + + match classify t with + | Class.Infinite -> if t < 0.0 then "-inf " else "inf " + | Class.Nan -> "nan " + | Class.Subnormal | Class.Normal | Class.Zero -> + let go t = + let conv_one t = + assert (0. <= t && t < 999.95); + let x = format_float "%.1f" t in + (* Fix the ".0" suffix *) + if String.is_suffix x ~suffix:".0" + then begin + let x = Bytes.of_string x in + let n = Bytes.length x in + Bytes.set x (n - 1) ' '; + Bytes.set x (n - 2) ' '; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:x + end + else x + in + let conv mag t denominator = + assert (denominator = 100. && t >= 999.95 + || denominator >= 100_000. && t >= round_nearest (denominator *. 9.999_5)); + assert (t < round_nearest (denominator *. 9_999.5)); + let i, d = + let k = iround_ratio_exn ~numerator:t ~denominator in + (* [mod] is okay here because we know i >= 0. *) + k / 10, k mod 10 + in + let open Int_replace_polymorphic_compare in + assert (0 <= i && i < 1000); + assert (0 <= d && d < 10); + if d = 0 then + sprintf "%d%c " i mag + else + sprintf "%d%c%d" i mag d + in + (* While the standard metric prefixes (e.g. capital "M" rather than "m", [1]) are + nominally more correct, this hinders readability in our case. E.g., 10G6 and + 1066 look too similar. That's an extreme example, but in general k,m,g,t,p + probably stand out better than K,M,G,T,P when interspersed with digits. + + [1] http://en.wikipedia.org/wiki/Metric_prefix *) + (* The trick here is that: + - the first boundary (999.95) as a float is slightly over-represented (so it is + better approximated as "1k" than as "999.9"), + - the other boundaries are accurately represented, because they are integers. + That's why the strict equalities below do exactly what we want. *) + if t < 999.95E0 then conv_one t + else if t < 999.95E3 then conv 'k' t 100. + else if t < 999.95E6 then conv 'm' t 100_000. + else if t < 999.95E9 then conv 'g' t 100_000_000. + else if t < 999.95E12 then conv 't' t 100_000_000_000. + else if t < 999.95E15 then conv 'p' t 100_000_000_000_000. + else sprintf "%.1e" t + in + if t >= 0. + then go t + else "-" ^ (go ~-.t) + +(* Performance note: Initializing the accumulator to 1 results in one extra + multiply; e.g., to compute x ** 4, we in principle only need 2 multiplies, + but this function will have 3 multiplies. However, attempts to avoid this + (like decrementing n and initializing accum to be x, or handling small + exponents as a special case) have not yielded anything that is a net + improvement. +*) +let int_pow x n = + let open Int_replace_polymorphic_compare in + if n = 0 then + 1. + else begin + (* Using [x +. (-0.)] on the following line convinces the compiler to avoid a certain + boxing (that would result in allocation in each iteration). Soon, the compiler + shouldn't need this "hint" to avoid the boxing. The reason we add -0 rather than 0 + is that [x +. (-0.)] is apparently always the same as [x], whereas [x +. 0.] is + not, in that it sends [-0.] to [0.]. This makes a difference because we want + [int_pow (-0.) (-1)] to return neg_infinity just like [-0. ** -1.] would. *) + let x = ref (x +. (-0.)) in + let n = ref n in + let accum = ref 1. in + if !n < 0 then begin + (* x ** n = (1/x) ** -n *) + x := 1. /. !x; + n := ~- !n; + if !n < 0 then begin + (* n must have been min_int, so it is now so big that it has wrapped around. + We decrement it so that it looks positive again, but accordingly have + to put an extra factor of x in the accumulator. + *) + accum := !x; + decr n + end + end; + (* Letting [a] denote (the original value of) [x ** n], we maintain + the invariant that [(x ** n) *. accum = a]. *) + while !n > 1 do + if !n land 1 <> 0 then accum := !x *. !accum; + x := !x *. !x; + n := !n lsr 1 + done; + (* n is necessarily 1 at this point, so there is one additional + multiplication by x. *) + !x *. !accum + end + +let round_gen x ~how = + if x = 0. then 0. + else if not (is_finite x) then x + else begin + (* Significant digits and decimal digits. *) + let sd, dd = + match how with + | `significant_digits sd -> + let dd = sd - to_int (round_up (log10 (abs x))) in + sd, dd + | `decimal_digits dd -> + let sd = dd + to_int (round_up (log10 (abs x))) in + sd, dd + in + let open Int_replace_polymorphic_compare in + if sd < 0 + then 0. + else if sd >= 17 + then x + else + (* Choose the order that is exactly representable as a float. Small positive + integers are, but their inverses in most cases are not. *) + let abs_dd = Int.abs dd in + if abs_dd > 22 || sd >= 16 + (* 10**22 is exactly representable as a float, but 10**23 is not, so use the slow + path. Similarly, if we need 16 significant digits in the result, then the integer + [round_nearest (x <op> order)] might not be exactly representable as a float, since + for some ranges we only have 15 digits of precision guaranteed. + + That said, we are still rounding twice here: + + 1) first time when rounding [x *. order] or [x /. order] to the nearest float + (just the normal way floating-point multiplication or division works), + + 2) second time when applying [round_nearest_half_to_even] to the result of the + above operation + + So for arguments within an ulp from a tie we might still produce an off-by-one + result. *) + then of_string (sprintf "%.*g" sd x) + else + let order = int_pow 10. abs_dd in + if dd >= 0 + then round_nearest_half_to_even (x *. order) /. order + else round_nearest_half_to_even (x /. order) *. order + end + +let round_significant x ~significant_digits = + if Int_replace_polymorphic_compare.(<=) significant_digits 0 then + raise (Invalid_argument + ("Float.round_significant: invalid argument significant_digits:" + ^ Int.to_string significant_digits)) + else + round_gen x ~how:(`significant_digits significant_digits) + +let round_decimal x ~decimal_digits = + round_gen x ~how:(`decimal_digits decimal_digits) + +let between t ~low ~high = low <= t && t <= high + +let clamp_exn t ~min ~max = + (* Also fails if [min] or [max] is nan *) + assert (min <= max); + (* clamp_unchecked is in float0.ml *) + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + (* Also fails if [min] or [max] is nan *) + if min <= max then + Ok (clamp_unchecked t ~min ~max) + else + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + +let ( + ) = ( +. ) +let ( - ) = ( -. ) +let ( * ) = ( *. ) +let ( ** ) = Caml.( ** ) +let ( / ) = ( /. ) +let ( ~- ) = ( ~-. ) + +let sign_exn t : Sign.t = + if t > 0. + then Pos + else if t < 0. + then Neg + else if t = 0. + then Zero + else Error.raise_s (Sexp.message "Float.sign_exn of NAN" + ["", sexp_of_t t]) + +let sign_or_nan t : Sign_or_nan.t = + if t > 0. + then Pos + else if t < 0. + then Neg + else if t = 0. + then Zero + else Nan + +let ieee_negative t = + let bits = Caml.Int64.bits_of_float t in + Poly.(bits < Caml.Int64.zero) + +let exponent_bits = 11 +let mantissa_bits = 52 + +let exponent_mask64 = Int64.((shift_left one exponent_bits) - one) +let exponent_mask = Int64.to_int_exn exponent_mask64 +let mantissa_mask = Int63.((shift_left one mantissa_bits) - one) +let mantissa_mask64 = Int63.to_int64 mantissa_mask + +let ieee_exponent t = + let bits = Caml.Int64.bits_of_float t in + Int64.((bit_and (shift_right_logical bits mantissa_bits) exponent_mask64)) + |> Caml.Int64.to_int + +let ieee_mantissa t = + let bits = Caml.Int64.bits_of_float t in + Int63.of_int64_exn Caml.Int64.(logand bits mantissa_mask64) + +let create_ieee_exn ~negative ~exponent ~mantissa = + if Int.(bit_and exponent exponent_mask <> exponent) + then failwithf "exponent %d out of range [0, %d]" + exponent exponent_mask () + else if Int63.(bit_and mantissa mantissa_mask <> mantissa) + then failwithf "mantissa %s out of range [0, %s]" + (Int63.to_string mantissa) (Int63.to_string mantissa_mask) () + else + let sign_bits = if negative then Caml.Int64.min_int else Caml.Int64.zero in + let expt_bits = Caml.Int64.shift_left (Caml.Int64.of_int exponent) mantissa_bits in + let mant_bits = Int63.to_int64 mantissa in + let bits = Caml.Int64.(logor sign_bits (logor expt_bits mant_bits)) in + Caml.Int64.float_of_bits bits + +let create_ieee ~negative ~exponent ~mantissa = + Or_error.try_with (fun () -> create_ieee_exn ~negative ~exponent ~mantissa) + +module Terse = struct + type nonrec t = t + let t_of_sexp = t_of_sexp + + let to_string x = Printf.sprintf "%.8G" x + let sexp_of_t x = Sexp.Atom (to_string x) + let of_string x = of_string x +end + +let validate_ordinary t = + Validate.of_error_opt ( + let module C = Class in + match classify t with + | C.Normal | C.Subnormal | C.Zero -> None + | C.Infinite -> Some "value is infinite" + | C.Nan -> Some "value is NaN") +;; + +module V = struct + module ZZ = Comparable.Validate (T) + + let validate_bound ~min ~max t = + Validate.first_failure (validate_ordinary t) (ZZ.validate_bound t ~min ~max) + ;; + + let validate_lbound ~min t = + Validate.first_failure (validate_ordinary t) (ZZ.validate_lbound t ~min) + ;; + + let validate_ubound ~max t = + Validate.first_failure (validate_ordinary t) (ZZ.validate_ubound t ~max) + ;; +end + +include V + +include Comparable.With_zero (struct + include T + let zero = zero + include V + end) + +(* These are partly here as a performance hack to avoid some boxing we're getting with + the versions we get from [With_zero]. They also make [Float.is_negative nan] and + [Float.is_non_positive nan] return [false]; the versions we get from [With_zero] return + [true]. *) +let is_positive t = t > 0. +let is_non_negative t = t >= 0. +let is_negative t = t < 0. +let is_non_positive t = t <= 0. + +include Pretty_printer.Register(struct + include T + let module_name = "Base.Float" + let to_string = to_string + end) + +module O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = Caml.( ** ) + include (Float_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int = of_int + let of_float x = x +end + +module O_dot = struct + let ( *. ) = ( * ) + let ( +. ) = ( + ) + let ( -. ) = ( - ) + let ( /. ) = ( / ) + let ( ~-. ) = ( ~- ) + let ( **. ) = Caml.( ** ) +end + +module Private = struct + let lower_bound_for_int = lower_bound_for_int + let upper_bound_for_int = upper_bound_for_int + let specialized_hash = hash_float + let one_ulp_less_than_half = one_ulp_less_than_half + let int63_round_nearest_portable_alloc_exn = int63_round_nearest_portable_alloc_exn + let int63_round_nearest_arch64_noalloc_exn = int63_round_nearest_arch64_noalloc_exn + let iround_nearest_exn_64 = iround_nearest_exn_64 +end + + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Float_replace_polymorphic_compare + +(* These functions specifically replace defaults in replace_polymorphic_compare *) +let min (x : t) y = + if is_nan x || is_nan y then nan + else if x < y then x else y + +let max (x : t) y = + if is_nan x || is_nan y then nan + else if x > y then x else y diff --git a/src/float.mli b/src/float.mli new file mode 100644 index 0000000..db3185e --- /dev/null +++ b/src/float.mli @@ -0,0 +1,587 @@ +(** Floating-point representation and utilities. + + If using 32-bit OCaml, you cannot quite assume operations act as you'd expect for IEEE + 64-bit floats. E.g., one can have [let x = ~-. (2. ** 62.) in x = x -. 1.] evaluate + to [false] while [let x = ~-. (2. ** 62.) in let y = x -. 1 in x = y] evaluates to + [true]. This is related to 80-bit registers being used for calculations; you can + force representation as a 64-bit value by let-binding. *) + +open! Import + +type t = float [@@deriving_inline hash] +include +sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +include Floatable.S with type t := t + +(** [max] and [min] will return nan if either argument is nan. + + The [validate_*] functions always fail if class is [Nan] or [Infinite]. *) +include Identifiable.S with type t := t +include Comparable.With_zero with type t := t + +(** [validate_ordinary] fails if class is [Nan] or [Infinite]. *) +val validate_ordinary : t Validate.check + +val nan : t + +val infinity : t +val neg_infinity : t + +val max_value : t (** Equal to [infinity]. *) + +val min_value : t (** Equal to [neg_infinity]. *) + +val zero : t +val one : t +val minus_one : t + +val pi : t (** The constant pi. *) + +val sqrt_pi : t (** The constant sqrt(pi). *) + +val sqrt_2pi : t (** The constant sqrt(2 * pi). *) + +val euler : t (** Euler-Mascheroni constant (γ). *) + +(** The difference between 1.0 and the smallest exactly representable floating-point + number greater than 1.0. That is: + + [epsilon_float = (one_ulp `Up 1.0) -. 1.0] + + This gives the relative accuracy of type [t], in the sense that for numbers on the + order of [x], the roundoff error is on the order of [x *. float_epsilon]. + + See also: {{:http://en.wikipedia.org/wiki/Machine_epsilon} Machine epsilon}. +*) +val epsilon_float : t + +val max_finite_value : t + +(** + - [min_positive_subnormal_value = 2 ** -1074] + - [min_positive_normal_value = 2 ** -1022] *) + +val min_positive_subnormal_value : t +val min_positive_normal_value : t + +(** An order-preserving bijection between all floats except for nans, and all int64s with + absolute value smaller than or equal to [2**63 - 2**52]. Note both 0. and -0. map to + 0L. *) +val to_int64_preserve_order : t -> int64 option +val to_int64_preserve_order_exn : t -> int64 + +(** Returns [nan] if the absolute value of the argument is too large. *) +val of_int64_preserve_order : int64 -> t + +(** The next or previous representable float. ULP stands for "unit of least precision", + and is the spacing between floating point numbers. Both [one_ulp `Up infinity] and + [one_ulp `Down neg_infinity] return a nan. *) +val one_ulp : [`Up | `Down] -> t -> t + +(** Note that this doesn't round trip in either direction. For example, [Float.to_int + (Float.of_int max_int) <> max_int]. *) +val of_int : int -> t +val to_int : t -> int + +val of_int63 : Int63.t -> t + +val of_int64 : int64 -> t +val to_int64 : t -> int64 + +(** [round] rounds a float to an integer float. [iround{,_exn}] rounds a float to an + int. Both round according to a direction [dir], with default [dir] being [`Nearest]. + + {v + | `Down | rounds toward Float.neg_infinity | + | `Up | rounds toward Float.infinity | + | `Nearest | rounds to the nearest int ("round half-integers up") | + | `Zero | rounds toward zero | + v} + + [iround_exn] raises when trying to handle nan or trying to handle a float outside the + range \[float min_int, float max_int). + + + Here are some examples for [round] for each direction: + + {v + | `Down | [-2.,-1.) to -2. | [-1.,0.) to -1. | [0.,1.) to 0., [1.,2.) to 1. | + | `Up | (-2.,-1.] to -1. | (-1.,0.] to -0. | (0.,1.] to 1., (1.,2.] to 2. | + | `Zero | (-2.,-1.] to -1. | (-1.,1.) to 0. | [1.,2.) to 1. | + | `Nearest | [-1.5,-0.5) to -1. | [-0.5,0.5) to 0. | [0.5,1.5) to 1. | + v} + + For convenience, versions of these functions with the [dir] argument hard-coded are + provided. If you are writing performance-critical code you should use the + versions with the hard-coded arguments (e.g. [iround_down_exn]). The [_exn] ones + are the fastest. + + The following properties hold: + + - [of_int (iround_*_exn i) = i] for any float [i] that is an integer with + [min_int <= i <= max_int]. + + - [round_* i = i] for any float [i] that is an integer. + + - [iround_*_exn (of_int i) = i] for any int [i] with [-2**52 <= i <= 2**52]. *) +val round : ?dir:[`Zero|`Nearest|`Up|`Down] -> t -> t +val iround : ?dir:[`Zero|`Nearest|`Up|`Down] -> t -> int option +val iround_exn : ?dir:[`Zero|`Nearest|`Up|`Down] -> t -> int + +val round_towards_zero : t -> t +val round_down : t -> t +val round_up : t -> t +val round_nearest : t -> t (** Rounds half integers up. *) + +val round_nearest_half_to_even : t -> t (** Rounds half integers to the even integer. *) + +val iround_towards_zero : t -> int option +val iround_down : t -> int option +val iround_up : t -> int option +val iround_nearest : t -> int option + +val iround_towards_zero_exn : t -> int +val iround_down_exn : t -> int +val iround_up_exn : t -> int +val iround_nearest_exn : t -> int + +val int63_round_down_exn : t -> Int63.t +val int63_round_up_exn : t -> Int63.t +val int63_round_nearest_exn : t -> Int63.t + +(** If [f <= iround_lbound || f >= iround_ubound], then [iround*] functions will refuse + to round [f], returning [None] or raising as appropriate. *) +val iround_lbound : t +val iround_ubound : t + +(** [round_significant x ~significant_digits:n] rounds to the nearest number with [n] + significant digits. More precisely: it returns the representable float closest to [x + rounded to n significant digits]. It is meant to be equivalent to [sprintf "%.*g" n x + |> Float.of_string] but faster (10x-15x). Exact ties are resolved as round-to-even. + + However, it might in rare cases break the contract above. + + + It might in some cases appear as if it violates the round-to-even rule: + + {[ + let x = 4.36083208835;; + let z = 4.3608320883;; + assert (z = fast_approx_round_significant x ~sf:11) + ]} + + But in this case so does sprintf, since [x] as a float is slightly + under-represented: + + {[ + sprintf "%.11g" x = "4.3608320883";; + sprintf "%.30g" x = "4.36083208834999958014577714493" + ]} + + More importantly, [round_significant] might sometimes give a different + result than [sprintf ... |> Float.of_string] because it round-trips through an + integer. For example, the decimal fraction 0.009375 is slightly under-represented as + a float: + + {[ sprintf "%.17g" 0.009375 = "0.0093749999999999997" ]} + + But: + + {[ 0.009375 *. 1e5 = 937.5 ]} + + Therefore: + + {[ round_significant 0.009375 ~significant_digits:3 = 0.00938 ]} + + whereas: + + {[ sprintf "%.3g" 0.009375 = "0.00937" ]} + + + In general we believe (and have tested on numerous examples) that the following + holds for all x: + + {[ + let s = sprintf "%.*g" significant_digits x |> Float.of_string in + s = round_significant ~significant_digits x + || s = round_significant ~significant_digits (one_ulp `Up x) + || s = round_significant ~significant_digits (one_ulp `Down x) + ]} + + Also, for float representations of decimal fractions (like 0.009375), + [round_significant] is more likely to give the "desired" result than [sprintf ... |> + of_string] (that is, the result of rounding the decimal fraction, rather than its + float representation). But it's not guaranteed either--see the [4.36083208835] + example above. + +*) +val round_significant : float -> significant_digits:int -> float + +(** [round_decimal x ~decimal_digits:n] rounds [x] to the nearest [10**(-n)]. For positive + [n] it is meant to be equivalent to [sprintf "%.*f" n x |> Float.of_string], but + faster. + + All the considerations mentioned in [round_significant] apply (both functions use the + same code path). +*) +val round_decimal : float -> decimal_digits:int -> float + + +val is_nan : t -> bool + +(** Includes positive and negative [Float.infinity]. *) +val is_inf : t -> bool + +(** [min_inan] and [max_inan] return, respectively, the min and max of the two given + values, except when one of the values is a [nan], in which case the other is + returned. (Returns [nan] if both arguments are [nan].) *) + +val min_inan : t -> t -> t +val max_inan : t -> t -> t + +val ( + ) : t -> t -> t +val ( - ) : t -> t -> t +val ( / ) : t -> t -> t +val ( * ) : t -> t -> t +val ( ** ) : t -> t -> t + +val ( ~- ) : t -> t + +(** Returns the fractional part and the whole (i.e., integer) part. For example, [modf + (-3.14)] returns [{ fractional = -0.14; integral = -3.; }]! *) +module Parts : sig + type outer + type t + val fractional : t -> outer + val integral : t -> outer +end with type outer := t +val modf : t -> Parts.t + +(** [mod_float x y] returns a result with the same sign as [x]. It returns [nan] if [y] + is [0]. It is basically + + {[ let mod_float x y = x -. float(truncate(x/.y)) *. y]} + + not + + {[ let mod_float x y = x -. floor(x/.y) *. y ]} + + and therefore resembles [mod] on integers more than [%]. *) +val mod_float : t -> t -> t + +(** {6 Ordinary functions for arithmetic operations} + + These are for modules that inherit from [t], since the infix operators are more + convenient. *) +val add : t -> t -> t +val sub : t -> t -> t +val neg : t -> t +val scale : t -> t -> t +val abs : t -> t + + +(** A sub-module designed to be opened to make working with floats more convenient. *) +module O : sig + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ** ) : t -> t -> t + val ( ~- ) : t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + val of_int : int -> t + val of_float : float -> t +end + +(** Similar to [O], except that operators are suffixed with a dot, allowing one to have + both int and float operators in scope simultaneously. *) +module O_dot : sig + val ( +. ) : t -> t -> t + val ( -. ) : t -> t -> t + val ( *. ) : t -> t -> t + val ( /. ) : t -> t -> t + val ( **. ) : t -> t -> t + val ( ~-. ) : t -> t +end + +(** [to_string x] builds a string [s] representing the float [x] that guarantees the round + trip, that is such that [Float.equal x (Float.of_string s)]. + + It usually yields as few significant digits as possible. That is, it won't print + [3.14] as [3.1400000000000001243]. The only exception is that occasionally it will + output 17 significant digits when the number can be represented with just 16 (but not + 15 or less) of them. *) +val to_string : t -> string + +(** Pretty print float, for example [to_string_hum ~decimals:3 1234.1999 = "1_234.200"] + [to_string_hum ~decimals:3 ~strip_zero:true 1234.1999 = "1_234.2" ]. No delimiters + are inserted to the right of the decimal. *) +val to_string_hum + : ?delimiter:char (** defaults to ['_'] *) + -> ?decimals:int (** defaults to [3] *) + -> ?strip_zero:bool (** defaults to [false] *) + -> t + -> string + +(** Produce a lossy compact string representation of the float. The float is scaled by + an appropriate power of 1000 and rendered with one digit after the decimal point, + except that the decimal point is written as '.', 'k', 'm', 'g', 't', or 'p' to + indicate the scale factor. (However, if the digit after the "decimal" point is 0, + it is suppressed.) + + The smallest scale factor that allows the number to be rendered with at most 3 digits + to the left of the decimal is used. If the number is too large for this format (i.e., + the absolute value is at least 999.95e15), scientific notation is used instead. E.g.: + + - [to_padded_compact_string (-0.01) = "-0 "] + - [to_padded_compact_string 1.89 = "1.9"] + - [to_padded_compact_string 999_949.99 = "999k9"] + - [to_padded_compact_string 999_950. = "1m "] + + In the case where the digit after the "decimal", or the "decimal" itself is omitted, + the numbers are padded on the right with spaces to ensure the last two columns of the + string always correspond to the decimal and the digit afterward (except in the case of + scientific notation, where the exponent is the right-most element in the string and + could take up to four characters). + + - [to_padded_compact_string 1. = "1 "] + - [to_padded_compact_string 1.e6 = "1m "] + - [to_padded_compact_string 1.e16 = "1.e+16"] + - [to_padded_compact_string max_finite_value = "1.8e+308"] + + Numbers in the range -.05 < x < .05 are rendered as "0 " or "-0 ". + + Other cases: + + - [to_padded_compact_string nan = "nan "] + - [to_padded_compact_string infinity = "inf "] + - [to_padded_compact_string neg_infinity = "-inf "] + + Exact ties are resolved to even in the decimal: + + - [to_padded_compact_string 3.25 = "3.2"] + - [to_padded_compact_string 3.75 = "3.8"] + - [to_padded_compact_string 33_250. = "33k2"] + - [to_padded_compact_string 33_350. = "33k4"] *) +val to_padded_compact_string : t -> string + +(** [int_pow x n] computes [x ** float n] via repeated squaring. It is generally much + faster than [**]. + + Note that [int_pow x 0] always returns [1.], even if [x = nan]. This + coincides with [x ** 0.] and is intentional. + + For [n >= 0] the result is identical to an n-fold product of [x] with itself under + [*.], with a certain placement of parentheses. For [n < 0] the result is identical + to [int_pow (1. /. x) (-n)]. + + The error will be on the order of [|n|] ulps, essentially the same as if you + perturbed [x] by up to a ulp and then exponentiated exactly. + + Benchmarks show a factor of 5-10 speedup (relative to [**]) for exponents up to about + 1000 (approximately 10ns vs. 70ns). For larger exponents the advantage is smaller but + persists into the trillions. For a recent or more detailed comparison, run the + benchmarks. + + Depending on context, calling this function might or might not allocate 2 minor words. + Even if called in a way that causes allocation, it still appears to be faster than + [**]. *) +val int_pow : t -> int -> t + +(** [square x] returns [x *. x]. *) +val square : t -> t + +(** [ldexp x n] returns [x *. 2 ** n] *) +val ldexp : t -> int -> t + +(** [frexp f] returns the pair of the significant and the exponent of [f]. When [f] is + zero, the significant [x] and the exponent [n] of [f] are equal to zero. When [f] is + non-zero, they are defined by [f = x *. 2 ** n] and [0.5 <= x < 1.0]. *) +val frexp : t -> t * int + +(** Base 10 logarithm. *) +external log10 : t -> t = "caml_log10_float" "log10" +[@@unboxed] [@@noalloc] + +(** [expm1 x] computes [exp x -. 1.0], giving numerically-accurate results even if [x] is + close to [0.0]. *) +external expm1 : t -> t = "caml_expm1_float" "caml_expm1" +[@@unboxed] [@@noalloc] + +(** [log1p x] computes [log(1.0 +. x)] (natural logarithm), giving numerically-accurate + results even if [x] is close to [0.0]. *) +external log1p : t -> t = "caml_log1p_float" "caml_log1p" +[@@unboxed] [@@noalloc] + +(** [copysign x y] returns a float whose absolute value is that of [x] and whose sign is + that of [y]. If [x] is [nan], returns [nan]. If [y] is [nan], returns either [x] or + [-. x], but it is not specified which. *) +external copysign : t -> t -> t = "caml_copysign_float" "caml_copysign" +[@@unboxed] [@@noalloc] + +(** Cosine. Argument is in radians. *) +external cos : t -> t = "caml_cos_float" "cos" +[@@unboxed] [@@noalloc] + +(** Sine. Argument is in radians. *) +external sin : t -> t = "caml_sin_float" "sin" +[@@unboxed] [@@noalloc] + +(** Tangent. Argument is in radians. *) +external tan : t -> t = "caml_tan_float" "tan" +[@@unboxed] [@@noalloc] + +(** Arc cosine. The argument must fall within the range [[-1.0, 1.0]]. Result is in + radians and is between [0.0] and [pi]. *) +external acos : t -> t = "caml_acos_float" "acos" +[@@unboxed] [@@noalloc] + +(** Arc sine. The argument must fall within the range [[-1.0, 1.0]]. Result is in + radians and is between [-pi/2] and [pi/2]. *) +external asin : t -> t = "caml_asin_float" "asin" +[@@unboxed] [@@noalloc] + +(** Arc tangent. Result is in radians and is between [-pi/2] and [pi/2]. *) +external atan : t -> t = "caml_atan_float" "atan" +[@@unboxed] [@@noalloc] + +(** [atan2 y x] returns the arc tangent of [y /. x]. The signs of [x] and [y] are used to + determine the quadrant of the result. Result is in radians and is between [-pi] and + [pi]. *) +external atan2 : t -> t -> t = "caml_atan2_float" "atan2" +[@@unboxed] [@@noalloc] + +(** [hypot x y] returns [sqrt(x *. x + y *. y)], that is, the length of the hypotenuse of + a right-angled triangle with sides of length [x] and [y], or, equivalently, the + distance of the point [(x,y)] to origin. *) +external hypot : t -> t -> t = "caml_hypot_float" "caml_hypot" +[@@unboxed] [@@noalloc] + +(** Hyperbolic cosine. Argument is in radians. *) +external cosh : t -> t = "caml_cosh_float" "cosh" +[@@unboxed] [@@noalloc] + +(** Hyperbolic sine. Argument is in radians. *) +external sinh : t -> t = "caml_sinh_float" "sinh" +[@@unboxed] [@@noalloc] + +(** Hyperbolic tangent. Argument is in radians. *) +external tanh : t -> t = "caml_tanh_float" "tanh" +[@@unboxed] [@@noalloc] + +(** Square root. *) +external sqrt : t -> t = "caml_sqrt_float" "sqrt" +[@@unboxed] [@@noalloc] + +(** Exponential. *) +external exp : t -> t = "caml_exp_float" "exp" +[@@unboxed] [@@noalloc] + +(** Natural logarithm. *) +external log : t -> t = "caml_log_float" "log" +[@@unboxed] [@@noalloc] + + +(** Excluding nan the floating-point "number line" looks like: + {v + t Class.t example + ^ neg_infinity Infinite neg_infinity + | neg normals Normal -3.14 + | neg subnormals Subnormal -.2. ** -1023. + | (-/+) zero Zero 0. + | pos subnormals Subnormal 2. ** -1023. + | pos normals Normal 3.14 + v infinity Infinite infinity + v} *) +module Class : sig + type t = + | Infinite + | Nan + | Normal + | Subnormal + | Zero + [@@deriving_inline compare, enumerate, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val all : t list + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Stringable.S with type t := t +end + +val classify : t -> Class.t + +(** [is_finite t] returns [true] iff [classify t] is in [Normal; Subnormal; Zero;]. *) +val is_finite : t -> bool + +(*_ Caution: If we remove this sig item, [sign] will still be present from + [Comparable.With_zero]. *) +val sign : t -> Sign.t +[@@deprecated "[since 2016-01] Replace [sign] with [robust_sign] or [sign_exn]"] + +(** The sign of a float. Both [-0.] and [0.] map to [Zero]. Raises on nan. All other + values map to [Neg] or [Pos]. *) +val sign_exn : t -> Sign.t + +(** The sign of a float, with support for NaN. Both [-0.] and [0.] map to [Zero]. All NaN + values map to [Nan]. All other values map to [Neg] or [Pos]. *) +val sign_or_nan : t -> Sign_or_nan.t + +(** These functions construct and destruct 64-bit floating point numbers based on their + IEEE representation with a sign bit, an 11-bit non-negative (biased) exponent, and a + 52-bit non-negative mantissa (or significand). See + {{:http://en.wikipedia.org/wiki/Double-precision_floating-point_format} Wikipedia} for + details of the encoding. + + In particular, if 1 <= exponent <= 2046, then: + + {[ + create_ieee_exn ~negative:false ~exponent ~mantissa + = 2 ** (exponent - 1023) * (1 + (2 ** -52) * mantissa) + ]} *) +val create_ieee : negative:bool -> exponent:int -> mantissa:Int63.t -> t Or_error.t +val create_ieee_exn : negative:bool -> exponent:int -> mantissa:Int63.t -> t +val ieee_negative : t -> bool +val ieee_exponent : t -> int +val ieee_mantissa : t -> Int63.t + +(** S-expressions contain at most 8 significant digits. *) +module Terse : sig + type nonrec t = t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t +end + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + val lower_bound_for_int : int -> t + val upper_bound_for_int : int -> t + val specialized_hash : t -> int + val one_ulp_less_than_half : t + val int63_round_nearest_portable_alloc_exn : t -> Int63.t + val int63_round_nearest_arch64_noalloc_exn : t -> Int63.t + val iround_nearest_exn_64 : t -> int +end diff --git a/src/float0.ml b/src/float0.ml new file mode 100644 index 0000000..4a405f6 --- /dev/null +++ b/src/float0.ml @@ -0,0 +1,125 @@ +open! Import + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Float_replace_polymorphic_compare + +let is_nan x = (x : float) <> x + +(* An order-preserving bijection between all floats except for NaNs, and 99.95% of + int64s. + + Note we don't distinguish 0. and -0. as separate values here, they both map to 0L, which + maps back to 0. + + This should work both on little-endian and high-endian CPUs. Wikipedia says: "on + modern standard computers (i.e., implementing IEEE 754), one may in practice safely + assume that the endianness is the same for floating point numbers as for integers" + (http://en.wikipedia.org/wiki/Endianness#Floating-point_and_endianness). +*) +let to_int64_preserve_order t = + if is_nan t then + None + else + if t = 0. then (* also includes -0. *) + Some 0L + else + if t > 0. then + Some (Caml.Int64.bits_of_float t) + else + Some (Caml.Int64.neg (Caml.Int64.bits_of_float (~-. t))) +;; + +let to_int64_preserve_order_exn x = + Option.value_exn (to_int64_preserve_order x) +;; + +let of_int64_preserve_order x = + if Int64_replace_polymorphic_compare.(>=) x 0L then + Caml.Int64.float_of_bits x + else + ~-. (Caml.Int64.float_of_bits (Caml.Int64.neg x)) +;; + +let one_ulp dir t = + match to_int64_preserve_order t with + | None -> Caml.nan + | Some x -> + of_int64_preserve_order (Caml.Int64.add x (match dir with `Up -> 1L | `Down -> -1L)) +;; + +(* [upper_bound_for_int] and [lower_bound_for_int] are for calculating the max/min float + that fits in a given-size integer when rounded towards 0 (using [int_of_float]). + + max_int/min_int depend on [num_bits], e.g. +/- 2^30, +/- 2^62 if 31-bit, 63-bit + (respectively) while float is IEEE standard for double (52 significant bits). + + In all cases, we want to guarantee that + [lower_bound_for_int <= x <= upper_bound_for_int] + iff [int_of_float x] fits in an int with [num_bits] bits. + + [2 ** (num_bits - 1)] is the first float greater that max_int, we use the preceding + float as upper bound. + + [- (2 ** (num_bits - 1))] is equal to min_int. + For lower bound we look for the smallest float [f] satisfying [f > min_int - 1] so that + [f] rounds toward zero to [min_int] + + So in particular we will have: + [lower_bound_for_int x <= - (2 ** (1-x))] + [upper_bound_for_int x < 2 ** (1-x) ] +*) +let upper_bound_for_int num_bits = + let exp = Caml.float_of_int ( num_bits - 1 ) in + one_ulp `Down (2. ** exp) + +let is_x_minus_one_exact x = + (* [x = x -. 1.] does not work with x87 floating point arithmetic backend (which is used + on 32-bit ocaml) because of 80-bit register precision of intermediate computations. + + An alternative way of computing this: [x -. one_ulp `Down x <= 1.] is also prone to + the same precision issues: you need to make sure [x] is 64-bit. + *) + let open Int64_replace_polymorphic_compare in + not (Caml.Int64.bits_of_float x = Caml.Int64.bits_of_float (x -. 1.)) + +let lower_bound_for_int num_bits = + let exp = Caml.float_of_int ( num_bits - 1 ) in + let min_int_as_float = ~-. (2. ** exp) in + let open Int_replace_polymorphic_compare in + if num_bits - 1 < 53 (* 53 = #bits in the float's mantissa with sign included *) + then + begin + (* The smallest float that rounds towards zero to [min_int] is + [min_int - 1 + epsilon] *) + assert (is_x_minus_one_exact min_int_as_float); + one_ulp `Up (min_int_as_float -. 1.) + end + else + begin + (* [min_int_as_float] is already the smallest float [f] satisfying [f > min_int - 1]. *) + assert (not (is_x_minus_one_exact min_int_as_float)); + min_int_as_float + end + + +(* Float clamping is structured slightly differently than clamping for other types, so + that we get the behavior of [clamp_unchecked nan ~min ~max = nan] (for any [min] and + [max]) for free. +*) +let clamp_unchecked (t : float) ~min ~max = + if t < min then min + else if max < t then max + else t + +let box = + (* Prevent potential constant folding of [+. 0.] in the near ocamlopt future. *) + let x = if Random.bool () then 0. else 0. in + (fun f -> f +. x) + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Float_replace_polymorphic_compare diff --git a/src/floatable.ml b/src/floatable.ml new file mode 100644 index 0000000..1c77d93 --- /dev/null +++ b/src/floatable.ml @@ -0,0 +1,10 @@ +(** Functor that adds float conversion functions to a module. *) + +open! Import + +module type S = sig + type t + + val of_float : float -> t + val to_float : t -> float +end diff --git a/src/fn.ml b/src/fn.ml new file mode 100644 index 0000000..d7a2b32 --- /dev/null +++ b/src/fn.ml @@ -0,0 +1,30 @@ +open! Import + +let const c _ = c + +external ignore : _ -> unit = "%ignore" (* this has the same behavior as [Caml.ignore] *) + +let non f x = not (f x) + +let forever f = + let rec forever () = + f (); + forever () + in + try forever () + with e -> e + +external id : 'a -> 'a = "%identity" + +external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + +(* The typical use case for these functions is to pass in functional arguments and get + functions as a result. *) +let compose f g x = f (g x) + +let flip f x y = f y x + +let rec apply_n_times ~n f x = + if n <= 0 + then x + else apply_n_times ~n:(n - 1) f (f x) diff --git a/src/fn.mli b/src/fn.mli new file mode 100644 index 0000000..8ed42fa --- /dev/null +++ b/src/fn.mli @@ -0,0 +1,32 @@ +(** Various combinators for functions. *) + +open! Import + +(** A "pipe" operator. *) +external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + +(** Produces a function that just returns its first argument. *) +val const : 'a -> _ -> 'a + +(** [ignore] is the same as [Caml.ignore]. It is useful to have here so that code + that rebinds [ignore] can still refer to [Fn.ignore]. *) +external ignore : _ -> unit = "%ignore" + +(** Negates a function. *) +val non : ('a -> bool) -> 'a -> bool + +(** [forever f] runs [f ()] until it throws an exception and returns the + exception. This function is useful for read_line loops, etc. *) +val forever : (unit -> unit) -> exn + +(** [apply_n_times ~n f x] is the [n]-fold application of [f] to [x]. *) +val apply_n_times : n:int -> ('a -> 'a) -> ('a -> 'a) + +(** The identity function. Also see [Sys.opaque_identity]. *) +external id : 'a -> 'a = "%identity" + +(** [compose f g x] is [f (g x)]. *) +val compose : ('b -> 'c) -> ('a -> 'b) -> ('a -> 'c) + +(** Reverses the order of arguments for a binary function. *) +val flip : ('a -> 'b -> 'c) -> ('b -> 'a -> 'c) diff --git a/src/formatter.ml b/src/formatter.ml new file mode 100644 index 0000000..0f4b939 --- /dev/null +++ b/src/formatter.ml @@ -0,0 +1 @@ +type t = Caml.Format.formatter diff --git a/src/formatter.mli b/src/formatter.mli new file mode 100644 index 0000000..fdd1daf --- /dev/null +++ b/src/formatter.mli @@ -0,0 +1,9 @@ +(** The [Format.formatter] type from OCaml's standard library, exported here + for convenience and compatibility with other libraries. + + The [Format] module itself is deprecated in Base. You may refer to it + explicitly through [Caml.Format], though you may wish to search for other + alternatives for constructing pretty-printers using the [Format.formatter] + type. *) + +type t = Caml.Format.formatter diff --git a/src/hash.ml b/src/hash.ml new file mode 100644 index 0000000..fcd56db --- /dev/null +++ b/src/hash.ml @@ -0,0 +1,231 @@ +(* + This is the interface to the runtime support for [ppx_hash]. + + The [ppx_hash] syntax extension supports: [@@deriving_inline hash][@@@end] and [%hash_fold: TYPE] and + [%hash: TYPE] + + For type [t] a function [hash_fold_t] of type [Hash.state -> t -> Hash.state] is + generated. + + The generated [hash_fold_<T>] function is compositional, following the structure of the + type; allowing user overrides at every level. This is in contrast to ocaml's builtin + polymorphic hashing [Hashtbl.hash] which ignores user overrides. + + The generator also provides a direct hash-function [hash] (named [hash_<T>] when <T> != + "t") of type: [t -> Hash.hash_value]. + + The folding hash function can be accessed as [%hash_fold: TYPE] + The direct hash function can be accessed as [%hash: TYPE] +*) + +open! Import0 + +module Array = Array0 +module Char = Char0 +module Int = Int0 +module List = List0 + +include Hash_intf + +(** Builtin folding-style hash functions, abstracted over [Hash_intf.S] *) +module Folding (Hash : Hash_intf.S) + : Hash_intf.Builtin_intf + with type state = Hash.state + and type hash_value = Hash.hash_value += struct + + type state = Hash.state + type hash_value = Hash.hash_value + type 'a folder = state -> 'a -> state + + let hash_fold_unit s () = s + + let hash_fold_int = Hash.fold_int + let hash_fold_int64 = Hash.fold_int64 + let hash_fold_float = Hash.fold_float + let hash_fold_string = Hash.fold_string + + let as_int f s x = hash_fold_int s (f x) + + (* This ignores the sign bit on 32-bit architectures, but it's unlikely to lead to + frequent collisions (min_value colliding with 0 is the most likely one). *) + let hash_fold_int32 = as_int Caml.Int32.to_int + + let hash_fold_char = as_int Char.to_int + let hash_fold_bool = as_int (function true -> 1 | false -> 0) + + let hash_fold_nativeint s x = hash_fold_int64 s (Caml.Int64.of_nativeint x) + + let hash_fold_option hash_fold_elem s = function + | None -> hash_fold_int s 0 + | Some x -> hash_fold_elem (hash_fold_int s 1) x + + let rec hash_fold_list_body hash_fold_elem s list = + match list with + | [] -> s + | x::xs -> hash_fold_list_body hash_fold_elem (hash_fold_elem s x) xs + + let hash_fold_list hash_fold_elem s list = + (* The [length] of the list must be incorporated into the hash-state so values of + types such as [unit list] - ([], [()], [();()],..) are hashed differently. *) + (* The [length] must come before the elements to avoid a violation of the rule + enforced by Perfect_hash. *) + let s = hash_fold_int s (List.length list) in + let s = hash_fold_list_body hash_fold_elem s list in + s + + let hash_fold_lazy_t hash_fold_elem s x = + hash_fold_elem s (Caml.Lazy.force x) + + let hash_fold_ref_frozen hash_fold_elem s x = hash_fold_elem s (!x) + + let rec hash_fold_array_frozen_i hash_fold_elem s array i = + if i = Array.length array + then s + else + let e = Array.unsafe_get array i in + hash_fold_array_frozen_i hash_fold_elem (hash_fold_elem s e) array (i + 1) + + let hash_fold_array_frozen hash_fold_elem s array = + hash_fold_array_frozen_i + (* [length] must be incorporated for arrays, as it is for lists. See comment above *) + hash_fold_elem (hash_fold_int s (Array.length array)) array 0 + + (* the duplication here is because we think + ocaml can't eliminate indirect function calls otherwise. *) + let hash_nativeint x = + Hash.get_hash_value (hash_fold_nativeint (Hash.reset (Hash.alloc ())) x) + let hash_int64 x = + Hash.get_hash_value (hash_fold_int64 (Hash.reset (Hash.alloc ())) x) + let hash_int32 x = + Hash.get_hash_value (hash_fold_int32 (Hash.reset (Hash.alloc ())) x) + let hash_char x = + Hash.get_hash_value (hash_fold_char (Hash.reset (Hash.alloc ())) x) + let hash_int x = + Hash.get_hash_value (hash_fold_int (Hash.reset (Hash.alloc ())) x) + let hash_bool x = + Hash.get_hash_value (hash_fold_bool (Hash.reset (Hash.alloc ())) x) + let hash_string x = + Hash.get_hash_value (hash_fold_string (Hash.reset (Hash.alloc ())) x) + let hash_float x = + Hash.get_hash_value (hash_fold_float (Hash.reset (Hash.alloc ())) x) + let hash_unit x = + Hash.get_hash_value (hash_fold_unit (Hash.reset (Hash.alloc ())) x) + +end + +module F (Hash : Hash_intf.S) : + Hash_intf.Full + with type hash_value = Hash.hash_value + and type state = Hash.state + and type seed = Hash.seed += struct + + include Hash + + type 'a folder = state -> 'a -> state + + let create ?seed () = reset ?seed (alloc ()) + + let of_fold hash_fold_t = (fun t -> get_hash_value (hash_fold_t (create ()) t)) + + module Builtin = Folding(Hash) + + let run ?seed folder x = + Hash.get_hash_value (folder (Hash.reset ?seed (Hash.alloc ())) x) + +end + +module Internalhash : sig + include Hash_intf.S + with type state = private int (* allow optimizations for immediate type *) + and type seed = int + and type hash_value = int + + external fold_int64 : state -> int64 -> state = "Base_internalhash_fold_int64" [@@noalloc] + external fold_int : state -> int -> state = "Base_internalhash_fold_int" [@@noalloc] + external fold_float : state -> float -> state = "Base_internalhash_fold_float" [@@noalloc] + external fold_string : state -> string -> state = "Base_internalhash_fold_string" [@@noalloc] + external get_hash_value : state -> hash_value = "Base_internalhash_get_hash_value" [@@noalloc] +end = struct + let description = "internalhash" + + type state = int + type hash_value = int + type seed = int + + external create_seeded : seed -> state = "%identity" [@@noalloc] + external fold_int64 : state -> int64 -> state = "Base_internalhash_fold_int64" [@@noalloc] + external fold_int : state -> int -> state = "Base_internalhash_fold_int" [@@noalloc] + external fold_float : state -> float -> state = "Base_internalhash_fold_float" [@@noalloc] + external fold_string : state -> string -> state = "Base_internalhash_fold_string" [@@noalloc] + external get_hash_value : state -> hash_value = "Base_internalhash_get_hash_value" [@@noalloc] + + let alloc () = create_seeded 0 + + let reset ?(seed=0) _t = create_seeded seed + + module For_tests = struct + let compare_state = compare + let state_to_string = Int.to_string + end +end + +module T = struct + include Internalhash + type 'a folder = state -> 'a -> state + + let create ?seed () = reset ?seed (alloc ()) + + let run ?seed folder x = + get_hash_value (folder (reset ?seed (alloc ())) x) + + let of_fold hash_fold_t = (fun t -> get_hash_value (hash_fold_t (create ()) t)) + + module Builtin = struct + module Folding = Folding(Internalhash) + include + (Folding : Hash_intf.Builtin_hash_fold_intf + with type state := state + and type 'a folder := 'a folder) + + let hash_nativeint = Folding.hash_nativeint + let hash_int64 = Folding.hash_int64 + let hash_int32 = Folding.hash_int32 + let hash_string = Folding.hash_string + + (* [Folding] provides some default implementations for the [hash_*] functions below, + but they are inefficient for some use-cases because of the use of the [hash_fold] + functions. At this point, the [hash_value] type has been fixed to [int], so this + module can provide specialized implementations. *) + + let hash_char = Char0.to_int + + (* This hash was chosen from here: https://gist.github.com/badboy/6267743 + + It attempts to fulfill the primary goals of a non-cryptographic hash function: + + - a bit change in the input should change ~1/2 of the output bits + - the output should be uniformly distributed across the output range + - inputs that are close to each other shouldn't lead to outputs that are close to + each other. + - all bits of the input are used in generating the output + + In our case we also want it to be fast, non-allocating, and inlinable. *) + let [@inline always] hash_int (t : int) = + let t = (lnot t) + (t lsl 21) in + let t = t lxor (t lsr 24) in + let t = (t + (t lsl 3)) + (t lsl 8) in + let t = t lxor (t lsr 14) in + let t = (t + (t lsl 2)) + (t lsl 4) in + let t = t lxor (t lsr 28) in + t + (t lsl 31) + ;; + + let hash_bool x = if x then 1 else 0 + external hash_float : float -> int = "Base_hash_double" [@@noalloc] + let hash_unit () = 0 + end +end + +include T diff --git a/src/hash.mli b/src/hash.mli new file mode 100644 index 0000000..1b586b4 --- /dev/null +++ b/src/hash.mli @@ -0,0 +1 @@ +include Hash_intf.Hash (** @inline *) diff --git a/src/hash_intf.ml b/src/hash_intf.ml new file mode 100644 index 0000000..f0f9536 --- /dev/null +++ b/src/hash_intf.ml @@ -0,0 +1,196 @@ +(** [Hash_intf.S] is the interface which a hash function must support. + + The functions of [Hash_intf.S] are only allowed to be used in specific sequence: + + [alloc], [reset ?seed], [fold_..*], [get_hash_value], [reset ?seed], [fold_..*], + [get_hash_value], ... + + (The optional [seed]s passed to each reset may differ.) + + The chain of applications from [reset] to [get_hash_value] must be done in a + single-threaded manner (you can't use [fold_*] on a state that's been used + before). More precisely, [alloc ()] creates a new family of states. All functions that + take [t] and produce [t] return a new state from the same family. + + At any point in time, at most one state in the family is "valid". The other states are + "invalid". + + - The state returned by [alloc] is invalid. + - The state returned by [reset] is valid (all of the other states become invalid). + - The [fold_*] family of functions requires a valid state and produces a valid state + (thereby making the input state invalid). + - [get_hash_value] requires a valid state and makes it invalid. + + These requirements are currently formally encoded in the [Check_initialized_correctly] + module in bench/bench.ml. *) + +open! Import0 + +module type S = sig + + (** Name of the hash-function, e.g., "internalhash", "siphash" *) + val description : string + + (** [state] is the internal hash-state used by the hash function. *) + type state + + (** [fold_<T> state v] incorporates a value [v] of type <T> into the hash-state, + returning a modified hash-state. Implementations of the [fold_<T>] functions may + mutate the [state] argument in place, and return a reference to it. Implementations + of the fold_<T> functions should not allocate. *) + val fold_int : state -> int -> state + val fold_int64 : state -> int64 -> state + val fold_float : state -> float -> state + val fold_string : state -> string -> state + + (** [seed] is the type used to seed the initial hash-state. *) + type seed + + (** [alloc ()] returns a fresh uninitialized hash-state. May allocate. *) + val alloc : unit -> state + + (** [reset ?seed state] initializes/resets a hash-state with the given [seed], or else a + default-seed. Argument [state] may be mutated. Should not allocate. *) + val reset : ?seed:seed -> state -> state + + (** [hash_value] The type of hash values, returned by [get_hash_value]. *) + type hash_value + + (** [get_hash_value] extracts a hash-value from the hash-state. *) + val get_hash_value : state -> hash_value + + module For_tests : sig + val compare_state : state -> state -> int + val state_to_string : state -> string + end +end + +module type Builtin_hash_fold_intf = sig + type state + type 'a folder = state -> 'a -> state + + val hash_fold_nativeint : nativeint folder + val hash_fold_int64 : int64 folder + val hash_fold_int32 : int32 folder + val hash_fold_char : char folder + val hash_fold_int : int folder + val hash_fold_bool : bool folder + val hash_fold_string : string folder + val hash_fold_float : float folder + val hash_fold_unit : unit folder + + val hash_fold_option : 'a folder -> 'a option folder + val hash_fold_list : 'a folder -> 'a list folder + val hash_fold_lazy_t : 'a folder -> 'a lazy_t folder + + (** Hash support for [array] and [ref] is provided, but is potentially DANGEROUS, since + it incorporates the current contents of the array/ref into the hash value. Because + of this we add a [_frozen] suffix to the function name. + + Hash support for [string] is also potentially DANGEROUS, but strings are mutated + less often, so we don't append [_frozen] to it. + + Also note that we don't support [bytes]. *) + val hash_fold_ref_frozen : 'a folder -> 'a ref folder + val hash_fold_array_frozen : 'a folder -> 'a array folder + +end + +module type Builtin_hash_intf = sig + type hash_value + + val hash_nativeint : nativeint -> hash_value + val hash_int64 : int64 -> hash_value + val hash_int32 : int32 -> hash_value + val hash_char : char -> hash_value + val hash_int : int -> hash_value + val hash_bool : bool -> hash_value + val hash_string : string -> hash_value + val hash_float : float -> hash_value + val hash_unit : unit -> hash_value + +end + +module type Builtin_intf = sig + include Builtin_hash_fold_intf + include Builtin_hash_intf +end + +module type Full = sig + + include S (** @inline *) + + type 'a folder = state -> 'a -> state + + (** [create ?seed ()] is a convenience. Equivalent to [reset ?seed (alloc ())]. *) + val create : ?seed:seed -> unit -> state + + (** [of_fold fold] constructs a standard hash function from an existing fold + function. *) + val of_fold : (state -> 'a -> state) -> ('a -> hash_value) + + module Builtin : Builtin_intf + with type state := state + and type 'a folder := 'a folder + and type hash_value := hash_value + + (** [run ?seed folder x] runs [folder] on [x] in a newly allocated hash-state, + initialized using optional [seed] or a default-seed. + + The following identity exists: [run [%hash_fold: T]] == [[%hash: T]] + + [run] can be used if we wish to run a hash-folder with a non-default seed. *) + val run : ?seed:seed -> 'a folder -> 'a -> hash_value + +end + +module type Hash = sig + module type Full = Full + module type S = S + + module F (Hash : S) : Full + with type hash_value = Hash.hash_value + and type state = Hash.state + and type seed = Hash.seed + + (** The code of [ppx_hash] is agnostic to the choice of hash algorithm that is + used. However, it is not currently possible to mix various choices of hash algorithms + in a given code base. + + We experimented with: + - (a) custom hash algorithms implemented in OCaml and + - (b) in C; + - (c) OCaml's internal hash function (which is a custom version of Murmur3, + implemented in C); + - (d) siphash, a modern hash function implemented in C. + + Our findings were as follows: + + - Implementing our own custom hash algorithms in OCaml and C yielded very little + performance improvement over the (c) proposal, without providing the benefit of being + a peer-reviewed, widely used hash function. + + - Siphash (a modern hash function with an internal state of 32 bytes) has a worse + performance profile than (a,b,c) above (hashing takes more time). Since its internal + state is bigger than an OCaml immediate value, one must either manage allocation of + such state explicitly, or paying the cost of allocation each time a hash is computed. + While being a supposedly good hash function (with good hash quality), this quality was + not translated in measurable improvements in our macro benchmarks. (Also, based on + the data available at the time of writing, it's unclear that other hash algorithms in + this class would be more than marginally faster.) + + - By contrast, using the internal combinators of OCaml hash function means that we do + not allocate (the internal state of this hash function is 32 bit) and have the same + quality and performance as Hashtbl.hash. + + Hence, we are here making the choice of using this Internalhash (that is, Murmur3, the + OCaml hash algorithm as of 4.03) as our hash algorithm. It means that the state of the + hash function does not need to be preallocated, and makes for simpler use in hash + tables and other structures. *) + + include Full + with type state = private int + and type seed = int + + and type hash_value = int (** @open *) +end diff --git a/src/hash_set.ml b/src/hash_set.ml new file mode 100644 index 0000000..8c416df --- /dev/null +++ b/src/hash_set.ml @@ -0,0 +1,220 @@ +open! Import + +include Hash_set_intf + +let hashable_s = Hashtbl.hashable_s +let hashable = Hashtbl.Private.hashable +let poly_hashable = Hashtbl.Poly.hashable + +let with_return = With_return.with_return + +type 'a t = ('a, unit) Hashtbl.t +type 'a hash_set = 'a t +type 'a elt = 'a + +module Accessors = struct + + let hashable = hashable + let clear = Hashtbl.clear + let length = Hashtbl.length + let mem = Hashtbl.mem + + let is_empty t = Hashtbl.is_empty t + + let find_map t ~f = + with_return (fun r -> + Hashtbl.iter_keys t ~f:(fun elt -> + match f elt with + | None -> () + | Some _ as o -> r.return o); + None) + ;; + + let find t ~f = find_map t ~f:(fun a -> if f a then Some a else None) + + let add t k = Hashtbl.set t ~key:k ~data:() + + let strict_add t k = + if mem t k then Or_error.error_string "element already exists" + else begin + Hashtbl.set t ~key:k ~data:(); + Result.Ok () + end + ;; + + let strict_add_exn t k = Or_error.ok_exn (strict_add t k) + + let remove = Hashtbl.remove + + let strict_remove t k = + if mem t k then begin + remove t k; + Result.Ok () + end else + Or_error.error "element not in set" k (Hashtbl.sexp_of_key t) + ;; + + let strict_remove_exn t k = Or_error.ok_exn (strict_remove t k) + + let fold t ~init ~f = Hashtbl.fold t ~init ~f:(fun ~key ~data:() acc -> f acc key) + let iter t ~f = Hashtbl.iter_keys t ~f + + let count t ~f = Container.count ~fold t ~f + let sum m t ~f = Container.sum ~fold m t ~f + let min_elt t ~compare = Container.min_elt ~fold t ~compare + let max_elt t ~compare = Container.max_elt ~fold t ~compare + let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t + let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t + + let to_list = Hashtbl.keys + + let sexp_of_t sexp_of_e t = + sexp_of_list sexp_of_e (to_list t |> List.sort ~compare:(hashable t).compare) + + let to_array t = + let len = length t in + let index = ref (len - 1) in + fold t ~init:[||] ~f:(fun acc key -> + if Array.length acc = 0 then Array.create ~len key + else begin + index := !index - 1; + Array.set acc (!index) key; + acc + end) + + let exists t ~f = Hashtbl.existsi t ~f:(fun ~key ~data:() -> f key) + let for_all t ~f = not (Hashtbl.existsi t ~f:(fun ~key ~data:() -> not (f key))) + + let equal t1 t2 = Hashtbl.equal t1 t2 (fun () () -> true) + + let copy t = Hashtbl.copy t + + let filter t ~f = Hashtbl.filteri t ~f:(fun ~key ~data:() -> f key) + + let diff t1 t2 = filter t1 ~f:(fun key -> not (Hashtbl.mem t2 key)) + + let inter t1 t2 = + let smaller, larger = if length t1 > length t2 then (t2, t1) else (t1, t2) in + Hashtbl.filteri smaller ~f:(fun ~key ~data:() -> Hashtbl.mem larger key) + + let filter_inplace t ~f = + let to_remove = + fold t ~init:[] ~f:(fun ac x -> + if f x then ac else x :: ac) + in + List.iter to_remove ~f:(fun x -> remove t x) + ;; + + let of_hashtbl_keys hashtbl = Hashtbl.map hashtbl ~f:ignore + + let to_hashtbl t ~f = Hashtbl.mapi t ~f:(fun ~key ~data:() -> f key) +end + +include Accessors + +let create ?growth_allowed ?size m = + Hashtbl.create ?growth_allowed ?size m +;; + +let of_list ?growth_allowed ?size m l = + let size = match size with Some x -> x | None -> List.length l in + let t = Hashtbl.create ?growth_allowed ~size m in + List.iter l ~f:(fun k -> add t k); + t +;; + +let t_of_sexp m e_of_sexp sexp = + match sexp with + | Sexp.Atom _ -> + raise (Of_sexp_error (Failure "Hash_set.t_of_sexp requires a list", sexp)) + | Sexp.List list -> + let t = create m ~size:(List.length list) in + List.iter list ~f:(fun sexp -> + let e = e_of_sexp sexp in + match strict_add t e with + | Ok () -> () + | Error _ -> + raise (Of_sexp_error + (Error.to_exn + (Error.create "Hash_set.t_of_sexp got a duplicate element" + sexp Fn.id), + sexp))); + t +;; + +module Creators (Elt : sig + type 'a t + + val hashable : 'a t Hashable.t + end) : sig + + type 'a t_ = 'a Elt.t t + + val t_of_sexp : (Sexp.t -> 'a Elt.t) -> Sexp.t -> 'a t_ + + include Creators_generic + with type 'a t := 'a t_ + with type 'a elt := 'a Elt.t + with type ('elt, 'z) create_options := ('elt, 'z) create_options_without_first_class_module + +end = struct + + type 'a t_ = 'a Elt.t t + + let create ?growth_allowed ?size () = + create ?growth_allowed ?size (Hashable.to_key Elt.hashable) + + let of_list ?growth_allowed ?size l = + of_list ?growth_allowed ?size (Hashable.to_key Elt.hashable) l + + let t_of_sexp e_of_sexp sexp = + t_of_sexp (Hashable.to_key Elt.hashable) e_of_sexp sexp +end + +module Poly = struct + + type 'a t = 'a hash_set + + type 'a elt = 'a + + let hashable = poly_hashable + + include Creators (struct + type 'a t = 'a + let hashable = hashable + end) + + include Accessors + + let sexp_of_t = sexp_of_t + +end + +module M (Elt : T.T) = struct + type nonrec t = Elt.t t +end +module type Sexp_of_m = sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] + include Hashtbl_intf.Key with type t := t +end + +let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t = + sexp_of_t Elt.sexp_of_t t + +let m__t_of_sexp (type elt) (module Elt : M_of_sexp with type t = elt) sexp = + t_of_sexp (module Elt) Elt.t_of_sexp sexp + +module Private = struct + let hashable = Hashtbl.Private.hashable +end diff --git a/src/hash_set.mli b/src/hash_set.mli new file mode 100644 index 0000000..f69262f --- /dev/null +++ b/src/hash_set.mli @@ -0,0 +1 @@ +include Hash_set_intf.Hash_set (** @inline *) diff --git a/src/hash_set_intf.ml b/src/hash_set_intf.ml new file mode 100644 index 0000000..cbda416 --- /dev/null +++ b/src/hash_set_intf.ml @@ -0,0 +1,204 @@ +open! Import + +module type Key = Hashtbl_intf.Key + +module type Accessors = sig + include Container.Generic + + val mem : 'a t -> 'a -> bool (** override [Container.Generic.mem] *) + + val copy : 'a t -> 'a t (** preserves the equality function *) + + val add : 'a t -> 'a -> unit + + (** [strict_add t x] returns [Ok ()] if the [x] was not in [t], or an [Error] if it + was. *) + val strict_add : 'a t -> 'a -> unit Or_error.t + val strict_add_exn : 'a t -> 'a -> unit + + val remove : 'a t -> 'a -> unit + + (** [strict_remove t x] returns [Ok ()] if the [x] was in [t], or an [Error] if it + was not. *) + val strict_remove : 'a t -> 'a -> unit Or_error.t + val strict_remove_exn : 'a t -> 'a -> unit + + val clear : 'a t -> unit + val equal : 'a t -> 'a t -> bool + val filter : 'a t -> f:('a -> bool) -> 'a t + val filter_inplace : 'a t -> f:('a -> bool) -> unit + + (** [inter t1 t2] computes the set intersection of [t1] and [t2]. Runs in O(min(length + t1, length t2)). Behavior is undefined if [t1] and [t2] don't have the same + equality function. *) + val inter : 'key t -> 'key t -> 'key t + val diff : 'a t -> 'a t -> 'a t + + val of_hashtbl_keys : ('a, _) Hashtbl.t -> 'a t + val to_hashtbl : 'key t -> f:('key -> 'data) -> ('key, 'data) Hashtbl.t +end + +type ('key, 'z) create_options = + ('key, unit, 'z) Hashtbl_intf.create_options + +type ('key, 'z) create_options_without_first_class_module = + ('key, unit, 'z) Hashtbl_intf.create_options_without_first_class_module + +module type Creators = sig + type 'a t + + val create + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> 'a t + val of_list + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> 'a list + -> 'a t +end + +module type Creators_generic = sig + type 'a t + type 'a elt + type ('a, 'z) create_options + + val create : ('a, unit -> 'a t) create_options + val of_list : ('a, 'a elt list -> 'a t) create_options +end + +module Check = struct + module Make_creators_check (Type : T.T1) (Elt : T.T1) (Options : T.T2) + (M : Creators_generic + with type 'a t := 'a Type.t + with type 'a elt := 'a Elt.t + with type ('a, 'z) create_options := ('a, 'z) Options.t) + = struct end + + module Check_creators_is_specialization_of_creators_generic (M : Creators) = + Make_creators_check + (struct type 'a t = 'a M.t end) + (struct type 'a t = 'a end) + (struct type ('a, 'z) t = ('a, 'z) create_options end) + (struct + include M + + let create ?growth_allowed ?size m () = + create ?growth_allowed ?size m + end) +end + +module type Hash_set = sig + + type 'a t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** We use [[@@deriving_inline sexp_of][@@@end]] but not [[@@deriving sexp]] because we want people to be + explicit about the hash and comparison functions used when creating hashtables. One + can use [Hash_set.Poly.t], which does have [[@@deriving_inline sexp][@@@end]], to use polymorphic + comparison and hashing. *) + + module type Creators = Creators + module type Creators_generic = Creators_generic + + type nonrec ('key, 'z) create_options = + ('key, 'z) create_options + + include Creators + with type 'a t := 'a t (** @open *) + + module type Accessors = Accessors + + include Accessors with type 'a t := 'a t with type 'a elt = 'a (** @open *) + + val hashable_s : 'key t -> (module Hashtbl_intf.Key with type t = 'key) + + type nonrec ('key, 'z) create_options_without_first_class_module = + ('key, 'z) create_options_without_first_class_module + + (** A hash set that uses polymorphic comparison *) + module Poly : sig + + type nonrec 'a t = 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Creators_generic + with type 'a t := 'a t + with type 'a elt = 'a + with type ('key, 'z) create_options := + ('key, 'z) create_options_without_first_class_module + + include Accessors with type 'a t := 'a t with type 'a elt := 'a elt + + end + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_hash_set = Hash_set.M(String).t + ]} + + which stands for: + + {[ + type string_hash_set = (String.t, int) Hash_set.t + ]} + + The point is that [Hash_set.M(String).t] supports deriving, whereas the second + syntax doesn't (because [t_of_sexp] doesn't know what comparison/hash function to + use). *) + module M (Elt : T.T) : sig + type nonrec t = Elt.t t + end + module type Sexp_of_m = sig + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] + include Hashtbl_intf.Key with type t := t + end + val sexp_of_m__t : (module Sexp_of_m with type t = 'elt) -> 'elt t -> Sexp.t + val m__t_of_sexp : (module M_of_sexp with type t = 'elt) -> Sexp.t -> 'elt t + + module Creators (Elt : sig + type 'a t + val hashable : 'a t Hashable.t + end) : sig + type 'a t_ = 'a Elt.t t + val t_of_sexp : (Sexp.t -> 'a Elt.t) -> Sexp.t -> 'a t_ + include Creators_generic + with type 'a t := 'a t_ + with type 'a elt := 'a Elt.t + with type ('elt, 'z) create_options := + ('elt, 'z) create_options_without_first_class_module + end + + (**/**) + (*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) + module Private : sig + val hashable : 'a t -> 'a Hashable.t + end +end diff --git a/src/hash_stubs.c b/src/hash_stubs.c new file mode 100644 index 0000000..3aefec5 --- /dev/null +++ b/src/hash_stubs.c @@ -0,0 +1,26 @@ +#include <stdint.h> +#include <caml/mlvalues.h> +#include <caml/hash.h> + +/* Final mix and return from the hash.c implementation from INRIA */ +#define FINAL_MIX_AND_RETURN(h) \ + h ^= h >> 16; \ + h *= 0x85ebca6b; \ + h ^= h >> 13; \ + h *= 0xc2b2ae35; \ + h ^= h >> 16; \ + return Val_int(h & 0x3FFFFFFFU); + +CAMLprim value Base_hash_string (value string) +{ + uint32_t h; + h = caml_hash_mix_string (0, string); + FINAL_MIX_AND_RETURN(h) +} + +CAMLprim value Base_hash_double (value d) +{ + uint32_t h; + h = caml_hash_mix_double (0, Double_val(d)); + FINAL_MIX_AND_RETURN (h); +} diff --git a/src/hashable.ml b/src/hashable.ml new file mode 100644 index 0000000..3b7b275 --- /dev/null +++ b/src/hashable.ml @@ -0,0 +1,4 @@ +open! Import + + +include Hashable_intf diff --git a/src/hashable.mli b/src/hashable.mli new file mode 100644 index 0000000..5678db4 --- /dev/null +++ b/src/hashable.mli @@ -0,0 +1,5 @@ +open! Import + +module type Key = Hashable_intf.Key +module type Hashable = Hashable_intf.Hashable +include Hashable (** @inline *) diff --git a/src/hashable_intf.ml b/src/hashable_intf.ml new file mode 100644 index 0000000..cca6dfe --- /dev/null +++ b/src/hashable_intf.ml @@ -0,0 +1,93 @@ +open! Import + +module type Key = sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** Values returned by [hash] must be non-negative. An exception will be raised in the + case that [hash] returns a negative value. *) + val hash : t -> int +end + +module Hashable = struct + type 'a t = + { hash : 'a -> int; + compare : 'a -> 'a -> int; + sexp_of_t : 'a -> Sexp.t; + } + + (** This function is sound but not complete, meaning that if it returns [true] then it's + safe to use the two interchangeably. If it's [false], you have no guarantees. For + example: + + {[ + > utop + open Core;; + let equal (a : 'a Hashtbl_intf.Hashable.t) b = + phys_equal a b + || (phys_equal a.hash b.hash + && phys_equal a.compare b.compare + && phys_equal a.sexp_of_t b.sexp_of_t) + ;; + let a = Hashtbl_intf.Hashable.{ hash; compare; sexp_of_t = Int.sexp_of_t };; + let b = Hashtbl_intf.Hashable.{ hash; compare; sexp_of_t = Int.sexp_of_t };; + equal a b;; (* false?! *) + ]} + *) + let equal a b = + phys_equal a b + || (phys_equal a.hash b.hash + && phys_equal a.compare b.compare + && phys_equal a.sexp_of_t b.sexp_of_t) + ;; + + let hash_param = Caml.Hashtbl.hash_param + let hash = Caml.Hashtbl.hash + + let poly = { hash; + compare = Poly.compare; + sexp_of_t = (fun _ -> Sexp.Atom "_"); + } + + let of_key (type a) (module Key : Key with type t = a) = + { hash = Key.hash; + compare = Key.compare; + sexp_of_t = Key.sexp_of_t; + } + ;; + + let to_key (type a) { hash; compare; sexp_of_t } = + (module struct + type t = a + let hash = hash + let compare = compare + let sexp_of_t = sexp_of_t + end : Key with type t = a) + ;; +end +include Hashable + +module type Hashable = sig + type 'a t = 'a Hashable.t = + { hash : 'a -> int; + compare : 'a -> 'a -> int; + sexp_of_t : 'a -> Sexp.t; + } + + val equal : 'a t -> 'a t -> bool + + val poly : 'a t + + val of_key : (module Key with type t = 'a) -> 'a t + val to_key : 'a t -> (module Key with type t = 'a) + + val hash_param : int -> int -> 'a -> int + + val hash : 'a -> int +end diff --git a/src/hasher.ml b/src/hasher.ml new file mode 100644 index 0000000..06e04ef --- /dev/null +++ b/src/hasher.ml @@ -0,0 +1,56 @@ +open! Import + +(** Signatures required of types which can be used in [[@@deriving_inline hash][@@@end]]. *) +(*_ JS-only: For a more in-depth discussion, see documentation of ppx_hash, available in + ppx/ppx_hash/README.md and ppx/ppx_hash/doc/design.notes. *) + +module type S = sig + + (** The type that is hashed. *) + type t + + (** [hash_fold_t state x] mixes the content of [x] into the [state]. + + By default, all our [hash_fold_t] functions (derived or not) should satisfy the + following properties. + + 1. [hash_fold_t state x] should mix all the information present in [x] in the state. + That is, by default, [hash_fold_t] will traverse the full term [x] (this is a + significant change for Hashtbl.hash which by default stops traversing the term after + after considering a small number of "significant values"). [hash_fold_t] must not + discard the [state]. + + 2. [hash_fold_t] must be compatible with the associated [compare] function: that is, + for all [x] [y] and [s], [compare x y = 0] must imply [hash_fold_t s x = hash_fold_t + s y]. + + 3. To avoid avoid systematic collisions, [hash_fold_t] should expand to different + sequences of built-in mixing functions for different values of [x]. No such sequence + is allowed to be a prefix of another. + + A common mistake is to implement [hash_fold_t] of a collection by just folding all + the elements. This makes the folding sequence of [a] be a prefix of [a @ b], thereby + violating the requirement. This creates large families of collisions: all of the + following collections would hash the same: + + {v + [[]; [1;2;3]] + [[1]; [2;3]] + [[1; 2]; [3]] + [[1; 2; 3]; []] + [[1]; [2]; []; [3];] + ... + v} + + A good way to avoid this is to mix in the size of the collection to the beginning + ([fold ~init:(hash_fold_int state length) ~f:hash_fold_elem]). The default in our + libraries is to mix the length of the structure before folding. To prevent the + aforementioned collisions, one should respect this ordering. + *) + val hash_fold_t : Hash.state -> t -> Hash.state +end + +module type S1 = sig + type 'a t + val hash_fold_t : (Hash.state -> 'a -> Hash.state) -> Hash.state -> 'a t -> Hash.state +end diff --git a/src/hashtbl.ml b/src/hashtbl.ml new file mode 100644 index 0000000..dcc09ac --- /dev/null +++ b/src/hashtbl.ml @@ -0,0 +1,880 @@ +open! Import + +include Hashtbl_intf + +let with_return = With_return.with_return + +let hash_param = Hashable.hash_param +let hash = Hashable.hash + + +type ('k, 'v) t = + { mutable table : ('k, 'v) Avltree.t array + ; mutable length : int + (* [recently_added] is the reference passed to [Avltree.add]. We put it in the hash + table to avoid allocating it at every [set]. *) + ; recently_added : bool ref + ; growth_allowed : bool + ; hashable : 'k Hashable.t + ; mutable mutation_allowed : bool (* Set during all iteration operations *) + } + +type 'a key = 'a + +let sexp_of_key t = t.hashable.Hashable.sexp_of_t +let compare_key t = t.hashable.Hashable.compare + +let ensure_mutation_allowed t = + if not t.mutation_allowed then failwith "Hashtbl: mutation not allowed during iteration" +;; + +let without_mutating t f = + if t.mutation_allowed + then + begin + t.mutation_allowed <- false; + match f () with + | x -> t.mutation_allowed <- true; x + | exception exn -> t.mutation_allowed <- true; raise exn + end + else + f () +;; + +(** Internally use a maximum size that is a power of 2. Reverses the above to find the + floor power of 2 below the system max array length *) +let max_table_length = Int.floor_pow2 Array.max_length ;; + +let create ?(growth_allowed = true) ?(size = 128) ~hashable () = + let size = Int.min (Int.max 1 size) max_table_length in + let size = Int.ceil_pow2 size in + { table = Array.create ~len:size Avltree.empty + ; length = 0 + ; growth_allowed = growth_allowed + ; recently_added = ref false + ; hashable + ; mutation_allowed = true + } +;; + +(** Supplemental hash. This may not be necessary, it is intended as a defense against poor + hash functions, for which the power of 2 sized table will be especially sensitive. + With some testing we may choose to add it, but this table is designed to be robust to + collisions, and in most of my testing this degrades performance. *) +let _supplemental_hash h = + let h = h lxor ((h lsr 20) lxor (h lsr 12)) in + h lxor (h lsr 7) lxor (h lsr 4) +;; + +let slot t key = + let hash = t.hashable.Hashable.hash key in + (* this is always non-negative because we do [land] with non-negative number *) + hash land ((Array.length t.table) - 1) +;; + +let add_worker t ~replace ~key ~data = + let i = slot t key in + let root = t.table.(i) in + let added = t.recently_added in + added := false; + let new_root = + (* The avl tree might replace the value [replace=true] or do nothing [replace=false] + to the entry, in that case the table did not get bigger, so we should not + increment length, we pass in the bool ref t.added so that it can tell us whether + it added or replaced. We do it this way to avoid extra allocation. Since the bool + is an immediate it does not go through the write barrier. *) + Avltree.add ~replace root ~compare:(compare_key t) ~added ~key ~data + in + if !added then + t.length <- t.length + 1; + (* This little optimization saves a caml_modify when the tree + hasn't been rebalanced. *) + if not (phys_equal new_root root) then + t.table.(i) <- new_root +;; + +let maybe_resize_table t = + let len = Array.length t.table in + let should_grow = t.length > len in + if should_grow && t.growth_allowed then begin + let new_array_length = Int.min (len * 2) max_table_length in + if new_array_length > len then begin + let new_table = + Array.create ~len:new_array_length Avltree.empty + in + let old_table = t.table in + t.table <- new_table; + t.length <- 0; + let f ~key ~data = add_worker ~replace:true t ~key ~data in + for i = 0 to Array.length old_table - 1 do + Avltree.iter old_table.(i) ~f + done + end + end +;; + +let set t ~key ~data = + ensure_mutation_allowed t; + add_worker ~replace:true t ~key ~data; + maybe_resize_table t +;; + +let replace = set + +let add t ~key ~data = + ensure_mutation_allowed t; + add_worker ~replace:false t ~key ~data; + if !(t.recently_added) then begin + maybe_resize_table t; + `Ok + end else + `Duplicate +;; + +let add_exn t ~key ~data = + match add t ~key ~data with + | `Ok -> () + | `Duplicate -> + let sexp_of_key = sexp_of_key t in + let error = Error.create "Hashtbl.add_exn got key already present" key sexp_of_key in + Error.raise error +;; + +let clear t = + ensure_mutation_allowed t; + for i = 0 to Array.length t.table - 1 do + t.table.(i) <- Avltree.empty; + done; + t.length <- 0 +;; + +let find_and_call t key ~if_found ~if_not_found = + (* with a good hash function these first two cases will be the overwhelming majority, + and Avltree.find is recursive, so it can't be inlined, so doing this avoids a + function call in most cases. *) + match t.table.(slot t key) with + | Avltree.Empty -> if_not_found key + | Avltree.Leaf { key = k; value = v } -> + if compare_key t k key = 0 then if_found v + else if_not_found key + | tree -> + Avltree.find_and_call tree ~compare:(compare_key t) key ~if_found ~if_not_found +;; + +let findi_and_call t key ~if_found ~if_not_found = + (* with a good hash function these first two cases will be the overwhelming majority, + and Avltree.find is recursive, so it can't be inlined, so doing this avoids a + function call in most cases. *) + match t.table.(slot t key) with + | Avltree.Empty -> if_not_found key + | Avltree.Leaf { key = k; value = v } -> + if compare_key t k key = 0 then if_found ~key:k ~data:v + else if_not_found key + | tree -> + Avltree.findi_and_call tree ~compare:(compare_key t) key ~if_found ~if_not_found +;; + +let find = + let if_found v = Some v in + let if_not_found _ = None in + fun t key -> + find_and_call t key ~if_found ~if_not_found +;; + +let mem t key = + match t.table.(slot t key) with + | Avltree.Empty -> false + | Avltree.Leaf { key = k; value = _ } -> compare_key t k key = 0 + | tree -> Avltree.mem tree ~compare:(compare_key t) key +;; + +let remove t key = + ensure_mutation_allowed t; + let i = slot t key in + let root = t.table.(i) in + let added_or_removed = t.recently_added in + added_or_removed := false; + let new_root = + Avltree.remove root + ~removed:added_or_removed ~compare:(compare_key t) key + in + if not (phys_equal root new_root) then + t.table.(i) <- new_root; + if !added_or_removed then + t.length <- t.length - 1 +;; + +let length t = t.length + +let is_empty t = length t = 0 + +let fold t ~init ~f = + if length t = 0 then init + else begin + let n = Array.length t.table in + let acc = ref init in + let m = t.mutation_allowed in + match + t.mutation_allowed <- false; + for i = 0 to n - 1 do + match Array.unsafe_get t.table i with + | Avltree.Empty -> () + | Avltree.Leaf { key; value = data } -> acc := f ~key ~data !acc + | bucket -> acc := Avltree.fold bucket ~init:!acc ~f + done + with + | () -> + t.mutation_allowed <- m; + !acc + | exception exn -> + t.mutation_allowed <- m; + raise exn + end +;; + +let iteri t ~f = + if t.length = 0 then () + else begin + let n = Array.length t.table in + let m = t.mutation_allowed in + match + t.mutation_allowed <- false; + for i = 0 to n - 1 do + match Array.unsafe_get t.table i with + | Avltree.Empty -> () + | Avltree.Leaf { key; value = data } -> f ~key ~data + | bucket -> Avltree.iter bucket ~f + done + with + | () -> + t.mutation_allowed <- m + | exception exn -> + t.mutation_allowed <- m; + raise exn + end +;; + +let iter t ~f = iteri t ~f:(fun ~key:_ ~data -> f data) +let iter_keys t ~f = iteri t ~f:(fun ~key ~data:_ -> f key) + +let invariant invariant_key invariant_data t = + for i = 0 to Array.length t.table - 1 do + Avltree.invariant t.table.(i) ~compare:(compare_key t) + done; + let real_len = + fold t ~init:0 ~f:(fun ~key ~data i -> + invariant_key key; + invariant_data data; + i + 1) + in + assert (real_len = t.length); +;; + +let find_exn = + let if_found v = v in + let if_not_found _ = raise Caml.Not_found in + fun t key -> + find_and_call t key ~if_found ~if_not_found +;; + +let existsi t ~f = + with_return (fun r -> + iteri t ~f:(fun ~key ~data -> if f ~key ~data then r.return true); + false) +;; + +let exists t ~f = existsi t ~f:(fun ~key:_ ~data -> f data) +;; + +let for_alli t ~f = not (existsi t ~f:(fun ~key ~data -> not (f ~key ~data))) +let for_all t ~f = not (existsi t ~f:(fun ~key:_ ~data -> not (f data))) + +let counti t ~f = + fold t ~init:0 ~f:(fun ~key ~data acc -> if f ~key ~data then acc+1 else acc) +let count t ~f = + fold t ~init:0 ~f:(fun ~key:_ ~data acc -> if f data then acc+1 else acc) + +let mapi t ~f = + let new_t = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + iteri t ~f:(fun ~key ~data -> replace new_t ~key ~data:(f ~key ~data)); + new_t + +let map t ~f = mapi t ~f:(fun ~key:_ ~data -> f data) + +let copy t = map t ~f:Fn.id + +let filter_mapi t ~f = + let new_t = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + iteri t ~f:(fun ~key ~data -> + match f ~key ~data with + | Some new_data -> replace new_t ~key ~data:new_data + | None -> ()); + new_t + +let filter_map t ~f = filter_mapi t ~f:(fun ~key:_ ~data -> f data) + +let filteri t ~f = + filter_mapi t ~f:(fun ~key ~data -> if f ~key ~data then Some data else None) +;; + +let filter t ~f = filteri t ~f:(fun ~key:_ ~data -> f data) +let filter_keys t ~f = filteri t ~f:(fun ~key ~data:_ -> f key) + +let partition_mapi t ~f = + let t0 = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + let t1 = + create ~growth_allowed:t.growth_allowed + ~hashable:t.hashable ~size:t.length () + in + iteri t ~f:(fun ~key ~data -> + match f ~key ~data with + | `Fst new_data -> replace t0 ~key ~data:new_data + | `Snd new_data -> replace t1 ~key ~data:new_data); + (t0, t1) +;; + +let partition_map t ~f = partition_mapi t ~f:(fun ~key:_ ~data -> f data) + +let partitioni_tf t ~f = + partition_mapi t ~f:(fun ~key ~data -> if f ~key ~data then `Fst data else `Snd data) +;; + +let partition_tf t ~f = partitioni_tf t ~f:(fun ~key:_ ~data -> f data) + +let find_or_add t id ~default = + match find t id with + | Some x -> x + | None -> + let default = default () in + replace t ~key:id ~data:default; + default + +let findi_or_add t id ~default = + match find t id with + | Some x -> x + | None -> + let default = default id in + replace t ~key:id ~data:default; + default + +(* Some hashtbl implementations may be able to perform this more efficiently than two + separate lookups *) +let find_and_remove t id = + let result = find t id in + if Option.is_some result then remove t id; + result + + +let change t id ~f = + match f (find t id) with + | None -> remove t id + | Some data -> replace t ~key:id ~data +;; + +let update t id ~f = + set t ~key:id ~data:(f (find t id)) +;; + +let incr_by ~remove_if_zero t key by = + if remove_if_zero + then + change t key ~f:(fun opt -> + match by + Option.value opt ~default:0 with + | 0 -> None + | n -> Some n) + else + update t key ~f:(function + | None -> by + | Some i -> by + i) +;; + +let incr ?(by = 1) ?(remove_if_zero = false) t key = incr_by ~remove_if_zero t key by +let decr ?(by = 1) ?(remove_if_zero = false) t key = incr_by ~remove_if_zero t key (-by) +;; + +let add_multi t ~key ~data = + update t key ~f:(function + | None -> [ data ] + | Some l -> data :: l) +;; + +let remove_multi t key = + match find t key with + | None -> () + | Some [] | Some [_] -> remove t key + | Some (_ :: tl) -> replace t ~key ~data:tl + +let find_multi t key = + match find t key with + | None -> [] + | Some l -> l + +let create_mapped ?growth_allowed ?size ~hashable ~get_key ~get_data rows = + let size = match size with Some s -> s | None -> List.length rows in + let res = create ?growth_allowed ~hashable ~size () in + let dupes = ref [] in + List.iter rows ~f:(fun r -> + let key = get_key r in + let data = get_data r in + if mem res key then + dupes := key :: !dupes + else + replace res ~key ~data); + match !dupes with + | [] -> `Ok res + | keys -> `Duplicate_keys (List.dedup_and_sort ~compare:hashable.Hashable.compare keys) +;; + +(* + {[ + let create_mapped_exn ?growth_allowed ?size ~hashable ~get_key ~get_data rows = + let size = match size with Some s -> s | None -> List.length rows in + let res = create ?growth_allowed ~size ~hashable () in + List.iter rows ~f:(fun r -> + let key = get_key r in + let data = get_data r in + if mem res key then + let sexp_of_key = hashable.Hashable.sexp_of_t in + failwiths "Hashtbl.create_mapped_exn: duplicate key" key <:sexp_of< key >> + else + replace res ~key ~data); + res + ;; + ]} *) + +let create_mapped_multi ?growth_allowed ?size ~hashable ~get_key ~get_data rows = + let size = match size with Some s -> s | None -> List.length rows in + let res = create ?growth_allowed ~size ~hashable () in + List.iter rows ~f:(fun r -> + let key = get_key r in + let data = get_data r in + add_multi res ~key ~data); + res +;; + +let of_alist ?growth_allowed ?size ~hashable lst = + match create_mapped ?growth_allowed ?size ~hashable ~get_key:fst ~get_data:snd lst with + | `Ok t -> `Ok t + | `Duplicate_keys k -> `Duplicate_key (List.hd_exn k) +;; + +let of_alist_report_all_dups ?growth_allowed ?size ~hashable lst = + create_mapped ?growth_allowed ?size ~hashable ~get_key:fst ~get_data:snd lst +;; + +let of_alist_or_error ?growth_allowed ?size ~hashable lst = + match of_alist ?growth_allowed ?size ~hashable lst with + | `Ok v -> Result.Ok v + | `Duplicate_key key -> + let sexp_of_key = hashable.Hashable.sexp_of_t in + Or_error.error "Hashtbl.of_alist_exn: duplicate key" key sexp_of_key +;; + +let of_alist_exn ?growth_allowed ?size ~hashable lst = + match of_alist_or_error ?growth_allowed ?size ~hashable lst with + | Result.Ok v -> v + | Result.Error e -> Error.raise e +;; + +let of_alist_multi ?growth_allowed ?size ~hashable lst = + create_mapped_multi ?growth_allowed ?size ~hashable ~get_key:fst ~get_data:snd lst +;; + +let to_alist t = fold ~f:(fun ~key ~data list -> (key, data) :: list) ~init:[] t + +let sexp_of_t sexp_of_key sexp_of_data t = + t + |> to_alist + |> List.sort ~compare:(fun (k1, _) (k2, _) -> t.hashable.compare k1 k2) + |> sexp_of_list (sexp_of_pair sexp_of_key sexp_of_data) +;; + +let t_of_sexp ~hashable k_of_sexp d_of_sexp sexp = + let alist = list_of_sexp (pair_of_sexp k_of_sexp d_of_sexp) sexp in + of_alist_exn ~hashable alist ~size:(List.length alist) +;; + +let validate ~name f t = Validate.alist ~name f (to_alist t) + +let keys t = fold t ~init:[] ~f:(fun ~key ~data:_ acc -> key :: acc) + +let data t = fold ~f:(fun ~key:_ ~data list -> data::list) ~init:[] t + +let add_to_groups groups ~get_key ~get_data ~combine ~rows = + List.iter rows ~f:(fun row -> + let key = get_key row in + let data = get_data row in + let data = + match find groups key with + | None -> data + | Some old -> combine old data + in + replace groups ~key ~data) +;; + +let group ?growth_allowed ?size ~hashable ~get_key ~get_data ~combine rows = + let res = create ?growth_allowed ?size ~hashable () in + add_to_groups res ~get_key ~get_data ~combine ~rows; + res +;; + +let create_with_key ?growth_allowed ?size ~hashable ~get_key rows = + create_mapped ?growth_allowed ?size ~hashable ~get_key ~get_data:Fn.id rows +;; + +let create_with_key_or_error ?growth_allowed ?size ~hashable ~get_key rows = + match create_with_key ?growth_allowed ?size ~hashable ~get_key rows with + | `Ok t -> Result.Ok t + | `Duplicate_keys keys -> + let sexp_of_key = hashable.Hashable.sexp_of_t in + Or_error.error_s + (Sexp.message "Hashtbl.create_with_key: duplicate keys" + [ "keys", sexp_of_list sexp_of_key keys ]) +;; + +let create_with_key_exn ?growth_allowed ?size ~hashable ~get_key rows = + Or_error.ok_exn (create_with_key_or_error ?growth_allowed ?size ~hashable ~get_key rows) +;; + +let merge = + let maybe_set t ~key ~f d = + match f ~key d with + | None -> () + | Some v -> + set t ~key ~data:v + in + fun t_left t_right ~f -> + if not (Hashable.equal t_left.hashable t_right.hashable) + then invalid_arg "Hashtbl.merge: different 'hashable' values"; + let new_t = + create ~growth_allowed:t_left.growth_allowed + ~hashable:t_left.hashable ~size:t_left.length () + in + without_mutating t_left (fun () -> + without_mutating t_right (fun () -> + iteri t_left ~f:(fun ~key ~data:left -> + match find t_right key with + | None -> + maybe_set new_t ~key ~f (`Left left) + | Some right -> + maybe_set new_t ~key ~f (`Both (left, right)) + ); + iteri t_right ~f:(fun ~key ~data:right -> + match find t_left key with + | None -> + maybe_set new_t ~key ~f (`Right right) + | Some _ -> () (* already done above *) + ))); + new_t +;; + +type 'a merge_into_action = Remove | Set_to of 'a + +let merge_into ~src ~dst ~f = + iteri src ~f:(fun ~key ~data -> + let dst_data = find dst key in + let action = without_mutating dst (fun () -> f ~key data dst_data) in + match action with + | Remove -> remove dst key + | Set_to data -> + match dst_data with + | None -> replace dst ~key ~data + | Some dst_data -> + if not (phys_equal dst_data data) + then replace dst ~key ~data) +;; + +let filteri_inplace t ~f = + let to_remove = + fold t ~init:[] ~f:(fun ~key ~data ac -> + if f ~key ~data then ac else key :: ac) + in + List.iter to_remove ~f:(fun key -> remove t key); +;; + +let filter_inplace t ~f = + filteri_inplace t ~f:(fun ~key:_ ~data -> f data) +;; + +let filter_keys_inplace t ~f = + filteri_inplace t ~f:(fun ~key ~data:_ -> f key) +;; + +let filter_mapi_inplace t ~f = + let map_results = + fold t ~init:[] ~f:(fun ~key ~data ac -> (key, f ~key ~data) :: ac) + in + List.iter map_results ~f:(fun (key,result) -> + match result with + | None -> remove t key + | Some data -> set t ~key ~data + ); +;; + +let filter_map_inplace t ~f = + filter_mapi_inplace t ~f:(fun ~key:_ ~data -> f data) + +let mapi_inplace t ~f = + ensure_mutation_allowed t; + without_mutating t (fun () -> Array.iter t.table ~f:(Avltree.mapi_inplace ~f)) + +let map_inplace t ~f = + mapi_inplace t ~f:(fun ~key:_ ~data -> f data) + +let equal t t' equal = + length t = length t' && + with_return (fun r -> + without_mutating t' (fun () -> + iteri t ~f:(fun ~key ~data -> + match find t' key with + | None -> r.return false + | Some data' -> + if not (equal data data') + then r.return false)); + true) +;; + +let similar = equal + +module Accessors = struct + + type nonrec 'a merge_into_action = 'a merge_into_action = Remove | Set_to of 'a + + let invariant = invariant + let clear = clear + let copy = copy + let remove = remove + let set = set + let add = add + let add_exn = add_exn + let change = change + let update = update + let add_multi = add_multi + let remove_multi = remove_multi + let find_multi = find_multi + let mem = mem + let iter_keys = iter_keys + let iter = iter + let iteri = iteri + let exists = exists + let existsi = existsi + let for_all = for_all + let for_alli = for_alli + let count = count + let counti = counti + let fold = fold + let length = length + let is_empty = is_empty + let map = map + let mapi = mapi + let filter_map = filter_map + let filter_mapi = filter_mapi + let filter_keys = filter_keys + let filter = filter + let filteri = filteri + let partition_map = partition_map + let partition_mapi = partition_mapi + let partition_tf = partition_tf + let partitioni_tf = partitioni_tf + let find_or_add = find_or_add + let findi_or_add = findi_or_add + let find = find + let find_exn = find_exn + let find_and_call = find_and_call + let findi_and_call = findi_and_call + let find_and_remove = find_and_remove + let to_alist = to_alist + let validate = validate + let merge = merge + let merge_into = merge_into + let keys = keys + let data = data + let filter_keys_inplace = filter_keys_inplace + let filter_inplace = filter_inplace + let filteri_inplace = filteri_inplace + let map_inplace = map_inplace + let mapi_inplace = mapi_inplace + let filter_map_inplace = filter_map_inplace + let filter_mapi_inplace = filter_mapi_inplace + let equal = equal + let similar = similar + let incr = incr + let decr = decr + let sexp_of_key = sexp_of_key +end + +module Creators (Key : sig + type 'a t + + val hashable : 'a t Hashable.t + end) : sig + + type ('a, 'b) t_ = ('a Key.t, 'b) t + + val t_of_sexp : (Sexp.t -> 'a Key.t) -> (Sexp.t -> 'b) -> Sexp.t -> ('a, 'b) t_ + + include Creators_generic + with type ('a, 'b) t := ('a, 'b) t_ + with type 'a key := 'a Key.t + with type ('key, 'data, 'a) create_options := + ('key, 'data, 'a) create_options_without_first_class_module + +end = struct + + let hashable = Key.hashable + + type ('a, 'b) t_ = ('a Key.t, 'b) t + + let create ?growth_allowed ?size () = create ?growth_allowed ?size ~hashable () + + let of_alist ?growth_allowed ?size l = + of_alist ?growth_allowed ~hashable ?size l + ;; + + let of_alist_report_all_dups ?growth_allowed ?size l = + of_alist_report_all_dups ?growth_allowed ~hashable ?size l + ;; + + let of_alist_or_error ?growth_allowed ?size l = + of_alist_or_error ?growth_allowed ~hashable ?size l + ;; + + let of_alist_exn ?growth_allowed ?size l = + of_alist_exn ?growth_allowed ~hashable ?size l + ;; + + let t_of_sexp k_of_sexp d_of_sexp sexp = + t_of_sexp ~hashable k_of_sexp d_of_sexp sexp + ;; + + let of_alist_multi ?growth_allowed ?size l = + of_alist_multi ?growth_allowed ~hashable ?size l + ;; + + let create_mapped ?growth_allowed ?size ~get_key ~get_data l = + create_mapped ?growth_allowed ~hashable ?size ~get_key ~get_data l + ;; + + let create_with_key ?growth_allowed ?size ~get_key l = + create_with_key ?growth_allowed ~hashable ?size ~get_key l + ;; + + let create_with_key_or_error ?growth_allowed ?size ~get_key l = + create_with_key_or_error ?growth_allowed ~hashable ?size ~get_key l + ;; + + let create_with_key_exn ?growth_allowed ?size ~get_key l = + create_with_key_exn ?growth_allowed ~hashable ?size ~get_key l + ;; + + let group ?growth_allowed ?size ~get_key ~get_data ~combine l = + group ?growth_allowed ~hashable ?size ~get_key ~get_data ~combine l + ;; +end + +module Poly = struct + + type nonrec ('a, 'b) t = ('a, 'b) t + + type 'a key = 'a + + let hashable = Hashable.poly + + include Creators (struct + type 'a t = 'a + let hashable = hashable + end) + + include Accessors + + let sexp_of_t = sexp_of_t +end + +module Private = struct + module type Creators_generic = Creators_generic + module type Hashable = Hashable.Hashable + + type nonrec ('key, 'data, 'z) create_options_without_first_class_module = + ('key, 'data, 'z) create_options_without_first_class_module + + let hashable t = t.hashable +end + +let create ?growth_allowed ?size m = + create ~hashable:(Hashable.of_key m) ?growth_allowed ?size () +let of_alist ?growth_allowed ?size m l = + of_alist ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_report_all_dups ?growth_allowed ?size m l = + of_alist_report_all_dups ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_or_error ?growth_allowed ?size m l = + of_alist_or_error ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_exn ?growth_allowed ?size m l = + of_alist_exn ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let of_alist_multi ?growth_allowed ?size m l = + of_alist_multi ~hashable:(Hashable.of_key m) ?growth_allowed ?size l +let create_mapped ?growth_allowed ?size m ~get_key ~get_data l = + create_mapped ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key ~get_data l +let create_with_key ?growth_allowed ?size m ~get_key l = + create_with_key ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key l +let create_with_key_or_error ?growth_allowed ?size m ~get_key l = + create_with_key_or_error ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key l +let create_with_key_exn ?growth_allowed ?size m ~get_key l = + create_with_key_exn ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key l +let group ?growth_allowed ?size m ~get_key ~get_data ~combine l = + group ~hashable:(Hashable.of_key m) ?growth_allowed ?size ~get_key ~get_data ~combine l + +let hashable_s t = Hashable.to_key t.hashable + +module M (K : T.T) = struct + type nonrec 'v t = (K.t, 'v) t +end +module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Key with type t := t +end + +let sexp_of_m__t (type k) (module K : Sexp_of_m with type t = k) sexp_of_v t = + sexp_of_t K.sexp_of_t sexp_of_v t + +let m__t_of_sexp (type k) (module K : M_of_sexp with type t = k) v_of_sexp sexp = + t_of_sexp ~hashable:(Hashable.of_key (module K)) K.t_of_sexp v_of_sexp sexp + +(* typechecking this code is a compile-time test that [Creators] is a specialization of + [Creators_generic]. *) +module Check : sig end = struct + module Make_creators_check (Type : T.T2) (Key : T.T1) (Options : T.T3) + (M : Creators_generic + with type ('a, 'b) t := ('a, 'b) Type.t + with type 'a key := 'a Key.t + with type ('a, 'b, 'z) create_options := ('a, 'b, 'z) Options.t) + = struct end + + module Check_creators_is_specialization_of_creators_generic (M : Creators) = + Make_creators_check + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type 'a t = 'a end) + (struct type ('a, 'b, 'z) t = ('a, 'b, 'z) create_options end) + (struct + include M + + let create ?growth_allowed ?size m () = + create ?growth_allowed ?size m + end) +end diff --git a/src/hashtbl.mli b/src/hashtbl.mli new file mode 100644 index 0000000..e011dea --- /dev/null +++ b/src/hashtbl.mli @@ -0,0 +1 @@ +include Hashtbl_intf.Hashtbl (** @inline *) diff --git a/src/hashtbl_intf.ml b/src/hashtbl_intf.ml new file mode 100644 index 0000000..99f7913 --- /dev/null +++ b/src/hashtbl_intf.ml @@ -0,0 +1,769 @@ +open! Import + +module type Key = sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** Two [t]s that [compare] equal must have equal hashes for the hashtable + to behave properly. *) + val hash : t -> int +end + +module type Accessors = sig + (** {2 Accessors} *) + + type ('a, 'b) t + type 'a key + + val sexp_of_key : ('a, _) t -> 'a key -> Sexp.t + val clear : (_, _) t -> unit + val copy : ('a, 'b) t -> ('a, 'b) t + + (** Attempting to modify ([set], [remove], etc.) the hashtable during iteration ([fold], + [iter], [iter_keys], [iteri]) will raise an exception. *) + val fold : ('a, 'b) t -> init:'c -> f:(key:'a key -> data:'b -> 'c -> 'c) -> 'c + + val iter_keys : ('a, _) t -> f:( 'a key -> unit) -> unit + val iter : ( _, 'b) t -> f:( 'b -> unit) -> unit + + (** Iterates over both keys and values. + + Example: + + {v + let h = Hashtbl.of_alist_exn (module Int) [(1, 4); (5, 6)] in + Hashtbl.iteri h ~f:(fun ~key ~data -> + print_endline (Printf.sprintf "%d-%d" key data));; + 1-4 + 5-6 + - : unit = () + v} *) + val iteri : ('a, 'b) t -> f:(key:'a key -> data:'b -> unit) -> unit + + val existsi : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> bool + val exists : (_ , 'b) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> bool + val for_all : (_ , 'b) t -> f:( 'b -> bool) -> bool + val counti : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> int + val count : (_ , 'b) t -> f:( 'b -> bool) -> int + + val length : (_, _) t -> int + val is_empty : (_, _) t -> bool + val mem : ('a, _) t -> 'a key -> bool + val remove : ('a, _) t -> 'a key -> unit + + (** Sets the given [key] to [data]. *) + val set : ('a, 'b) t -> key:'a key -> data:'b -> unit + + (** [add] and [add_exn] leave the table unchanged if the key was already present. *) + val add : ('a, 'b) t -> key:'a key -> data:'b -> [ `Ok | `Duplicate ] + val add_exn : ('a, 'b) t -> key:'a key -> data:'b -> unit + + (** [change t key ~f] changes [t]'s value for [key] to be [f (find t key)]. *) + val change : ('a, 'b) t -> 'a key -> f:('b option -> 'b option) -> unit + + (** [update t key ~f] is [change t key ~f:(fun o -> Some (f o))]. *) + val update : ('a, 'b) t -> 'a key -> f:('b option -> 'b) -> unit + + (** [map t f] returns a new table with values replaced by the result of applying [f] + to the current values. + + Example: + + {v + let h = Hashtbl.of_alist_exn (module Int) [(1, 4); (5, 6)] in + let h' = Hashtbl.map h ~f:(fun x -> x * 2) in + Hashtbl.to_alist h';; + - : (int * int) list = [(5, 12); (1, 8)] + v} *) + val map : ('a, 'b) t -> f:('b -> 'c) -> ('a, 'c) t + + (** Like [map], but the function [f] takes both key and data as arguments. *) + val mapi : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'c) -> ('a, 'c) t + + (** Returns a new table by filtering the given table's values by [f]: the keys for which + [f] applied to the current value returns [Some] are kept, and those for which it + returns [None] are discarded. + + Example: + + {v + let h = Hashtbl.of_alist_exn (module Int) [(1, 4); (5, 6)] in + Hashtbl.filter_map h ~f:(fun x -> if x > 5 then Some x else None) + |> Hashtbl.to_alist;; + - : (int * int) list = [(5, 6)] + v} *) + val filter_map : ('a, 'b) t -> f:('b -> 'c option) -> ('a, 'c) t + + (** Like [filter_map], but the function [f] takes both key and data as arguments. *) + val filter_mapi + : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'c option) -> ('a, 'c) t + + val filter_keys : ('a, 'b) t -> f:('a key -> bool) -> ('a, 'b) t + val filter : ('a, 'b) t -> f:('b -> bool) -> ('a, 'b) t + val filteri : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> ('a, 'b) t + + (** Returns new tables with bound values partitioned by [f] applied to the bound + values. *) + val partition_map + : ('a, 'b) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + + (** Like [partition_map], but the function [f] takes both key and data as arguments. *) + val partition_mapi + : ('a, 'b) t + -> f:(key:'a key -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + + (** Returns a pair of tables [(t1, t2)], where [t1] contains all the elements of the + initial table which satisfy the predicate [f], and [t2] contains the rest. *) + val partition_tf : ('a, 'b) t -> f:('b -> bool) -> ('a, 'b) t * ('a, 'b) t + + (** Like [partition_tf], but the function [f] takes both key and data as arguments. *) + val partitioni_tf + : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> ('a, 'b) t * ('a, 'b) t + + (** [find_or_add t k ~default] returns the data associated with key [k] if it is in the + table [t], and otherwise assigns [k] the value returned by [default ()]. *) + val find_or_add : ('a, 'b) t -> 'a key -> default:(unit -> 'b) -> 'b + + (** Like [find_or_add] but [default] takes the key as an argument. *) + val findi_or_add : ('a, 'b) t -> 'a key -> default:('a key -> 'b) -> 'b + + (** [find t k] returns [Some] (the current binding) of [k] in [t], or [None] if no such + binding exists. *) + val find : ('a, 'b) t -> 'a key -> 'b option + + (** [find_exn t k] returns the current binding of [k] in [t], or raises [Caml.Not_found] + or [Not_found_s] if no such binding exists. *) + val find_exn : ('a, 'b) t -> 'a key -> 'b + + (** [find_and_call t k ~if_found ~if_not_found] + + is equivalent to: + + [match find t k with Some v -> if_found v | None -> if_not_found k] + + except that it doesn't allocate the option. *) + val find_and_call + : ('a, 'b) t + -> 'a key + -> if_found:('b -> 'c) + -> if_not_found:('a key -> 'c) + -> 'c + + val findi_and_call + : ('a, 'b) t + -> 'a key + -> if_found:(key:'a key -> data:'b -> 'c) + -> if_not_found:('a key -> 'c) + -> 'c + + (** [find_and_remove t k] returns Some (the current binding) of k in t and removes it, + or None is no such binding exists. *) + val find_and_remove : ('a, 'b) t -> 'a key -> 'b option + + (** Merges two hashtables. + + The result of [merge f h1 h2] has as keys the set of all [k] in the union of the + sets of keys of [h1] and [h2] for which [d(k)] is not None, where: + + d(k) = + - [f ~key:k (Some d1) None] + if [k] in [h1] maps to d1, and [h2] does not have data for [k]; + + - [f ~key:k None (Some d2)] + if [k] in [h2] maps to d2, and [h1] does not have data for [k]; + + - [f ~key:k (Some d1) (Some d2)] + otherwise, where [k] in [h1] maps to [d1] and [k] in [h2] maps to [d2]. + + Each key [k] is mapped to a single piece of data [x], where [d(k) = Some x]. + + Example: + + {v + let h1 = Hashtbl.of_alist_exn (module Int) [(1, 5); (2, 3232)] in + let h2 = Hashtbl.of_alist_exn (module Int) [(1, 3)] in + Hashtbl.merge h1 h2 ~f:(fun ~key:_ -> function + | `Left x -> Some (`Left x) + | `Right x -> Some (`Right x) + | `Both (x, y) -> if x=y then None else Some (`Both (x,y)) + ) |> Hashtbl.to_alist;; + - : (int * [> `Both of int * int | `Left of int | `Right of int ]) list = + [(2, `Left 3232); (1, `Both (5, 3))] + v} *) + val merge + : ('k, 'a) t + -> ('k, 'b) t + -> f:(key:'k key -> [ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> 'c option) + -> ('k, 'c) t + + (** Every [key] in [src] will be removed or set in [dst] according to the return value + of [f]. *) + type 'a merge_into_action = Remove | Set_to of 'a + + val merge_into + : src:('k, 'a) t + -> dst:('k, 'b) t + -> f:(key:'k key -> 'a -> 'b option -> 'b merge_into_action) + -> unit + + (** Returns the list of all keys for given hashtable. *) + val keys : ('a, _) t -> 'a key list + + (** Returns the list of all data for given hashtable. *) + val data : (_, 'b) t -> 'b list + + (** [filter_inplace t ~f] removes all the elements from [t] that don't satisfy [f]. *) + val filter_keys_inplace : ('a, _) t -> f:('a key -> bool) -> unit + val filter_inplace : ( _, 'b) t -> f:('b -> bool) -> unit + val filteri_inplace : ('a, 'b) t -> f:(key:'a key -> data:'b -> bool) -> unit + + (** [map_inplace t ~f] applies [f] to all elements in [t], transforming them in + place. *) + val map_inplace : (_, 'b) t -> f:( 'b -> 'b) -> unit + val mapi_inplace : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'b) -> unit + + (** [filter_map_inplace] combines the effects of [map_inplace] and [filter_inplace]. *) + val filter_map_inplace : (_, 'b) t -> f:( 'b -> 'b option) -> unit + val filter_mapi_inplace : ('a, 'b) t -> f:(key:'a key -> data:'b -> 'b option) -> unit + + (** [equal t1 t2 f] and [similar t1 t2 f] both return true iff [t1] and [t2] have the + same keys and for all keys [k], [f (find_exn t1 k) (find_exn t2 k)]. [equal] and + [similar] only differ in their types. *) + val equal : ('a, 'b ) t -> ('a, 'b ) t -> ('b -> 'b -> bool) -> bool + val similar : ('a, 'b1) t -> ('a, 'b2) t -> ('b1 -> 'b2 -> bool) -> bool + + (** Returns the list of all (key, data) pairs for given hashtable. *) + val to_alist : ('a, 'b) t -> ('a key * 'b) list + + val validate + : name:('a key -> string) + -> 'b Validate.check + -> ('a, 'b) t Validate.check + + (** [remove_if_zero]'s default is [false]. *) + val incr : ?by:int -> ?remove_if_zero:bool -> ('a, int) t -> 'a key -> unit + val decr : ?by:int -> ?remove_if_zero:bool -> ('a, int) t -> 'a key -> unit +end + +module type Multi = sig + type ('a, 'b) t + type 'a key + + (** [add_multi t ~key ~data] if [key] is present in the table then cons + [data] on the list, otherwise add [key] with a single element list. *) + val add_multi : ('a, 'b list) t -> key:'a key -> data:'b -> unit + + (** [remove_multi t key] updates the table, removing the head of the list bound to + [key]. If the list has only one element (or is empty) then the binding is + removed. *) + val remove_multi : ('a, _ list) t -> 'a key -> unit + + (** [find_multi t key] returns the empty list if [key] is not present in the table, + returns [t]'s values for [key] otherwise. *) + val find_multi : ('a, 'b list) t -> 'a key -> 'b list +end + +type ('key, 'data, 'z) create_options = + ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'key) + -> 'z + +type ('key, 'data, 'z) create_options_without_first_class_module = + ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> 'z + +module type Creators_generic = sig + type ('a, 'b) t + type 'a key + type ('key, 'data, 'z) create_options + + val create : ('a key, 'b, unit -> ('a, 'b) t) create_options + + val of_alist + : ('a key, + 'b, + ('a key * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_key of 'a key + ]) create_options + + val of_alist_report_all_dups + : ('a key, + 'b, + ('a key * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a key list + ]) create_options + + val of_alist_or_error + : ('a key, 'b, ('a key * 'b) list -> ('a, 'b) t Or_error.t) create_options + + val of_alist_exn : ('a key, 'b, ('a key * 'b) list -> ('a, 'b) t) create_options + + val of_alist_multi + : ('a key, 'b list, ('a key * 'b) list -> ('a, 'b list) t) create_options + + + (** {[ create_mapped get_key get_data [x1,...,xn] + = of_alist [get_key x1, get_data x1; ...; get_key xn, get_data xn] ]} *) + val create_mapped + : ('a key, + 'b, + get_key:('r -> 'a key) + -> get_data:('r -> 'b) + -> 'r list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a key list ]) create_options + + (** {[ create_with_key ~get_key [x1,...,xn] + = of_alist [get_key x1, x1; ...; get_key xn, xn] ]} *) + val create_with_key + : ('a key, + 'r, + get_key:('r -> 'a key) + -> 'r list + -> [ `Ok of ('a, 'r) t + | `Duplicate_keys of 'a key list ]) create_options + + val create_with_key_or_error + : ('a key, + 'r, + get_key:('r -> 'a key) + -> 'r list + -> ('a, 'r) t Or_error.t) create_options + + val create_with_key_exn + : ('a key, + 'r, + get_key:('r -> 'a key) + -> 'r list + -> ('a, 'r) t) create_options + + val group + : ('a key, + 'b, + get_key:('r -> 'a key) + -> get_data:('r -> 'b) + -> combine:('b -> 'b -> 'b) + -> 'r list + -> ('a, 'b) t) create_options +end + +module type Creators = sig + type ('a, 'b) t + + (** {2 Creators} *) + + (** The module you pass to [create] must have a type that is hashable, sexpable, and + comparable. + + Example: + + {v + Hashtbl.create (module Int);; + - : (int, '_a) Hashtbl.t = <abstr>;; + v} *) + val create + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a, 'b) t + + (** Example: + + {v + Hashtbl.of_alist (module Int) [(3, "something"); (2, "whatever")] + - : [ `Duplicate_key of int | `Ok of (int, string) Hashtbl.t ] = `Ok <abstr> + v} *) + val of_alist + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_key of 'a + ] + + (** Whereas [of_alist] will report [Duplicate_key] no matter how many dups there are in + your list, [of_alist_report_all_dups] will report each and every duplicate entry. + + For example: + + {v + Hashtbl.of_alist (module Int) [(1, "foo"); (1, "bar"); (2, "foo"); (2, "bar")];; + - : [ `Duplicate_key of int | `Ok of (int, string) Hashtbl.t ] = `Duplicate_key 1 + + Hashtbl.of_alist_report_all_dups (module Int) [(1, "foo"); (1, "bar"); (2, "foo"); (2, "bar")];; + - : [ `Duplicate_keys of int list | `Ok of (int, string) Hashtbl.t ] = `Duplicate_keys [1; 2] + v} *) + val of_alist_report_all_dups + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a list + ] + + val of_alist_or_error + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> ('a, 'b) t Or_error.t + + val of_alist_exn + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> ('a, 'b) t + + (** Creates a {{!Multi} "multi"} hashtable, i.e., a hashtable where each key points to a + list potentially containing multiple values. So instead of short-circuiting with a + [`Duplicate_key] variant on duplicates, as in [of_alist], [of_alist_multi] folds + those values into a list for the given key: + + {v + let h = Hashtbl.of_alist_multi (module Int) [(1, "a"); (1, "b"); (2, "c"); (2, "d")];; + val h : (int, string list) Hashtbl.t = <abstr> + + Hashtbl.find_exn h 1;; + - : string list = ["b"; "a"] + v} *) + val of_alist_multi + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> ('a * 'b) list + -> ('a, 'b list) t + + (** Applies the [get_key] and [get_data] functions to the ['r list] to create the + initial keys and values, respectively, for the new hashtable. + + {[ create_mapped get_key get_data [x1;...;xn] + = of_alist [get_key x1, get_data x1; ...; get_key xn, get_data xn] + ]} + + Example: + + {v + let h = + Hashtbl.create_mapped (module Int) + ~get_key:(fun x -> x) + ~get_data:(fun x -> x + 1) + [1; 2; 3];; + val h : [ `Duplicate_keys of int list | `Ok of (int, int) Hashtbl.t ] = `Ok <abstr> + + let h = + match h with + | `Ok x -> x + | `Duplicate_keys _ -> failwith "" + in + Hashtbl.find_exn h 1;; + - : int = 2 + v} *) + val create_mapped + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> get_data:('r -> 'b) + -> 'r list + -> [ `Ok of ('a, 'b) t + | `Duplicate_keys of 'a list ] + + (** {[ create_with_key ~get_key [x1;...;xn] + = of_alist [get_key x1, x1; ...; get_key xn, xn] ]} *) + val create_with_key + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> 'r list + -> [ `Ok of ('a, 'r) t + | `Duplicate_keys of 'a list ] + + val create_with_key_or_error + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> 'r list + -> ('a, 'r) t Or_error.t + + val create_with_key_exn + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> 'r list + -> ('a, 'r) t + + (** Like [create_mapped], applies the [get_key] and [get_data] functions to the ['r + list] to create the initial keys and values, respectively, for the new hashtable -- + and then, like [add_multi], folds together values belonging to the same keys. Here, + though, the function used for the folding is given by [combine] (instead of just + being a [cons]). + + Example: + + {v + Hashtbl.group (module Int) + ~get_key:(fun x -> x / 2) + ~get_data:(fun x -> x) + ~combine:(fun x y -> x * y) + [ 1; 2; 3; 4] + |> Hashtbl.to_alist;; + - : (int * int) list = [(2, 4); (1, 6); (0, 1)] + v} *) + val group + : ?growth_allowed:bool (** defaults to [true] *) + -> ?size:int (** initial size -- default 128 *) + -> (module Key with type t = 'a) + -> get_key:('r -> 'a) + -> get_data:('r -> 'b) + -> combine:('b -> 'b -> 'b) + -> 'r list + -> ('a, 'b) t +end + +module type S_without_submodules = sig + + val hash : 'a -> int + val hash_param : int -> int -> 'a -> int + + + type ('a, 'b) t + + (** We provide a [sexp_of_t] but not a [t_of_sexp] for this type because one needs to be + explicit about the hash and comparison functions used when creating a hashtable. + Note that [Hashtbl.Poly.t] does have [[@@deriving_inline sexp][@@@end]], and uses OCaml's built-in + polymorphic comparison and and polymorphic hashing. *) + val sexp_of_t : ('a -> Sexp.t) -> ('b -> Sexp.t) -> ('a, 'b) t -> Sexp.t + + include Creators + with type ('a, 'b) t := ('a, 'b) t + (** @inline *) + + include Accessors + with type ('a, 'b) t := ('a, 'b) t + with type 'a key = 'a + (** @inline *) + + include Multi + with type ('a, 'b) t := ('a, 'b) t + with type 'a key := 'a key + (** @inline *) + + val hashable_s : ('key, _) t -> (module Key with type t = 'key) + + include Invariant.S2 with type ('a, 'b) t := ('a, 'b) t + +end + +module type S_poly = sig + + type ('a, 'b) t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] + + val hashable : 'a Hashable.t + + include Invariant.S2 with type ('a, 'b) t := ('a, 'b) t + + include Creators_generic + with type ('a, 'b) t := ('a, 'b) t + with type 'a key = 'a + with type ('key, 'data, 'z) create_options + := ('key, 'data, 'z) create_options_without_first_class_module + + include Accessors + with type ('a, 'b) t := ('a, 'b) t + with type 'a key := 'a key + + include Multi + with type ('a, 'b) t := ('a, 'b) t + with type 'a key := 'a key +end + +module type For_deriving = sig + type ('k, 'v) t + + module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Key with type t := t + end + val sexp_of_m__t + : (module Sexp_of_m with type t = 'k) + -> ('v -> Sexp.t) + -> ('k, 'v) t + -> Sexp.t + val m__t_of_sexp + : (module M_of_sexp with type t = 'k) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v) t +end + +module type Hashtbl = sig + + (** A hash table is a mutable data structure implementing a map between keys and values. + It supports constant-time lookup and in-place modification. + + {1 Usage} + + As a simple example, we'll create a hash table with string keys using the + {{!create}[create]} constructor, which expects a module defining the key's type: + + {[ + let h = Hashtbl.create (module String);; + val h : (string, '_a) Hashtbl.t = <abstr> + ]} + + We can set the values of individual keys with {{!set}[set]}. If the key already has + a value, it will be overwritten. + + {v + Hashtbl.set h ~key:"foo" ~data:5;; + - : unit = () + + Hashtbl.set h ~key:"foo" ~data:6;; + - : unit = () + + Hashtbl.set h ~key:"bar" ~data:6;; + - : unit = () + v} + + We can access values by key, or dump all of the hash table's data: + + {v + Hashtbl.find h "foo";; + - : int option = Some 6 + + Hashtbl.find_exn h "foo";; + - : int = 6 + + Hashtbl.to_alist h;; + - : (string * int) list = [("foo", 6); ("bar", 6)] + v} + + {{!change}[change]} lets us change a key's value by applying the given function: + + {v + Hashtbl.change h "foo" (fun x -> + match x with + | Some x -> Some (x * 2) + | None -> None + );; + - : unit = () + + Hashtbl.to_alist h;; + - : (string * int) list = [("foo", 12); ("bar", 6)] + v} + + + We can use {{!merge}[merge]} to merge two hashtables with fine-grained control over + how we choose values when a key is present in the first ("left") hashtable, the + second ("right"), or both. Here, we'll cons the values when both hashtables have a + key: + + {v + let h1 = Hashtbl.of_alist_exn (module Int) [(1, 5); (2, 3232)] in + let h2 = Hashtbl.of_alist_exn (module Int) [(1, 3)] in + Hashtbl.merge h1 h2 ~f:(fun ~key:_ -> function + | `Left x -> Some (`Left x) + | `Right x -> Some (`Right x) + | `Both (x, y) -> if x=y then None else Some (`Both (x,y)) + ) |> Hashtbl.to_alist;; + - : (int * [> `Both of int * int | `Left of int | `Right of int ]) list = + [(2, `Left 3232); (1, `Both (5, 3))] + v} + + {1 Interface} *) + + include S_without_submodules (** @inline *) + + module type Accessors = Accessors + module type Creators = Creators + module type Key = Key + module type Multi = Multi + module type S_poly = S_poly + module type S_without_submodules = S_without_submodules + + module type For_deriving = For_deriving + + type nonrec ('key, 'data, 'z) create_options = + ('key, 'data, 'z) create_options + + module Creators (Key : sig type 'a t val hashable : 'a t Hashable.t end) : sig + type ('a, 'b) t_ = ('a Key.t, 'b) t + val t_of_sexp : (Sexp.t -> 'a Key.t) -> (Sexp.t -> 'b) -> Sexp.t -> ('a, 'b) t_ + include Creators_generic + with type ('a, 'b) t := ('a, 'b) t_ + with type 'a key := 'a Key.t + with type ('key, 'data, 'a) create_options := + ('key, 'data, 'a) create_options_without_first_class_module + end + + module Poly : S_poly with type ('a, 'b) t = ('a, 'b) t + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_to_int_table = int Hashtbl.M(String).t + ]} + + which stands for: + + {[ + type string_to_int_table = (String.t, int) Hashtbl.t + ]} + + The point is that [int Hashtbl.M(String).t] supports deriving, whereas the second + syntax doesn't (because [t_of_sexp] doesn't know what comparison/hash function to + use). *) + module M (K : T.T) : sig + type nonrec 'v t = (K.t, 'v) t + end + + include For_deriving with type ('a, 'b) t := ('a, 'b) t + + (**/**) + (*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) + module Private : sig + + module type Creators_generic = Creators_generic + + type nonrec ('key, 'data, 'z) create_options_without_first_class_module = + ('key, 'data, 'z) create_options_without_first_class_module + + val hashable : ('key, _) t -> 'key Hashable.t + end +end diff --git a/src/hex_lexer.mll b/src/hex_lexer.mll new file mode 100644 index 0000000..0551e8d --- /dev/null +++ b/src/hex_lexer.mll @@ -0,0 +1,15 @@ +{ +type result = +| Neg of string +| Pos of string +} + +let hex_digit = ['0' - '9' 'A' - 'F' 'a' - 'f'] +let body = (hex_digit (hex_digit | '_')*) as body +let body_with_suffix = '0' ['X' 'x'] body +let pos = body_with_suffix +let neg = '-' body_with_suffix + +rule parse_hex = parse +| neg { Neg body } +| pos { Pos body } diff --git a/src/identifiable.ml b/src/identifiable.ml new file mode 100644 index 0000000..f4e059e --- /dev/null +++ b/src/identifiable.ml @@ -0,0 +1,58 @@ +open! Import + +module type S = sig + type t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable .S with type t := t + include Comparable .S with type t := t + include Pretty_printer.S with type t := t +end + +module Make (T : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t + val module_name : string + end) = struct + include T + include Comparable .Make (T) + include Pretty_printer.Register (T) +end + +module Make_using_comparator (T : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + include Stringable.S with type t := t + val module_name : string + end) = struct + include T + include Comparable .Make_using_comparator (T) + include Pretty_printer.Register (T) +end diff --git a/src/identifiable.mli b/src/identifiable.mli new file mode 100644 index 0000000..a0770bf --- /dev/null +++ b/src/identifiable.mli @@ -0,0 +1,77 @@ +(** A signature combining functionality that is commonly used for types that are intended + to act as names or identifiers. + + Modules that satisfy [Identifiable] can be printed and parsed (both through string and + s-expression converters) and can be used in hash-based and comparison-based + containers (e.g., hashtables and maps). + + This module also provides functors for conveniently constructing identifiable + modules. *) + +open! Import + +module type S = sig + type t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t + include Comparable.S with type t := t + include Pretty_printer.S with type t := t +end + +(** Used for making an Identifiable module. Here's an example. + + {[ + module Id = struct + module T = struct + type t = A | B [@@deriving_inline compare, hash, sexp][@@@end] + let of_string s = t_of_sexp (sexp_of_string s) + let to_string t = string_of_sexp (sexp_of_t t) + let module_name = "My_library.Std.Id" + end + include T + include Identifiable.Make (T) + end + ]} *) +module Make (M : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Stringable.S with type t := t + val module_name : string (** For registering the pretty printer. *) + end) : S + with type t := M.t + +module Make_using_comparator (M : sig + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparator.S with type t := t + include Stringable.S with type t := t + val module_name : string + end) : S + with type t := M.t + with type comparator_witness := M.comparator_witness diff --git a/src/import.ml b/src/import.ml new file mode 100644 index 0000000..e3ba407 --- /dev/null +++ b/src/import.ml @@ -0,0 +1,8 @@ +include Import0 + +include Sexplib0.Sexp_conv +include Hash.Builtin +include Ppx_compare_lib.Builtin +include Int_replace_polymorphic_compare + +exception Not_found_s = Sexp.Not_found_s diff --git a/src/import0.ml b/src/import0.ml new file mode 100644 index 0000000..0431d2e --- /dev/null +++ b/src/import0.ml @@ -0,0 +1,392 @@ +(* This module is included in [Import]. It is aimed at modules that define the standard + combinators for [sexp_of], [of_sexp], [compare] and [hash] and are included in + [Import]. *) + +include + (Shadow_stdlib + : (module type of struct include Shadow_stdlib end + with type 'a ref := 'a ref + with type ('a, 'b, 'c) format := ('a, 'b, 'c) format + with type ('a, 'b, 'c, 'd) format4 := ('a, 'b, 'c, 'd) format4 + with type ('a, 'b, 'c, 'd, 'e, 'f) format6 := ('a, 'b, 'c, 'd, 'e, 'f) format6 + (* These modules are redefined in Base *) + with module Array := Shadow_stdlib.Array + with module Bool := Shadow_stdlib.Bool + with module Buffer := Shadow_stdlib.Buffer + with module Bytes := Shadow_stdlib.Bytes + with module Char := Shadow_stdlib.Char + with module Float := Shadow_stdlib.Float + with module Hashtbl := Shadow_stdlib.Hashtbl + with module Int := Shadow_stdlib.Int + with module Int32 := Shadow_stdlib.Int32 + with module Int64 := Shadow_stdlib.Int64 + with module Lazy := Shadow_stdlib.Lazy + with module List := Shadow_stdlib.List + with module Map := Shadow_stdlib.Map + with module Nativeint := Shadow_stdlib.Nativeint + with module Option := Shadow_stdlib.Option + with module Printf := Shadow_stdlib.Printf + with module Queue := Shadow_stdlib.Queue + with module Random := Shadow_stdlib.Random + with module Result := Shadow_stdlib.Result + with module Set := Shadow_stdlib.Set + with module Stack := Shadow_stdlib.Stack + with module String := Shadow_stdlib.String + with module Sys := Shadow_stdlib.Sys + with module Uchar := Shadow_stdlib.Uchar + with module Unit := Shadow_stdlib.Unit + )) [@ocaml.warning "-3"] +type 'a ref = 'a Caml.ref = { mutable contents: 'a } + +(* Reshuffle [Caml] so that we choose the modules using labels when available. *) +module Caml = struct + + + (** @canonical Caml.Arg *) + module Arg = Caml.Arg + + (** @canonical Caml.StdLabels.Array *) + module Array = Caml.StdLabels.Array + + (** @canonical Caml.Bool *) + module Bool = Caml.Bool + + (** @canonical Caml.Buffer *) + module Buffer = Caml.Buffer + + (** @canonical Caml.StdLabels.Bytes *) + module Bytes = Caml.StdLabels.Bytes + + (** @canonical Caml.Char *) + module Char = Caml.Char + + (** @canonical Caml.Ephemeron *) + module Ephemeron = Caml.Ephemeron + + (** @canonical Caml.Float *) + module Float = Caml.Float + + (** @canonical Caml.Format *) + module Format = Caml.Format + + (** @canonical Caml.Fun *) + module Fun = Caml.Fun + + (** @canonical Caml.Gc *) + module Gc = Caml.Gc + + (** @canonical Caml.MoreLabels.Hashtbl *) + module Hashtbl = Caml.MoreLabels.Hashtbl + + (** @canonical Caml.Int32 *) + module Int32 = Caml.Int32 + + (** @canonical Caml.Int *) + module Int = Caml.Int + + (** @canonical Caml.Int64 *) + module Int64 = Caml.Int64 + + (** @canonical Caml.Lazy *) + module Lazy = Caml.Lazy + + (** @canonical Caml.Lexing *) + module Lexing = Caml.Lexing + + (** @canonical Caml.StdLabels.List *) + module List = Caml.StdLabels.List + + (** @canonical Caml.MoreLabels.Map *) + module Map = Caml.MoreLabels.Map + + (** @canonical Caml.Nativeint *) + module Nativeint = Caml.Nativeint + + (** @canonical Caml.Obj *) + module Obj = Caml.Obj + + (** @canonical Caml.Option *) + module Option = Caml.Option + + (** @canonical Caml.Parsing *) + module Parsing = Caml.Parsing + + (** @canonical Caml.Printexc *) + module Printexc = Caml.Printexc + + (** @canonical Caml.Printf *) + module Printf = Caml.Printf + + (** @canonical Caml.Queue *) + module Queue = Caml.Queue + + (** @canonical Caml.Random *) + module Random = Caml.Random + + (** @canonical Caml.Result *) + module Result = Caml.Result + + (** @canonical Caml.Scanf *) + module Scanf = Caml.Scanf + + (** @canonical Caml.MoreLabels.Set *) + module Set = Caml.MoreLabels.Set + + (** @canonical Caml.Stack *) + module Stack = Caml.Stack + + (** @canonical Caml.Stream *) + module Stream = Caml.Stream + + (** @canonical Caml.StdLabels.String *) + module String = Caml.StdLabels.String + + (** @canonical Caml.Sys *) + module Sys = Caml.Sys + + (** @canonical Caml.Uchar *) + module Uchar = Caml.Uchar + + (** @canonical Caml.Unit *) + module Unit = Caml.Unit + + include Pervasives [@ocaml.warning "-3"] + + exception Not_found = Caml.Not_found +end + +external ( |> ) : 'a -> ( 'a -> 'b) -> 'b = "%revapply" + +(* These need to be declared as an external to get the lazy behavior *) +external ( && ) : bool -> bool -> bool = "%sequand" +external ( || ) : bool -> bool -> bool = "%sequor" +external not : bool -> bool = "%boolnot" + +(* This need to be declared as an external for the warnings to work properly *) +external ignore : _ -> unit = "%ignore" + +let ( != ) = Caml.( != ) +let ( * ) = Caml.( * ) +let ( ** ) = Caml.( ** ) +let ( *. ) = Caml.( *. ) +let ( + ) = Caml.( + ) +let ( +. ) = Caml.( +. ) +let ( - ) = Caml.( - ) +let ( -. ) = Caml.( -. ) +let ( / ) = Caml.( / ) +let ( /. ) = Caml.( /. ) + +(** @canonical Base.Poly *) +module Poly = Poly0 + +module Int_replace_polymorphic_compare = struct + let ( < ) (x : int) y = Poly.( < ) x y + let ( <= ) (x : int) y = Poly.( <= ) x y + let ( <> ) (x : int) y = Poly.( <> ) x y + let ( = ) (x : int) y = Poly.( = ) x y + let ( > ) (x : int) y = Poly.( > ) x y + let ( >= ) (x : int) y = Poly.( >= ) x y + + let ascending (x : int) y = Poly.ascending x y + let descending (x : int) y = Poly.descending x y + let compare (x : int) y = Poly.compare x y + let equal (x : int) y = Poly.equal x y + let max (x : int) y = if x >= y then x else y + let min (x : int) y = if x <= y then x else y +end + +include Int_replace_polymorphic_compare + +module Int32_replace_polymorphic_compare = struct + let ( < ) (x : Caml.Int32.t) y = Poly.( < ) x y + let ( <= ) (x : Caml.Int32.t) y = Poly.( <= ) x y + let ( <> ) (x : Caml.Int32.t) y = Poly.( <> ) x y + let ( = ) (x : Caml.Int32.t) y = Poly.( = ) x y + let ( > ) (x : Caml.Int32.t) y = Poly.( > ) x y + let ( >= ) (x : Caml.Int32.t) y = Poly.( >= ) x y + + let ascending (x : Caml.Int32.t) y = Poly.ascending x y + let descending (x : Caml.Int32.t) y = Poly.descending x y + let compare (x : Caml.Int32.t) y = Poly.compare x y + let equal (x : Caml.Int32.t) y = Poly.equal x y + let max (x : Caml.Int32.t) y = if x >= y then x else y + let min (x : Caml.Int32.t) y = if x <= y then x else y +end + +module Int64_replace_polymorphic_compare = struct + let ( < ) (x : Caml.Int64.t) y = Poly.( < ) x y + let ( <= ) (x : Caml.Int64.t) y = Poly.( <= ) x y + let ( <> ) (x : Caml.Int64.t) y = Poly.( <> ) x y + let ( = ) (x : Caml.Int64.t) y = Poly.( = ) x y + let ( > ) (x : Caml.Int64.t) y = Poly.( > ) x y + let ( >= ) (x : Caml.Int64.t) y = Poly.( >= ) x y + + let ascending (x : Caml.Int64.t) y = Poly.ascending x y + let descending (x : Caml.Int64.t) y = Poly.descending x y + let compare (x : Caml.Int64.t) y = Poly.compare x y + let equal (x : Caml.Int64.t) y = Poly.equal x y + let max (x : Caml.Int64.t) y = if x >= y then x else y + let min (x : Caml.Int64.t) y = if x <= y then x else y +end + +module Nativeint_replace_polymorphic_compare = struct + let ( < ) (x : Caml.Nativeint.t) y = Poly.( < ) x y + let ( <= ) (x : Caml.Nativeint.t) y = Poly.( <= ) x y + let ( <> ) (x : Caml.Nativeint.t) y = Poly.( <> ) x y + let ( = ) (x : Caml.Nativeint.t) y = Poly.( = ) x y + let ( > ) (x : Caml.Nativeint.t) y = Poly.( > ) x y + let ( >= ) (x : Caml.Nativeint.t) y = Poly.( >= ) x y + + let ascending (x : Caml.Nativeint.t) y = Poly.ascending x y + let descending (x : Caml.Nativeint.t) y = Poly.descending x y + let compare (x : Caml.Nativeint.t) y = Poly.compare x y + let equal (x : Caml.Nativeint.t) y = Poly.equal x y + let max (x : Caml.Nativeint.t) y = if x >= y then x else y + let min (x : Caml.Nativeint.t) y = if x <= y then x else y +end + +module Bool_replace_polymorphic_compare = struct + let ( < ) (x : bool) y = Poly.( < ) x y + let ( <= ) (x : bool) y = Poly.( <= ) x y + let ( <> ) (x : bool) y = Poly.( <> ) x y + let ( = ) (x : bool) y = Poly.( = ) x y + let ( > ) (x : bool) y = Poly.( > ) x y + let ( >= ) (x : bool) y = Poly.( >= ) x y + + let ascending (x : bool) y = Poly.ascending x y + let descending (x : bool) y = Poly.descending x y + let compare (x : bool) y = Poly.compare x y + let equal (x : bool) y = Poly.equal x y + let max (x : bool) y = if x >= y then x else y + let min (x : bool) y = if x <= y then x else y +end + +module Char_replace_polymorphic_compare = struct + let ( < ) (x : char) y = Poly.( < ) x y + let ( <= ) (x : char) y = Poly.( <= ) x y + let ( <> ) (x : char) y = Poly.( <> ) x y + let ( = ) (x : char) y = Poly.( = ) x y + let ( > ) (x : char) y = Poly.( > ) x y + let ( >= ) (x : char) y = Poly.( >= ) x y + + let ascending (x : char) y = Poly.ascending x y + let descending (x : char) y = Poly.descending x y + let compare (x : char) y = Poly.compare x y + let equal (x : char) y = Poly.equal x y + let max (x : char) y = if x >= y then x else y + let min (x : char) y = if x <= y then x else y +end + +module Uchar_replace_polymorphic_compare = struct + let i x = Caml.Uchar.to_int x + + let ( < ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( < ) (i x) (i y) + let ( <= ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( <= ) (i x) (i y) + let ( <> ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( <> ) (i x) (i y) + let ( = ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( = ) (i x) (i y) + let ( > ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( > ) (i x) (i y) + let ( >= ) (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.( >= ) (i x) (i y) + + let ascending (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.ascending (i x) (i y) + let descending (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.descending (i x) (i y) + let compare (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.compare (i x) (i y) + let equal (x : Caml.Uchar.t) y = Int_replace_polymorphic_compare.equal (i x) (i y) + let max (x : Caml.Uchar.t) y = if x >= y then x else y + let min (x : Caml.Uchar.t) y = if x <= y then x else y +end + +module Float_replace_polymorphic_compare = struct + let ( < ) (x : float) y = Poly.( < ) x y + let ( <= ) (x : float) y = Poly.( <= ) x y + let ( <> ) (x : float) y = Poly.( <> ) x y + let ( = ) (x : float) y = Poly.( = ) x y + let ( > ) (x : float) y = Poly.( > ) x y + let ( >= ) (x : float) y = Poly.( >= ) x y + + let ascending (x : float) y = Poly.ascending x y + let descending (x : float) y = Poly.descending x y + let compare (x : float) y = Poly.compare x y + let equal (x : float) y = Poly.equal x y + let max (x : float) y = if x >= y then x else y + let min (x : float) y = if x <= y then x else y +end + +module String_replace_polymorphic_compare = struct + let ( < ) (x : string) y = Poly.( < ) x y + let ( <= ) (x : string) y = Poly.( <= ) x y + let ( <> ) (x : string) y = Poly.( <> ) x y + let ( = ) (x : string) y = Poly.( = ) x y + let ( > ) (x : string) y = Poly.( > ) x y + let ( >= ) (x : string) y = Poly.( >= ) x y + + let ascending (x : string) y = Poly.ascending x y + let descending (x : string) y = Poly.descending x y + let compare (x : string) y = Poly.compare x y + let equal (x : string) y = Poly.equal x y + let max (x : string) y = if x >= y then x else y + let min (x : string) y = if x <= y then x else y +end + +module Bytes_replace_polymorphic_compare = struct + let ( < ) (x : bytes) y = Poly.( < ) x y + let ( <= ) (x : bytes) y = Poly.( <= ) x y + let ( <> ) (x : bytes) y = Poly.( <> ) x y + let ( = ) (x : bytes) y = Poly.( = ) x y + let ( > ) (x : bytes) y = Poly.( > ) x y + let ( >= ) (x : bytes) y = Poly.( >= ) x y + + let ascending (x : bytes) y = Poly.ascending x y + let descending (x : bytes) y = Poly.descending x y + let compare (x : bytes) y = Poly.compare x y + let equal (x : bytes) y = Poly.equal x y + let max (x : bytes) y = if x >= y then x else y + let min (x : bytes) y = if x <= y then x else y +end + +(* This needs to be defined as an external so that the compiler can specialize it as a + direct set or caml_modify *) +external ( := ) : 'a ref -> 'a -> unit = "%setfield0" + +(* These need to be defined as an external otherwise the compiler won't unbox + references *) +external ( ! ) : 'a ref -> 'a = "%field0" +external ref : 'a -> 'a ref = "%makemutable" + +let ( @ ) = Caml.( @ ) +let ( ^ ) = Caml.( ^ ) +let ( ~- ) = Caml.( ~- ) +let ( ~-. ) = Caml.( ~-. ) + +let ( asr ) = Caml.( asr ) +let ( land ) = Caml.( land ) +let lnot = Caml.lnot +let ( lor ) = Caml.( lor ) +let ( lsl ) = Caml.( lsl ) +let ( lsr ) = Caml.( lsr ) +let ( lxor ) = Caml.( lxor ) +let ( mod ) = Caml.( mod ) + +let abs = Caml.abs +let failwith = Caml.failwith +let fst = Caml.fst +let invalid_arg = Caml.invalid_arg +let snd = Caml.snd + +(* [raise] needs to be defined as an external as the compiler automatically replaces + '%raise' by '%reraise' when appropriate. *) +external raise : exn -> _ = "%raise" + +let phys_equal = Caml.( == ) + +let decr = Caml.decr +let incr = Caml.incr + +(* used by sexp_conv, which float0 depends on through option *) +let float_of_string = Caml.float_of_string + +(* [am_testing] is used in a few places to behave differently when in testing mode, such + as in [random.ml]. [am_testing] is implemented using [Base_am_testing], a weak C/js + primitive that returns [false], but when linking an inline-test-runner executable, is + overridden by another primitive that returns [true]. *) +external am_testing : unit -> bool = "Base_am_testing" +let am_testing = am_testing () diff --git a/src/indexed_container.ml b/src/indexed_container.ml new file mode 100644 index 0000000..809bfdf --- /dev/null +++ b/src/indexed_container.ml @@ -0,0 +1,65 @@ +include Indexed_container_intf + +let with_return = With_return.with_return + +let iteri ~fold t ~f = + ignore (fold t ~init:0 ~f:(fun i x -> f i x; i + 1) : int) +;; + +let foldi ~fold t ~init ~f = + let i = ref 0 in + fold t ~init ~f:(fun acc v -> + let acc = f !i acc v in + i := !i + 1; + acc) +;; + +let counti ~foldi t ~f = + foldi t ~init:0 ~f:(fun i n a -> if f i a then n + 1 else n) +;; + +let existsi ~iteri c ~f = + with_return (fun r -> + iteri c ~f:(fun i x -> if f i x then r.return true); + false) +;; + +let for_alli ~iteri c ~f = + with_return (fun r -> + iteri c ~f:(fun i x -> if not (f i x) then r.return false); + true) +;; + +let find_mapi ~iteri t ~f = + with_return (fun r -> + iteri t ~f:(fun i x -> match f i x with None -> () | Some _ as res -> r.return res); + None) +;; + +let findi ~iteri c ~f = + with_return (fun r -> + iteri c ~f:(fun i x -> if f i x then r.return (Some (i, x))); + None) +;; + +module Make (T : Make_arg) : S1 with type 'a t := 'a T.t = struct + include (Container.Make (T)) + + let iteri = + match T.iteri with + | `Custom iteri -> iteri + | `Define_using_fold -> fun t ~f -> iteri ~fold t ~f + ;; + + let foldi = + match T.foldi with + | `Custom foldi -> foldi + | `Define_using_fold -> fun t ~init ~f -> foldi ~fold t ~init ~f + ;; + + let counti t ~f = counti ~foldi t ~f + let existsi t ~f = existsi ~iteri t ~f + let for_alli t ~f = for_alli ~iteri t ~f + let find_mapi t ~f = find_mapi ~iteri t ~f + let findi t ~f = findi ~iteri t ~f +end diff --git a/src/indexed_container.mli b/src/indexed_container.mli new file mode 100644 index 0000000..f5a29e5 --- /dev/null +++ b/src/indexed_container.mli @@ -0,0 +1 @@ +include Indexed_container_intf.Indexed_container (** @inline *) diff --git a/src/indexed_container_intf.ml b/src/indexed_container_intf.ml new file mode 100644 index 0000000..3ae88d2 --- /dev/null +++ b/src/indexed_container_intf.ml @@ -0,0 +1,66 @@ + +type ('t, 'a, 'accum) fold = 't -> init:'accum -> f:('accum -> 'a -> 'accum) -> 'accum +type ('t, 'a, 'accum) foldi = + 't -> init:'accum -> f:(int -> 'accum -> 'a -> 'accum) -> 'accum +type ('t, 'a) iteri = 't -> f:(int -> 'a -> unit) -> unit + +module type S1 = sig + include Container.S1 + + (** These are all like their equivalents in [Container] except that an index starting at + 0 is added as the first argument to [f]. *) + + val foldi : ('a t, 'a, _) foldi + + val iteri : ('a t, 'a) iteri + val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + val counti : 'a t -> f:(int -> 'a -> bool) -> int + val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option + +end + +module type Make_arg = sig + include Container_intf.Make_arg + + val iteri : + [ `Define_using_fold + | `Custom of ('a t, 'a) iteri + ] + + val foldi : + [ `Define_using_fold + | `Custom of ('a t, 'a, _) foldi + ] +end + +module type Indexed_container = sig + + (** Provides generic signatures for containers that support indexed iteration ([iteri], + [foldi], ...). In principle, any container that has [iter] can also implement [iteri], + but the idea is that [Indexed_container_intf] should be included only for containers + that have a meaningful underlying ordering. *) + + module type S1 = S1 + + (** Generic definitions of [foldi] and [iteri] in terms of [fold]. + + E.g., [iteri ~fold t ~f = ignore (fold t ~init:0 ~f:(fun i x -> f i x; i + 1))]. *) + + val foldi : fold:('t, 'a, 'accum) fold -> ('t, 'a, 'accum) foldi + val iteri : fold:('t, 'a, int) fold -> ('t, 'a) iteri + + (** Generic definitions of indexed container operations in terms of [foldi]. *) + + val counti : foldi:('t, 'a, int) foldi -> 't -> f:(int -> 'a -> bool) -> int + + (** Generic definitions of indexed container operations in terms of [iteri]. *) + + val existsi : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> bool) -> bool + val for_alli : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> bool) -> bool + val findi : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> bool) -> (int * 'a) option + val find_mapi : iteri:('t, 'a) iteri -> 't -> f:(int -> 'a -> 'b option) -> 'b option + + module Make (T : Make_arg) : S1 with type 'a t := 'a T.t +end diff --git a/src/info.ml b/src/info.ml new file mode 100644 index 0000000..ce12e85 --- /dev/null +++ b/src/info.ml @@ -0,0 +1,240 @@ +(* This module is trying to minimize dependencies on modules in Core, so as to allow + [Info], [Error], and [Or_error] to be used in as many places as possible. Please avoid + adding new dependencies. *) + +open! Import + +include Info_intf + +module String = String0 + +module Message = struct + type t = + | Could_not_construct of Sexp.t + | String of string + | Exn of exn + | Sexp of Sexp.t + | Tag_sexp of string * Sexp.t * Source_code_position0.t option + | Tag_t of string * t + | Tag_arg of string * Sexp.t * t + | Of_list of int option * t list + | With_backtrace of t * string (* backtrace *) + [@@deriving_inline sexp_of] + let rec sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Could_not_construct v0 -> + let v0 = Sexp.sexp_of_t v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Could_not_construct"; v0] + | String v0 -> + let v0 = sexp_of_string v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "String"; v0] + | Exn v0 -> + let v0 = sexp_of_exn v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Exn"; v0] + | Sexp v0 -> + let v0 = Sexp.sexp_of_t v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Sexp"; v0] + | Tag_sexp (v0, v1, v2) -> + let v0 = sexp_of_string v0 + and v1 = Sexp.sexp_of_t v1 + and v2 = sexp_of_option Source_code_position0.sexp_of_t v2 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Tag_sexp"; v0; v1; v2] + | Tag_t (v0, v1) -> + let v0 = sexp_of_string v0 + and v1 = sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Tag_t"; v0; v1] + | Tag_arg (v0, v1, v2) -> + let v0 = sexp_of_string v0 + and v1 = Sexp.sexp_of_t v1 + and v2 = sexp_of_t v2 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Tag_arg"; v0; v1; v2] + | Of_list (v0, v1) -> + let v0 = sexp_of_option sexp_of_int v0 + and v1 = sexp_of_list sexp_of_t v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Of_list"; v0; v1] + | With_backtrace (v0, v1) -> + let v0 = sexp_of_t v0 + and v1 = sexp_of_string v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "With_backtrace"; v0; v1] + [@@@end] + + let rec to_strings_hum t ac = + (* We use [Sexp.to_string_mach], despite the fact that we are implementing + [to_strings_hum], because we want the info to fit on a single line, and once we've + had to resort to sexps, the message is going to start not looking so pretty + anyway. *) + match t with + | Could_not_construct sexp -> + "could not construct info: " :: Sexp.to_string_mach sexp :: ac + | String string -> string :: ac + | Exn exn -> Sexp.to_string_mach (Exn.sexp_of_t exn) :: ac + | Sexp sexp -> Sexp.to_string_mach sexp :: ac + | Tag_sexp (tag, sexp, _) -> tag :: ": " :: Sexp.to_string_mach sexp :: ac + | Tag_t (tag, t) -> tag :: ": " :: to_strings_hum t ac + | Tag_arg (tag, sexp, t) -> + tag :: ": " :: Sexp.to_string_mach sexp :: ": " :: to_strings_hum t ac + | With_backtrace (t, backtrace) -> + to_strings_hum t ("\nBacktrace:\n" :: backtrace :: ac) + | Of_list (trunc_after, ts) -> + let ts = + match trunc_after with + | None -> ts + | Some max -> + let n = List.length ts in + if n <= max then + ts + else + List.take ts max @ [ String (Printf.sprintf "and %d more info" (n - max)) ] + in + List.fold (List.rev ts) ~init:ac ~f:(fun ac t -> + to_strings_hum t (if List.is_empty ac then ac else ("; " :: ac))) + ;; + + let to_string_hum_deprecated t = String.concat (to_strings_hum t []) + + let rec to_sexps_hum t ac = + match t with + | Could_not_construct _ as t -> sexp_of_t t :: ac + | String string -> Atom string :: ac + | Exn exn -> Exn.sexp_of_t exn :: ac + | Sexp sexp -> sexp :: ac + | Tag_sexp (tag, sexp, here) -> + List ( Atom tag + :: sexp + :: (match here with + | None -> [] + | Some here -> [ Source_code_position0.sexp_of_t here ])) + :: ac + | Tag_t (tag, t) -> List (Atom tag :: to_sexps_hum t []) :: ac + | Tag_arg (tag, sexp, t) -> List (Atom tag :: sexp :: to_sexps_hum t []) :: ac + | With_backtrace (t, backtrace) -> + Sexp.List [ to_sexp_hum t; Sexp.Atom backtrace ] :: ac + | Of_list (_, ts) -> + List.fold (List.rev ts) ~init:ac ~f:(fun ac t -> to_sexps_hum t ac) + and to_sexp_hum t = + match to_sexps_hum t [] with + | [sexp] -> sexp + | sexps -> Sexp.List sexps + ;; + + (* We use [protect] to guard against exceptions raised by user-supplied functions, so + that failure to produce one part of an info doesn't interfere with other parts. *) + let protect f = + try f () with exn -> Could_not_construct (Exn.sexp_of_t exn) + ;; + + let of_info info = protect (fun () -> Lazy.force info) + let to_info t = lazy t +end + +open Message + +type t = Message.t Lazy.t + +let invariant _ = () + +let to_message = Message.of_info +let of_message = Message.to_info + +(* It is OK to use [Message.to_sexp_hum], which is not stable, because [t_of_sexp] below + can handle any sexp. *) +let sexp_of_t t = Message.to_sexp_hum (to_message t) + +let t_of_sexp sexp = lazy (Message.Sexp sexp) + +let compare t1 t2 = + Sexp.compare (sexp_of_t t1) (sexp_of_t t2) +;; + +let hash_fold_t state t = Sexp.hash_fold_t state (sexp_of_t t) +let hash t = Hash.run hash_fold_t t + +let to_string_hum t = + match to_message t with + | String s -> s + | message -> Sexp.to_string_hum (Message.to_sexp_hum message) +;; + +let to_string_hum_deprecated t = Message.to_string_hum_deprecated (to_message t) + +let to_string_mach t = Sexp.to_string_mach (sexp_of_t t) + +let of_lazy l = lazy (protect (fun () -> String (Lazy.force l))) + +let of_lazy_t lazy_t = Lazy.join lazy_t + +let of_string message = Lazy.from_val (String message) + +let createf format = Printf.ksprintf of_string format + +let of_thunk f = lazy (protect (fun () -> String (f ()))) + +let create ?here ?strict tag x sexp_of_x = + match strict with + | None -> lazy (protect (fun () -> Tag_sexp (tag, sexp_of_x x, here))) + | Some () -> of_message ( Tag_sexp (tag, sexp_of_x x, here)) +;; + +let create_s sexp = Lazy.from_val (Sexp sexp) + +let tag t ~tag = lazy (Tag_t (tag, to_message t)) + +let tag_arg t tag x sexp_of_x = + lazy (protect (fun () -> Tag_arg (tag, sexp_of_x x, to_message t))) +;; + +let of_list ?trunc_after ts = + lazy (Of_list (trunc_after, List.map ts ~f:to_message)) +;; + +exception Exn of t + +let () = + (* We install a custom exn-converter rather than use + [exception Exn of t [@@deriving_inline sexp][@@@end]] to eliminate the extra wrapping of + "(Exn ...)". *) + Sexplib.Conv.Exn_converter.add [%extension_constructor Exn] + (function + | Exn t -> sexp_of_t t + | _ -> + (* Reaching this branch indicates a bug in sexplib. *) + assert false) +;; + +let to_exn t = + if not (Lazy.is_val t) + then Exn t + else + match Lazy.force t with + | Message.Exn exn -> exn + | _ -> Exn t +;; + +let of_exn ?backtrace exn = + let backtrace = + match backtrace with + | None -> None + | Some `Get -> Some (Caml.Printexc.get_backtrace ()) + | Some (`This s) -> Some s + in + match exn, backtrace with + | Exn t, None -> t + | Exn t, Some backtrace -> lazy (With_backtrace (to_message t, backtrace)) + | _ , None -> Lazy.from_val (Message.Exn exn) + | _ , Some backtrace -> lazy (With_backtrace (Sexp (Exn.sexp_of_t exn), backtrace)) +;; + +include Pretty_printer.Register_pp(struct + type nonrec t = t + let module_name = "Base.Info" + let pp ppf t = Caml.Format.pp_print_string ppf (to_string_hum t) + end) + +module Internal_repr = Message + diff --git a/src/info.mli b/src/info.mli new file mode 100644 index 0000000..fc27b4a --- /dev/null +++ b/src/info.mli @@ -0,0 +1 @@ +include Info_intf.Info (** @inline *) diff --git a/src/info_intf.ml b/src/info_intf.ml new file mode 100644 index 0000000..6ed86ef --- /dev/null +++ b/src/info_intf.ml @@ -0,0 +1,148 @@ +(** [Info] is a library for lazily constructing human-readable information as a string + or sexp, with a primary use being error messages. + + Using [Info] is often preferable to [sprintf] or manually constructing strings + because you don't have to eagerly construct the string -- you only need to pay when + you actually want to display the info, which for many applications is rare. Using + [Info] is also better than creating custom exceptions because you have more control + over the format. + + Info is intended to be constructed in the following style; for simple info, you + write: + + {[Info.of_string "Unable to find file"]} + + Or for a more descriptive [Info] without attaching any content (but evaluating the + result eagerly): + + {[Info.createf "Process %s exited with code %d" process exit_code]} + + For info where you want to attach some content, you would write: + + {[Info.create "Unable to find file" filename [%sexp_of: string]]} + + Or even, + + {[ + Info.create "price too big" (price, [`Max max_price]) + [%sexp_of: float * [`Max of float]] + ]} + + Note that an [Info.t] can be created from any arbitrary sexp with [Info.t_of_sexp]. +*) + +open! Import + +module type S = sig + + (** Serialization and comparison force the lazy message. *) + type t [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Invariant_intf.S with type t := t + + (** [to_string_hum] forces the lazy message, which might be an expensive operation. + + [to_string_hum] usually produces a sexp; however, it is guaranteed that + [to_string_hum (of_string s) = s]. + + If this string is going to go into a log file, you may find it useful to ensure that + the string is only one line long. To do this, use [to_string_mach t]. *) + val to_string_hum : t -> string + + (** [to_string_mach t] outputs [t] as a sexp on a single line. *) + val to_string_mach : t -> string + + (** Old version (pre 109.61) of [to_string_hum] that some applications rely on. + + Calls should be replaced with [to_string_mach t], which outputs more parentheses and + backslashes. *) + val to_string_hum_deprecated : t -> string + + val of_string : string -> t + + (** Be careful that the body of the lazy or thunk does not access mutable data, since it + will only be called at an undetermined later point. *) + + val of_lazy : string Lazy.t -> t + val of_thunk : (unit -> string) -> t + val of_lazy_t : t Lazy.t -> t + + (** For [create message a sexp_of_a], [sexp_of_a a] is lazily computed, when the info is + converted to a sexp. So if [a] is mutated in the time between the call to [create] + and the sexp conversion, those mutations will be reflected in the sexp. Use + [~strict:()] to force [sexp_of_a a] to be computed immediately. *) + val create + : ?here : Source_code_position0.t + -> ?strict : unit + -> string + -> 'a + -> ('a -> Sexp.t) + -> t + + val create_s : Sexp.t -> t + + (** Constructs a [t] containing only a string from a format. This eagerly constructs + the string. *) + val createf : ('a, unit, string, t) format4 -> 'a + + (** Adds a string to the front. *) + val tag : t -> tag:string -> t + + (** Adds a string and some other data in the form of an s-expression at the front. *) + val tag_arg : t -> string -> 'a -> ('a -> Sexp.t) -> t + + (** Combines multiple infos into one. *) + val of_list : ?trunc_after:int -> t list -> t + + (** [of_exn] and [to_exn] are primarily used with [Error], but their definitions have to + be here because they refer to the underlying representation. + + [~backtrace:`Get] attaches the backtrace for the most recent exception. The same + caveats as for [Printexc.print_backtrace] apply. [~backtrace:(`This s)] attaches + the backtrace [s]. The default is no backtrace. *) + val of_exn : ?backtrace:[ `Get | `This of string ] -> exn -> t + val to_exn : t -> exn + + val pp : Formatter.t -> t -> unit + + module Internal_repr : sig + type info = t + + (** The internal representation. It is exposed so that we can write efficient + serializers outside of this module. *) + type t = + | Could_not_construct of Sexp.t + | String of string + | Exn of exn + | Sexp of Sexp.t + | Tag_sexp of string * Sexp.t * Source_code_position0.t option + | Tag_t of string * t + | Tag_arg of string * Sexp.t * t + | Of_list of int option * t list + | With_backtrace of t * string (** The second argument is the backtrace *) + [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val of_info : info -> t + val to_info : t -> info + end with type info := t +end + +module type Info = sig + module type S = S + + include S +end diff --git a/src/int.ml b/src/int.ml new file mode 100644 index 0000000..eb97185 --- /dev/null +++ b/src/int.ml @@ -0,0 +1,310 @@ +open! Import + +include Int_intf +include Int0 + +module T = struct + type t = int [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int + [@@@end] + + let compare (x : t) y = Bool.to_int (x > y) - Bool.to_int (x < y) + + let of_string s = + try of_string s with + | _ -> Printf.failwithf "Int.of_string: %S" s () + + let to_string = to_string +end + +let num_bits = Int_conversions.num_bits_int + +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let to_float = Caml.float_of_int +let of_float_unchecked = Caml.int_of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + Caml.int_of_float f + else + Printf.invalid_argf "Int.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () + +let zero = 0 +let one = 1 +let minus_one = -1 + +include T +include Comparator.Make(T) +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +module Conv = Int_conversions +include Conv.Make (T) +include Conv.Make_hex(struct + open Int_replace_polymorphic_compare + type t = int [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_int + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = Printf.sprintf "%x" i + let of_string s = Caml.Scanf.sscanf s "%x" Fn.id + + let module_name = "Base.Int.Hex" + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Int" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so + they do not shadow its definitions. This is here so that efficient versions + of the comparison functions are available within this module. *) +open! Int_replace_polymorphic_compare + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let pred i = i - 1 +let succ i = i + 1 + +let to_int i = i +let to_int_exn = to_int +let of_int i = i +let of_int_exn = of_int + +let max_value = Caml.max_int +let min_value = Caml.min_int + +let max_value_30_bits = 0x3FFF_FFFF + +let of_int32 = Conv.int32_to_int +let of_int32_exn = Conv.int32_to_int_exn +let of_int32_trunc = Conv.int32_to_int_trunc +let to_int32 = Conv.int_to_int32 +let to_int32_exn = Conv.int_to_int32_exn +let to_int32_trunc = Conv.int_to_int32_trunc +let of_int64 = Conv.int64_to_int +let of_int64_exn = Conv.int64_to_int_exn +let of_int64_trunc = Conv.int64_to_int_trunc +let to_int64 = Conv.int_to_int64 +let of_nativeint = Conv.nativeint_to_int +let of_nativeint_exn = Conv.nativeint_to_int_exn +let of_nativeint_trunc = Conv.nativeint_to_int_trunc +let to_nativeint = Conv.int_to_nativeint +let to_nativeint_exn = to_nativeint + +let abs x = abs x + +let ( + ) x y = ( + ) x y +let ( - ) x y = ( - ) x y +let ( * ) x y = ( * ) x y +let ( / ) x y = ( / ) x y + +let neg x = -x +let ( ~- ) = neg + +(* note that rem is not same as % *) +let rem a b = a mod b + +let incr = Caml.incr +let decr = Caml.decr + +let shift_right a b = a asr b +let shift_right_logical a b = a lsr b +let shift_left a b = a lsl b +let bit_not a = lnot a +let bit_or a b = a lor b +let bit_and a b = a land b +let bit_xor a b = a lxor b + +let pow = Int_math.int_pow +let ( ** ) b e = pow b e + +module Pow2 = struct + open! Import + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 x = + if x <= 0 then non_positive_argument (); + let x = x - 1 in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + (* The next line is superfluous on 32-bit architectures, but it's faster to do it + anyway than to branch *) + let x = x lor (x lsr 32) in + x + 1 + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= 0 then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + (* The next line is superfluous on 32-bit architectures, but it's faster to do it + anyway than to branch *) + let x = x lor (x lsr 32) in + x - (x lsr 1) + + let is_pow2 x = + if x <= 0 then non_positive_argument (); + (x land (x-1)) = 0 + ;; + + (* C stub for int clz to use the CLZ/BSR instruction where possible *) + external int_clz : int -> int = "Base_int_math_int_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if i <= 0 then + raise_s (Sexp.message "[Int.floor_log2] got invalid input" + ["", sexp_of_int i]); + Sys.word_size_in_bits - 1 - int_clz i + ;; + + let ceil_log2 i = + if i <= 0 then + raise_s (Sexp.message "[Int.ceil_log2] got invalid input" + ["", sexp_of_int i]); + if i = 1 + then 0 + else Sys.word_size_in_bits - int_clz (i - 1) + ;; +end +include Pow2 + +(* This is already defined by Comparable.Validate_with_zero, but Sign.of_int is + more direct. *) +let sign = Sign.of_int + +let popcount = Popcount.int_popcount + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Int_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + module F = Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + include F + + (* These inlined versions of (%), (/%), and (//) perform better than their functorized + counterparts in [F] (see benchmarks below). + + The reason these functions are inlined in [Int] but not in any of the other integer + modules is that they existed in [Int] and [Int] alone prior to the introduction of + the [Int_math.Make] functor, and we didn't want to degrade their performance. + + We won't pre-emptively do the same for new functions, unless someone cares, on a case + by case fashion. *) + + let ( % ) x y = + if y <= zero then + Printf.invalid_argf + "%s %% %s in core_int.ml: modulus should be positive" + (to_string x) (to_string y) (); + let rval = rem x y in + if rval < zero + then rval + y + else rval + ;; + + let ( /% ) x y = + if y <= zero then + Printf.invalid_argf + "%s /%% %s in core_int.ml: divisor should be positive" + (to_string x) (to_string y) (); + if x < zero + then (x + one) / y - one + else x / y + ;; + + let (//) x y = to_float x /. to_float y + ;; + + let ( land ) = ( land ) + let ( lor ) = ( lor ) + let ( lxor ) = ( lxor ) + let ( lnot ) = ( lnot ) + let ( lsl ) = ( lsl ) + let ( asr ) = ( asr ) + let ( lsr ) = ( lsr ) +end + +include O (* [Int] and [Int.O] agree value-wise *) + +module Private = struct + module O_F = O.F +end + +(* Include type-specific [Replace_polymorphic_compare] at the end, after including functor + application that could shadow its definitions. This is here so that efficient versions + of the comparison functions are exported by this module. *) +include Int_replace_polymorphic_compare diff --git a/src/int.mli b/src/int.mli new file mode 100644 index 0000000..6643709 --- /dev/null +++ b/src/int.mli @@ -0,0 +1 @@ +include Int_intf.Int (** @inline *) diff --git a/src/int0.ml b/src/int0.ml new file mode 100644 index 0000000..2306cc4 --- /dev/null +++ b/src/int0.ml @@ -0,0 +1,23 @@ +(* [Int0] defines integer functions that are primitives or can be simply + defined in terms of [Caml]. [Int0] is intended to completely express the + part of [Caml] that [Base] uses for integers -- no other file in Base other + than int0.ml should use these functions directly through [Caml]. [Int0] has + few dependencies, and so is available early in Base's build order. + + All Base files that need to use ints and come before [Base.Int] in build + order should do: + + {[ + module Int = Int0 + ]} + + Defining [module Int = Int0] is also necessary because it prevents ocamldep + from mistakenly causing a file to depend on [Base.Int]. *) + +let to_string = Caml.string_of_int +let of_string = Caml.int_of_string +let to_float = Caml.float_of_int +let of_float = Caml.int_of_float +let max_value = Caml.max_int +let min_value = Caml.min_int +let succ = Caml.succ diff --git a/src/int32.ml b/src/int32.ml new file mode 100644 index 0000000..8ecf0d6 --- /dev/null +++ b/src/int32.ml @@ -0,0 +1,274 @@ +open! Import +open! Caml.Int32 + +module T = struct + type t = int32 [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int32 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int32 in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int32_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int32 + [@@@end] + let compare (x : t) y = compare x y + + let to_string = to_string + let of_string = of_string +end + +include T +include Comparator.Make(T) + +let num_bits = 32 + +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let float_of_bits = float_of_bits +let bits_of_float = bits_of_float +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = lognot +let bit_xor = logxor +let bit_or = logor +let bit_and = logand +let min_value = min_int +let max_value = max_int +let abs = abs +let pred = pred +let succ = succ +let rem = rem +let neg = neg +let minus_one = minus_one +let one = one +let zero = zero +let compare = compare +let to_float = to_float +let of_float_unchecked = of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + of_float f + else + Printf.invalid_argf "Int32.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () +;; + +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +module Infix_compare = struct + open Poly + + let ( >= ) (x : t) y = x >= y + let ( <= ) (x : t) y = x <= y + let ( = ) (x : t) y = x = y + let ( > ) (x : t) y = x > y + let ( < ) (x : t) y = x < y + let ( <> ) (x : t) y = x <> y +end + +module Compare = struct + include Infix_compare + + let compare = compare + let ascending = compare + let descending x y = compare y x + let min (x : t) y = if x < y then x else y + let max (x : t) y = if x > y then x else y + let equal (x : t) y = x = y + let between t ~low ~high = low <= t && t <= high + let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + + let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + + let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) +end + +include Compare + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let incr r = r := !r + one +let decr r = r := !r - one + +let of_int32 t = t +let of_int32_exn = of_int32 +let to_int32 t = t +let to_int32_exn = to_int32 + +let popcount = Popcount.int32_popcount + +module Conv = Int_conversions +let of_int = Conv.int_to_int32 +let of_int_exn = Conv.int_to_int32_exn +let of_int_trunc = Conv.int_to_int32_trunc +let to_int = Conv.int32_to_int +let to_int_exn = Conv.int32_to_int_exn +let to_int_trunc = Conv.int32_to_int_trunc +let of_int64 = Conv.int64_to_int32 +let of_int64_exn = Conv.int64_to_int32_exn +let of_int64_trunc = Conv.int64_to_int32_trunc +let to_int64 = Conv.int32_to_int64 +let of_nativeint = Conv.nativeint_to_int32 +let of_nativeint_exn = Conv.nativeint_to_int32_exn +let of_nativeint_trunc = Conv.nativeint_to_int32_trunc +let to_nativeint = Conv.int32_to_nativeint +let to_nativeint_exn = to_nativeint + +let pow b e = of_int_exn (Int_math.int_pow (to_int_exn b) (to_int_exn e)) +let ( ** ) b e = pow b e + +module Pow2 = struct + open! Import + open Int32_replace_polymorphic_compare + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + let ( lor ) = Caml.Int32.logor;; + let ( lsr ) = Caml.Int32.shift_right_logical;; + let ( land ) = Caml.Int32.logand;; + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 x = + if x <= Caml.Int32.zero then non_positive_argument (); + let x = Caml.Int32.pred x in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + Caml.Int32.succ x + ;; + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= Caml.Int32.zero then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + Caml.Int32.sub x (x lsr 1) + ;; + + let is_pow2 x = + if x <= Caml.Int32.zero then non_positive_argument (); + (x land (Caml.Int32.pred x)) = Caml.Int32.zero + ;; + + (* C stub for int clz to use the CLZ/BSR instruction where possible. *) + external int32_clz : int32 -> int = "Base_int_math_int32_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if i <= Caml.Int32.zero then + raise_s (Sexp.message "[Int32.floor_log2] got invalid input" + ["", sexp_of_int32 i]); + num_bits - 1 - int32_clz i + ;; + + (** Hacker's Delight Second Edition p106 *) + let ceil_log2 i = + if i <= Caml.Int32.zero then + raise_s (Sexp.message "[Int32.ceil_log2] got invalid input" + ["", sexp_of_int32 i]); + (* The [i = 1] check is needed because clz(0) is undefined *) + if Caml.Int32.equal i Caml.Int32.one + then 0 + else num_bits - int32_clz (Caml.Int32.pred i) + ;; +end +include Pow2 + +include Conv.Make (T) + +include Conv.Make_hex(struct + + type t = int32 [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_int32 + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int32 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int32 in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = Printf.sprintf "%lx" i + let of_string s = Caml.Scanf.sscanf s "%lx" Fn.id + + let module_name = "Base.Int32.Hex" + + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Int32" + end) + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Int32] and [Int32.O] agree value-wise *) diff --git a/src/int32.mli b/src/int32.mli new file mode 100644 index 0000000..bbefbd1 --- /dev/null +++ b/src/int32.mli @@ -0,0 +1,48 @@ +(** An int of exactly 32 bits, regardless of the machine. + + Side note: There's not much reason to want an int of at least 32 bits (i.e., 32 on + 32-bit machines and 63 on 64-bit machines) because [Int63] is basically just as + efficient. + + Overflow issues are {i not} generally considered and explicitly handled. This may be + more of an issue for 32-bit ints than 64-bit ints. + + [Int32.t] is boxed on both 32-bit and 64-bit machines. *) + +open! Import + +include Int_intf.S with type t = int32 + +(** {2 Conversion functions} *) + +val of_int : int -> t option +val to_int : t -> int option + +val of_int32 : int32 -> t +val to_int32 : t -> int32 + +val of_nativeint : nativeint -> t option +val to_nativeint : t -> nativeint + +val of_int64 : int64 -> t option + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val of_int_trunc : int -> t +val to_int_trunc : t -> int +val of_nativeint_trunc : nativeint -> t +val of_int64_trunc : int64 -> t + +(** {3 Low-level float conversions} *) + +(** Rounds a regular 64-bit OCaml float to a 32-bit IEEE-754 "single" float, and returns + its bit representation. We make no promises about the exact rounding behavior, or + what happens in case of over- or underflow. *) +val bits_of_float : float -> t + +(** Creates a 32-bit IEEE-754 "single" float from the given bits, and converts it to a + regular 64-bit OCaml float. *) +val float_of_bits : t -> float diff --git a/src/int63.ml b/src/int63.ml new file mode 100644 index 0000000..ac40b85 --- /dev/null +++ b/src/int63.ml @@ -0,0 +1,77 @@ +open! Import + +let raise_s = Error.raise_s +module Repr = Int63_emul.Repr + +include Int63_backend + +module Overflow_exn = struct + let ( + ) t u = + let sum = t + u in + if bit_or (bit_xor t u) (bit_xor t (bit_not sum)) < zero + then sum + else raise_s (Sexp.message "( + ) overflow" + [ "t" , sexp_of_t t + ; "u" , sexp_of_t u + ; "sum", sexp_of_t sum + ]) + ;; + + let ( - ) t u = + let diff = t - u in + let pos_diff = t > u in + if t <> u && Bool.(<>) pos_diff (is_positive diff) then + raise_s (Sexp.message "( - ) overflow" + [ "t" , sexp_of_t t + ; "u" , sexp_of_t u + ; "diff", sexp_of_t diff + ]) + else diff + ;; + + let abs t = if t = min_value then failwith "abs overflow" else abs t + let neg t = if t = min_value then failwith "neg overflow" else neg t +end + +let () = assert (Int.(=) num_bits 63) + +let random_of_int ?(state = Random.State.default) bound = + of_int (Random.State.int state (to_int_exn bound)) + +let random_of_int64 ?(state = Random.State.default) bound = + of_int64_exn (Random.State.int64 state (to_int64 bound)) + +let random = + match Word_size.word_size with + | W64 -> random_of_int + | W32 -> random_of_int64 + +let random_incl_of_int ?(state = Random.State.default) lo hi = + of_int (Random.State.int_incl state (to_int_exn lo) (to_int_exn hi)) + +let random_incl_of_int64 ?(state = Random.State.default) lo hi = + of_int64_exn (Random.State.int64_incl state (to_int64 lo) (to_int64 hi)) + +let random_incl = + match Word_size.word_size with + | W64 -> random_incl_of_int + | W32 -> random_incl_of_int64 + +let floor_log2 t = + match Word_size.word_size with + | W64 -> t |> to_int_exn |> Int.floor_log2 + | W32 -> + if t <= zero + then raise_s (Sexp.message "[Int.floor_log2] got invalid input" + ["", sexp_of_t t]); + let floor_log2 = ref (Int.( - ) num_bits 2) in + while equal zero (bit_and t (shift_left one !floor_log2)) do + floor_log2 := Int.( - ) !floor_log2 1 + done; + !floor_log2 +;; + +module Private = struct + module Repr = Repr + let repr = repr +end diff --git a/src/int63.mli b/src/int63.mli new file mode 100644 index 0000000..6a5c3d8 --- /dev/null +++ b/src/int63.mli @@ -0,0 +1,86 @@ +(** 63-bit integers. + + The size of Int63 is always 63 bits. On a 64-bit platform it is just an int + (63-bits), and on a 32-bit platform it is an int64 wrapped to respect the + semantics of 63-bit integers. + + Because [Int63] has different representations on 32-bit and 64-bit platforms, + marshalling [Int63] will not work between 32-bit and 64-bit platforms -- [unmarshal] + will segfault. *) + +open! Import + +(** In 64-bit architectures, we expose [type t = private int] so that the compiler can + omit [caml_modify] when dealing with record fields holding [Int63.t]. + + Code should not explicitly make use of the [private], e.g., via [(i :> int)], since + such code will not compile on 32-bit platforms. *) +include Int_intf.S with type t = Int63_backend.t + +(** {2 Arithmetic with overflow} + + Unlike the usual operations, these never overflow, preferring instead to raise. *) + +module Overflow_exn : sig + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val abs : t -> t + val neg : t -> t +end + +(** {2 Conversion functions} *) + +val of_int : int -> t +val to_int : t -> int option + +val of_int32 : Int32.t -> t +val to_int32 : t -> Int32.t option + +val of_int64 : Int64.t -> t option + +val of_nativeint : nativeint -> t option +val to_nativeint : t -> nativeint option + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val to_int_trunc : t -> int +val to_int32_trunc : t -> Int32.t +val of_int64_trunc : Int64.t -> t +val of_nativeint_trunc : nativeint -> t +val to_nativeint_trunc : t -> nativeint + +(** {2 Random generation} *) + +(** [random ~state bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. + + The default [~state] is [Random.State.default]. *) +val random : ?state:Random.State.t -> t -> t + +(** [random_incl ~state lo hi] returns a random integer between [lo] (inclusive) and [hi] + (inclusive). Raises if [lo > hi]. + + The default [~state] is [Random.State.default]. *) +val random_incl : ?state:Random.State.t -> t -> t -> t + +(** [floor_log2 x] returns the floor of log-base-2 of [x], and raises if [x <= 0]. *) +val floor_log2 : t -> int + +(**/**) +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + (** [val repr] states how [Int63.t] is represented, i.e., as an [int] or an [int64], and + can be used for building [Int63] operations that behave differently depending on the + representation (e.g., see core_int63.ml). *) + module Repr : sig + type ('underlying_type, 'intermediate_type) t = + | Int : (int , int ) t + | Int64 : (int64 , Int63_emul.t) t + end + val repr : (t, t) Repr.t +end diff --git a/src/int63_backends.ml b/src/int63_backends.ml new file mode 100644 index 0000000..5efde7d --- /dev/null +++ b/src/int63_backends.ml @@ -0,0 +1,50 @@ +open! Import + +let raise_s = Error.raise_s + +module type Int_or_more = sig + type t [@@deriving_inline hash] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + include Int_intf.S with type t := t + val of_int : int -> t + val to_int : t -> int option + val to_int_trunc : t -> int + val of_int32 : int32 -> t + val to_int32 : t -> Int32.t option + val to_int32_trunc : t -> Int32.t + val of_int64 : Int64.t -> t option + val of_int64_trunc : Int64.t -> t + val of_nativeint : nativeint -> t option + val to_nativeint : t -> nativeint option + val of_nativeint_trunc : nativeint -> t + val to_nativeint_trunc : t -> nativeint + val of_float_unchecked : float -> t + val repr : (t, t) Int63_emul.Repr.t +end + +module Native : Int_or_more with type t = private int = struct + include Int + let to_int x = Some x + let to_int_trunc x = x + (* [of_int32_exn] is a safe operation on platforms with 64-bit word sizes. *) + let of_int32 = of_int32_exn + let to_nativeint_trunc x = to_nativeint x + let to_nativeint x = Some (to_nativeint x) + let repr = Int63_emul.Repr.Int +end + +module Emulated : Int_or_more with type t = Int63_emul.t = Int63_emul + +let dynamic : (module Int_or_more) = + match Word_size.word_size with + | W64 -> (module Native : Int_or_more) + | W32 -> (module Emulated : Int_or_more) + +module Dynamic = (val dynamic) diff --git a/src/int63_emul.ml b/src/int63_emul.ml new file mode 100644 index 0000000..d91e3c9 --- /dev/null +++ b/src/int63_emul.ml @@ -0,0 +1,399 @@ +(* A 63bit integer is a 64bit integer with its bits shifted to the left + and its lowest bit set to 0. + This is the same kind of encoding as OCaml int on 64bit architecture. + The only difference being the lowest bit (immediate bit) set to 1. *) + +open! Import +include Int64_replace_polymorphic_compare + +module T0 = struct + module T = struct + type t = int64 [@@deriving_inline compare, hash, sexp] + let compare : t -> t -> int = compare_int64 + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int64 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int64 in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int64_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int64 + [@@@end] + end + include T + include Comparator.Make(T) +end + +module Conv = Int_conversions + +module W : sig + type t = int64 + include (module type of struct include T0 end with type t := t) + val wrap_exn : Caml.Int64.t -> t + val wrap_modulo : Caml.Int64.t -> t + val unwrap : t -> Caml.Int64.t + + (** Returns a non-negative int64 that is equal to the input int63 modulo 2^63. *) + val unwrap_unsigned : t -> Caml.Int64.t + val add : t -> t -> t + val sub : t -> t -> t + val neg : t -> t + val abs : t -> t + val succ : t -> t + val pred : t -> t + val mul : t -> t -> t + val pow : t -> t -> t + val div : t -> t -> t + val rem : t -> t -> t + val popcount : t -> int + val bit_not : t -> t + val bit_xor : t -> t -> t + val bit_or : t -> t -> t + val bit_and : t -> t -> t + val shift_left : t -> int -> t + val shift_right : t -> int -> t + val shift_right_logical : t -> int -> t + val min_value : t + val max_value : t + + val to_int64 : t -> Caml.Int64.t + val of_int64 : Caml.Int64.t -> t option + val of_int64_exn : Caml.Int64.t -> t + val of_int64_trunc : Caml.Int64.t -> t + + val compare : t -> t -> int + + val ceil_pow2 : t -> t + val floor_pow2 : t -> t + val ceil_log2 : t -> int + val floor_log2 : t -> int + val is_pow2 : t -> bool +end = struct + type t = int64 + include (T0 : module type of struct include T0 end with type t := t) + + let wrap_exn x = + (* Raises if the int64 value does not fit on int63. *) + Conv.int64_fit_on_int63_exn x; + Caml.Int64.mul x 2L + let wrap x = + if Conv.int64_is_representable_as_int63 x + then Some (Caml.Int64.mul x 2L) + else None + let wrap_modulo x = + Caml.Int64.mul x 2L + let unwrap x = + Caml.Int64.shift_right x 1 + let unwrap_unsigned x = + Caml.Int64.shift_right_logical x 1 + + (* This does not use wrap or unwrap to avoid generating exceptions in the case of + overflows. This is to preserve the semantics of int type on 64 bit architecture. *) + let f2 f a b = Caml.Int64.mul (f (Caml.Int64.shift_right a 1) (Caml.Int64.shift_right b 1)) 2L + + let mask = 0xffff_ffff_ffff_fffeL + + let m x = Caml.Int64.logand x mask + + let add x y = Caml.Int64.add x y + let sub x y = Caml.Int64.sub x y + let neg x = Caml.Int64.neg x + let abs x = Caml.Int64.abs x + let one = wrap_exn 1L + let succ a = add a one + let pred a = sub a one + let min_value = m Caml.Int64.min_int + let max_value = m Caml.Int64.max_int + let bit_not x = m (Caml.Int64.lognot x) + let bit_and = Caml.Int64.logand + let bit_xor = Caml.Int64.logxor + let bit_or = Caml.Int64.logor + let shift_left x i = Caml.Int64.shift_left x i + let shift_right x i = m (Caml.Int64.shift_right x i) + let shift_right_logical x i = m (Caml.Int64.shift_right_logical x i) + let pow = f2 Int_math.int63_pow_on_int64 + let mul a b = Caml.Int64.mul a (Caml.Int64.shift_right b 1) + let div a b = wrap_modulo (Caml.Int64.div a b) + let rem a b = Caml.Int64.rem a b + let popcount x = Popcount.int64_popcount x + + let to_int64 t = unwrap t + let of_int64 t = wrap t + let of_int64_exn t = wrap_exn t + let of_int64_trunc t = wrap_modulo t + + let t_of_sexp x = wrap_exn (int64_of_sexp x) + let sexp_of_t x = sexp_of_int64 (unwrap x) + + let compare (x : t) y = compare x y + + let is_pow2 x = Int64.is_pow2 (unwrap x) + let floor_pow2 x = Int64.floor_pow2 (unwrap x) |> wrap_exn + let ceil_pow2 x = Int64.floor_pow2 (unwrap x) |> wrap_exn + let floor_log2 x = Int64.floor_log2 (unwrap x) + let ceil_log2 x = Int64.ceil_log2 (unwrap x) +end + +open W + +module T = struct + type t = W.t [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + W.hash_fold_t + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = W.hash in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = W.t_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = W.sexp_of_t + [@@@end] + type comparator_witness = W.comparator_witness + let comparator = W.comparator + + let compare = W.compare + + (* We don't expect [hash] to follow the behavior of int in 64bit architecture *) + let _ = hash + let hash (x : t) = Caml.Hashtbl.hash x + + let invalid_str x = failwith (Printf.sprintf "Int63.of_string: invalid input %S" x) + + (* + "sign" refers to whether the number starts with a '-' + "signedness = false" means the rest of the number is parsed as unsigned and then cast + to signed with wrap-around modulo 2^i + "signedness = true" means no such craziness happens + + The terminology and the logic is due to the code in byterun/ints.c in ocaml 4.03 + ([parse_sign_and_base] function). + + Signedness equals true for plain decimal number (e.g. 1235, -6789) + + Signedness equals false in the following cases: + - [0xffff], [-0xffff] (hexadecimal representation) + - [0b0101], [-0b0101] (binary representation) + - [0o1237], [-0o1237] (octal representation) + - [0u9812], [-0u9812] (unsigned decimal representation - available from OCaml 4.03) *) + let sign_and_signedness x = + let len = String.length x in + let open Int_replace_polymorphic_compare in + let pos,sign = + if 0 < len + then match x.[0] with + | '-' -> 1,`Neg + | '+' -> 1, `Pos + | _ -> 0, `Pos + else + 0, `Pos + in + if pos + 2 < len then + let c1 = x.[pos] in + let c2 = x.[pos + 1] in + match c1, c2 with + | '0', ('0' .. '9') -> sign,true + | '0', _ -> sign,false + | _ -> sign,true + else sign, true + + let to_string x = Caml.Int64.to_string (unwrap x) + + let of_string str = + try + let sign,signedness = sign_and_signedness str in + if signedness + then of_int64_exn (Caml.Int64.of_string str) + else + let pos_str = + match sign with + | `Neg -> String.sub str ~pos:1 ~len:(String.length str - 1) + | `Pos -> str + in + let int64 = Caml.Int64.of_string pos_str in + (* unsigned 63-bit int must parse as a positive signed 64-bit int *) + if Int64_replace_polymorphic_compare.(<) int64 0L then invalid_str str; + let int63 = wrap_modulo int64 in + match sign with + | `Neg -> neg int63 + | `Pos -> int63 + with _ -> invalid_str str +end + +include T + +let num_bits = 63 + +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = bit_not +let bit_xor = bit_xor +let bit_or = bit_or +let bit_and = bit_and +let popcount = popcount +let abs = abs +let pred = pred +let succ = succ +let pow = pow +let rem = rem +let neg = neg +let max_value = max_value +let min_value = min_value +let minus_one = wrap_exn Caml.Int64.minus_one +let one = wrap_exn Caml.Int64.one +let zero = wrap_exn Caml.Int64.zero +let is_pow2 = is_pow2 +let floor_pow2 = floor_pow2 +let ceil_pow2 = ceil_pow2 +let floor_log2 = floor_log2 +let ceil_log2 = ceil_log2 +let to_float x = Caml.Int64.to_float (unwrap x) +let of_float_unchecked x = wrap_modulo (Caml.Int64.of_float x) +let of_float t = + let open Float_replace_polymorphic_compare in + if t >= float_lower_bound && t <= float_upper_bound then + wrap_modulo (Caml.Int64.of_float t) + else + Printf.invalid_argf "Int63.of_float: argument (%f) is out of range or NaN" + (Float0.box t) + () +let of_int64 = of_int64 +let of_int64_exn = of_int64_exn +let of_int64_trunc = of_int64_trunc +let to_int64 = to_int64 + +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let ( ** ) b e = pow b e + +let incr r = r := !r + one +let decr r = r := !r - one + +(* We can reuse conversion function from/to int64 here. *) +let of_int x = wrap_exn (Conv.int_to_int64 x) +let of_int_exn x = of_int x +let to_int x = Conv.int64_to_int (unwrap x) +let to_int_exn x = Conv.int64_to_int_exn (unwrap x) +let to_int_trunc x = Conv.int64_to_int_trunc (unwrap x) + +let of_int32 x = wrap_exn (Conv.int32_to_int64 x) +let of_int32_exn x = of_int32 x +let to_int32 x = Conv.int64_to_int32 (unwrap x) +let to_int32_exn x = Conv.int64_to_int32_exn (unwrap x) +let to_int32_trunc x = Conv.int64_to_int32_trunc (unwrap x) + +let of_nativeint x = of_int64 (Conv.nativeint_to_int64 x) +let of_nativeint_exn x = wrap_exn (Conv.nativeint_to_int64 x) +let of_nativeint_trunc x = of_int64_trunc (Conv.nativeint_to_int64 x) +let to_nativeint x = Conv.int64_to_nativeint (unwrap x) +let to_nativeint_exn x = Conv.int64_to_nativeint_exn (unwrap x) +let to_nativeint_trunc x = Conv.int64_to_nativeint_trunc (unwrap x) + +include Conv.Make (T) + +include Conv.Make_hex(struct + + type t = T.t [@@deriving_inline compare, hash] + let compare : t -> t -> int = T.compare + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + T.hash_fold_t + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = T.hash in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = + (* the use of [unwrap_unsigned] here is important for the case of [min_value] *) + Printf.sprintf "%Lx" (unwrap_unsigned i) + let of_string s = of_string ("0x"^s) + let module_name = "Base.Int63.Hex" + + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string x = to_string x + let module_name = "Base.Int63" + end) + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Int64_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Int63] and [Int63.O] agree value-wise *) + +module Repr = struct + type emulated = t + type ('underlying_type, 'intermediate_type) t = + | Int : (int , int ) t + | Int64 : (int64 , emulated) t +end + +let repr = Repr.Int64 + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Int64_replace_polymorphic_compare diff --git a/src/int63_emul.mli b/src/int63_emul.mli new file mode 100644 index 0000000..2c159fb --- /dev/null +++ b/src/int63_emul.mli @@ -0,0 +1,34 @@ +open! Import + +include Int_intf.S + +val of_int : int -> t +val to_int : t -> int option +val to_int_trunc : t -> int + +val of_int32 : int32 -> t +val to_int32 : t -> Int32.t option +val to_int32_trunc : t -> Int32.t + +val of_int64 : Int64.t -> t option +val of_int64_trunc : Int64.t -> t + +val of_nativeint : nativeint -> t option +val to_nativeint : t -> nativeint option +val of_nativeint_trunc : nativeint -> t +val to_nativeint_trunc : t -> nativeint + +(*_ exported for Core_kernel *) +module W : sig + val wrap_exn : int64 -> t + val unwrap : t -> int64 +end + +module Repr : sig + type emulated = t + type ('underlying_type, 'intermediate_type) t = + | Int : (int , int ) t + | Int64 : (int64 , emulated) t +end with type emulated := t + +val repr : (t, t) Repr.t diff --git a/src/int64.ml b/src/int64.ml new file mode 100644 index 0000000..3023a4d --- /dev/null +++ b/src/int64.ml @@ -0,0 +1,260 @@ +open! Import +open! Caml.Int64 + +module T = struct + type t = int64 [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int64 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int64 in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = int64_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_int64 + [@@@end] + + let compare = Int64_replace_polymorphic_compare.compare + + let to_string = to_string + let of_string = of_string +end + +include T +include Comparator.Make(T) + +let num_bits = 64 +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let float_of_bits = float_of_bits +let bits_of_float = bits_of_float +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = lognot +let bit_xor = logxor +let bit_or = logor +let bit_and = logand +let min_value = min_int +let max_value = max_int +let abs = abs +let pred = pred +let succ = succ +let pow = Int_math.int64_pow +let rem = rem +let neg = neg +let minus_one = minus_one +let one = one +let zero = zero +let to_float = to_float +let of_float_unchecked = Caml.Int64.of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + Caml.Int64.of_float f + else + Printf.invalid_argf "Int64.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () + +let ( ** ) b e = pow b e + +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open Int64_replace_polymorphic_compare + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let incr r = r := !r + one +let decr r = r := !r - one + +let of_int64 t = t +let of_int64_exn = of_int64 +let to_int64 t = t + +let popcount = Popcount.int64_popcount + +module Conv = Int_conversions +let of_int = Conv.int_to_int64 +let of_int_exn = of_int +let to_int = Conv.int64_to_int +let to_int_exn = Conv.int64_to_int_exn +let to_int_trunc = Conv.int64_to_int_trunc +let of_int32 = Conv.int32_to_int64 +let of_int32_exn = of_int32 +let to_int32 = Conv.int64_to_int32 +let to_int32_exn = Conv.int64_to_int32_exn +let to_int32_trunc = Conv.int64_to_int32_trunc +let of_nativeint = Conv.nativeint_to_int64 +let of_nativeint_exn = of_nativeint +let to_nativeint = Conv.int64_to_nativeint +let to_nativeint_exn = Conv.int64_to_nativeint_exn +let to_nativeint_trunc = Conv.int64_to_nativeint_trunc + +module Pow2 = struct + open! Import + open Int64_replace_polymorphic_compare + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + let ( lor ) = Caml.Int64.logor;; + let ( lsr ) = Caml.Int64.shift_right_logical;; + let ( land ) = Caml.Int64.logand;; + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 x = + if x <= Caml.Int64.zero then non_positive_argument (); + let x = Caml.Int64.pred x in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + let x = x lor (x lsr 32) in + Caml.Int64.succ x + ;; + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= Caml.Int64.zero then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + let x = x lor (x lsr 32) in + Caml.Int64.sub x (x lsr 1) + ;; + + let is_pow2 x = + if x <= Caml.Int64.zero then non_positive_argument (); + (x land (Caml.Int64.pred x)) = Caml.Int64.zero + ;; + + (* C stub for int clz to use the CLZ/BSR instruction where possible *) + external int64_clz : int64 -> int = "Base_int_math_int64_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if i <= Caml.Int64.zero then + raise_s (Sexp.message "[Int64.floor_log2] got invalid input" + ["", sexp_of_int64 i]); + num_bits - 1 - int64_clz i + ;; + + (** Hacker's Delight Second Edition p106 *) + let ceil_log2 i = + if Poly.( <= ) i Caml.Int64.zero then + raise_s (Sexp.message "[Int64.ceil_log2] got invalid input" + ["", sexp_of_int64 i]); + if Caml.Int64.equal i Caml.Int64.one + then 0 + else num_bits - int64_clz (Caml.Int64.pred i) + ;; +end +include Pow2 + +include Conv.Make (T) + +include Conv.Make_hex(struct + + type t = int64 [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_int64 + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_int64 + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_int64 in fun x -> func x + [@@@end] + + let zero = zero + let neg = (~-) + let (<) = (<) + let to_string i = Printf.sprintf "%Lx" i + let of_string s = Caml.Scanf.sscanf s "%Lx" Fn.id + + let module_name = "Base.Int64.Hex" + + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Int64" + end) + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Int64_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Int64] and [Int64.O] agree value-wise *) + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Int64_replace_polymorphic_compare diff --git a/src/int64.mli b/src/int64.mli new file mode 100644 index 0000000..ffbf26c --- /dev/null +++ b/src/int64.mli @@ -0,0 +1,32 @@ +(** 64-bit integers. *) + +open! Import + +include Int_intf.S with type t = int64 + +(** {2 Conversion functions} *) + +val of_int : int -> t +val to_int : t -> int option + +val of_int32 : int32 -> t +val to_int32 : t -> int32 option + +val of_nativeint : nativeint -> t +val to_nativeint : t -> nativeint option + +val of_int64 : t -> t + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val to_int_trunc : t -> int +val to_int32_trunc : t -> int32 +val to_nativeint_trunc : t -> nativeint + +(** {3 Low-level float conversions} *) + +val bits_of_float : float -> t +val float_of_bits : t -> float diff --git a/src/int_conversions.ml b/src/int_conversions.ml new file mode 100644 index 0000000..5a34845 --- /dev/null +++ b/src/int_conversions.ml @@ -0,0 +1,359 @@ +open! Import + +module Int = Int0 +module Sys = Sys0 + +let [@inline never] convert_failure x a b to_string = + Printf.failwithf + "conversion from %s to %s failed: %s is out of range" + a + b + (to_string x) + () + +let num_bits_int = Sys.int_size_in_bits +let num_bits_int32 = 32 +let num_bits_int64 = 64 +let num_bits_nativeint = Word_size.num_bits Word_size.word_size + +let () = + assert (num_bits_int = 63 + || num_bits_int = 31 + || num_bits_int = 32) + +let min_int32 = Caml.Int32.min_int +let max_int32 = Caml.Int32.max_int +let min_int64 = Caml.Int64.min_int +let max_int64 = Caml.Int64.max_int +let min_nativeint = Caml.Nativeint.min_int +let max_nativeint = Caml.Nativeint.max_int + +let int_to_string = Caml.string_of_int +let int32_to_string = Caml.Int32.to_string +let int64_to_string = Caml.Int64.to_string +let nativeint_to_string = Caml.Nativeint.to_string + +(* int <-> int32 *) + +let int_to_int32_failure x = convert_failure x "int" "int32" int_to_string +let int32_to_int_failure x = convert_failure x "int32" "int" int32_to_string + +let int32_to_int_trunc = Caml.Int32.to_int +let int_to_int32_trunc = Caml.Int32.of_int + +let int_is_representable_as_int32 = + if num_bits_int <= num_bits_int32 + then (fun _ -> true) + else + let min = int32_to_int_trunc min_int32 in + let max = int32_to_int_trunc max_int32 in + (fun x -> compare_int min x <= 0 && compare_int x max <= 0) + +let int32_is_representable_as_int = + if num_bits_int32 <= num_bits_int + then (fun _ -> true) + else + let min = int_to_int32_trunc Int.min_value in + let max = int_to_int32_trunc Int.max_value in + (fun x -> compare_int32 min x <= 0 && compare_int32 x max <= 0) + +let int_to_int32 x = + if int_is_representable_as_int32 x + then Some (int_to_int32_trunc x) + else None + +let int32_to_int x = + if int32_is_representable_as_int x + then Some (int32_to_int_trunc x) + else None + +let int_to_int32_exn x = + if int_is_representable_as_int32 x + then int_to_int32_trunc x + else int_to_int32_failure x + +let int32_to_int_exn x = + if int32_is_representable_as_int x + then int32_to_int_trunc x + else int32_to_int_failure x + +(* int <-> int64 *) + +let int64_to_int_failure x = convert_failure x "int64" "int" int64_to_string + +let () = assert (num_bits_int < num_bits_int64) + +let int_to_int64 = Caml.Int64.of_int +let int64_to_int_trunc = Caml.Int64.to_int + +let int64_is_representable_as_int = + let min = int_to_int64 Int.min_value in + let max = int_to_int64 Int.max_value in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_to_int x = + if int64_is_representable_as_int x + then Some (int64_to_int_trunc x) + else None + +let int64_to_int_exn x = + if int64_is_representable_as_int x + then int64_to_int_trunc x + else int64_to_int_failure x + +(* int <-> nativeint *) + +let nativeint_to_int_failure x = convert_failure x "nativeint" "int" nativeint_to_string + +let () = assert (num_bits_int <= num_bits_nativeint) + +let int_to_nativeint = Caml.Nativeint.of_int +let nativeint_to_int_trunc = Caml.Nativeint.to_int + +let nativeint_is_representable_as_int = + if num_bits_nativeint <= num_bits_int + then (fun _ -> true) + else + let min = int_to_nativeint Int.min_value in + let max = int_to_nativeint Int.max_value in + (fun x -> compare_nativeint min x <= 0 && compare_nativeint x max <= 0) + +let nativeint_to_int x = + if nativeint_is_representable_as_int x + then Some (nativeint_to_int_trunc x) + else None + +let nativeint_to_int_exn x = + if nativeint_is_representable_as_int x + then nativeint_to_int_trunc x + else nativeint_to_int_failure x + +(* int32 <-> int64 *) + +let int64_to_int32_failure x = convert_failure x "int64" "int32" int64_to_string + +let () = assert (num_bits_int32 < num_bits_int64) + +let int32_to_int64 = Caml.Int64.of_int32 +let int64_to_int32_trunc = Caml.Int64.to_int32 + +let int64_is_representable_as_int32 = + let min = int32_to_int64 min_int32 in + let max = int32_to_int64 max_int32 in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_to_int32 x = + if int64_is_representable_as_int32 x + then Some (int64_to_int32_trunc x) + else None + +let int64_to_int32_exn x = + if int64_is_representable_as_int32 x + then int64_to_int32_trunc x + else int64_to_int32_failure x + +(* int32 <-> nativeint *) + +let nativeint_to_int32_failure x = + convert_failure x "nativeint" "int32" nativeint_to_string + +let () = assert (num_bits_int32 <= num_bits_nativeint) + +let int32_to_nativeint = Caml.Nativeint.of_int32 +let nativeint_to_int32_trunc = Caml.Nativeint.to_int32 + +let nativeint_is_representable_as_int32 = + if num_bits_nativeint <= num_bits_int32 + then (fun _ -> true) + else + let min = int32_to_nativeint min_int32 in + let max = int32_to_nativeint max_int32 in + (fun x -> compare_nativeint min x <= 0 && compare_nativeint x max <= 0) + +let nativeint_to_int32 x = + if nativeint_is_representable_as_int32 x + then Some (nativeint_to_int32_trunc x) + else None + +let nativeint_to_int32_exn x = + if nativeint_is_representable_as_int32 x + then nativeint_to_int32_trunc x + else nativeint_to_int32_failure x + + +(* int64 <-> nativeint *) + +let int64_to_nativeint_failure x = convert_failure x "int64" "nativeint" int64_to_string + +let () = assert (num_bits_int64 >= num_bits_nativeint) + +let int64_to_nativeint_trunc = Caml.Int64.to_nativeint +let nativeint_to_int64 = Caml.Int64.of_nativeint + +let int64_is_representable_as_nativeint = + if num_bits_int64 <= num_bits_nativeint + then (fun _ -> true) + else + let min = nativeint_to_int64 min_nativeint in + let max = nativeint_to_int64 max_nativeint in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_to_nativeint x = + if int64_is_representable_as_nativeint x + then Some (int64_to_nativeint_trunc x) + else None + +let int64_to_nativeint_exn x = + if int64_is_representable_as_nativeint x + then int64_to_nativeint_trunc x + else int64_to_nativeint_failure x + +(* int64 <-> int63 *) + +let int64_to_int63_failure x = convert_failure x "int64" "int63" int64_to_string + +let int64_is_representable_as_int63 = + let min = Caml.Int64.shift_right min_int64 1 in + let max = Caml.Int64.shift_right max_int64 1 in + (fun x -> compare_int64 min x <= 0 && compare_int64 x max <= 0) + +let int64_fit_on_int63_exn x = + if int64_is_representable_as_int63 x + then () + else int64_to_int63_failure x + +(* string conversions *) + +let insert_delimiter_every input ~delimiter ~chars_per_delimiter = + let input_length = String.length input in + if input_length <= chars_per_delimiter then + input + else begin + let has_sign = match input.[0] with '+' | '-' -> true | _ -> false in + let num_digits = if has_sign then input_length - 1 else input_length in + let num_delimiters = (num_digits - 1) / chars_per_delimiter in + let output_length = input_length + num_delimiters in + let output = Bytes.create output_length in + let input_pos = ref (input_length - 1) in + let output_pos = ref (output_length - 1) in + let num_chars_until_delimiter = ref chars_per_delimiter in + let first_digit_pos = if has_sign then 1 else 0 in + while !input_pos >= first_digit_pos do + if !num_chars_until_delimiter = 0 then begin + Bytes.set output !output_pos delimiter; + decr output_pos; + num_chars_until_delimiter := chars_per_delimiter; + end; + Bytes.set output !output_pos input.[!input_pos]; + decr input_pos; + decr output_pos; + decr num_chars_until_delimiter; + done; + if has_sign then Bytes.set output 0 input.[0]; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:output; + end +;; + +let insert_delimiter input ~delimiter = + insert_delimiter_every input ~delimiter ~chars_per_delimiter:3 + +let insert_underscores input = + insert_delimiter input ~delimiter:'_' + +let sexp_of_int_style = Sexp.of_int_style + +module Make (I : sig + type t + val to_string : t -> string + end) = struct + + open I + + let chars_per_delimiter = 3 + + let to_string_hum ?(delimiter='_') t = + insert_delimiter_every (to_string t) ~delimiter ~chars_per_delimiter + + let sexp_of_t t = + let s = to_string t in + Sexp.Atom + (match !sexp_of_int_style with + | `Underscores -> insert_delimiter_every s ~chars_per_delimiter ~delimiter:'_' + | `No_underscores -> s) + ;; +end + +module Make_hex (I : sig + type t [@@deriving_inline compare, hash] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + val to_string : t -> string + val of_string : string -> t + val zero : t + val (<) : t -> t -> bool + val neg : t -> t + val module_name : string + end) = +struct + + module T_hex = struct + + type t = I.t [@@deriving_inline compare, hash] + let compare : t -> t -> int = I.compare + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + I.hash_fold_t + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = I.hash in fun x -> func x + [@@@end] + + let chars_per_delimiter = 4 + + let to_string' ?delimiter t = + let make_suffix = + match delimiter with + | None -> I.to_string + | Some delimiter -> + (fun t -> + insert_delimiter_every (I.to_string t) ~delimiter ~chars_per_delimiter) + in + if I.(<) t I.zero + then "-0x" ^ make_suffix (I.neg t) + else "0x" ^ make_suffix t + + let to_string t = to_string' t ?delimiter:None + + let to_string_hum ?(delimiter='_') t = to_string' t ~delimiter + + let invalid str = + failwith (Printf.sprintf "%s.of_string: invalid input %S" I.module_name str) + + let of_string_with_delimiter str = + I.of_string (String.filter str ~f:(fun c -> Char.( <> ) c '_')) + + let of_string str = + let module L = Hex_lexer in + let lex = Caml.Lexing.from_string str in + let result = Option.try_with (fun () -> L.parse_hex lex) in + if lex.lex_curr_pos = lex.lex_buffer_len then ( + match result with + | None -> invalid str + | Some (Neg body) -> I.neg (of_string_with_delimiter body) + | Some (Pos body) -> of_string_with_delimiter body + ) else + invalid str + end + + module Hex = struct + include T_hex + include Sexpable.Of_stringable(T_hex) + end + +end diff --git a/src/int_conversions.mli b/src/int_conversions.mli new file mode 100644 index 0000000..23255ae --- /dev/null +++ b/src/int_conversions.mli @@ -0,0 +1,134 @@ +(** Conversions between various integer types *) + +open! Import + +(** Ocaml has the following integer types, with the following bit widths + on 32-bit and 64-bit architectures. + + {v + arch arch + type 32b 64b + ---------------------- + int 31 63 (32 when compiled to JavaScript) + nativeint 32 64 + int32 32 32 + int64 64 64 + v} + + In both cases, the following inequalities hold: + + {[ + width(int) < width(nativeint) + && width(int32) <= width(nativeint) <= width(int64) + ]} + + The conversion functions come in one of two flavors. + + If width(foo) <= width(bar) on both 32-bit and 64-bit architectures, then we have + + {[ val foo_to_bar : foo -> bar ]} + + otherwise we have + + {[ + val foo_to_bar : foo -> bar option + val foo_to_bar_exn : foo -> bar + ]} *) +val int_to_int32 : int -> int32 option +val int_to_int32_exn : int -> int32 +val int_to_int32_trunc : int -> int32 +val int_to_int64 : int -> int64 +val int_to_nativeint : int -> nativeint + +val int32_to_int : int32 -> int option +val int32_to_int_exn : int32 -> int +val int32_to_int_trunc : int32 -> int +val int32_to_int64 : int32 -> int64 +val int32_to_nativeint : int32 -> nativeint + +val int64_to_int : int64 -> int option +val int64_to_int_exn : int64 -> int +val int64_to_int_trunc : int64 -> int +val int64_to_int32 : int64 -> int32 option +val int64_to_int32_exn : int64 -> int32 +val int64_to_int32_trunc : int64 -> int32 +val int64_to_nativeint : int64 -> nativeint option +val int64_to_nativeint_exn : int64 -> nativeint +val int64_to_nativeint_trunc : int64 -> nativeint + +val int64_fit_on_int63_exn : int64 -> unit +val int64_is_representable_as_int63 : int64 -> bool + +val nativeint_to_int : nativeint -> int option +val nativeint_to_int_exn : nativeint -> int +val nativeint_to_int_trunc : nativeint -> int +val nativeint_to_int32 : nativeint -> int32 option +val nativeint_to_int32_exn : nativeint -> int32 +val nativeint_to_int32_trunc : nativeint -> int32 +val nativeint_to_int64 : nativeint -> int64 + +val num_bits_int : int +val num_bits_int32 : int +val num_bits_int64 : int +val num_bits_nativeint : int + +(** human-friendly string (and possibly sexp) conversions *) +module Make (I : sig + + type t + + val to_string : t -> string + + end) : sig + + val to_string_hum + : ?delimiter:char (** defaults to ['_'] *) + -> I.t + -> string + + val sexp_of_t : I.t -> Sexp.t + +end + +module Make_hex (I : sig + type t [@@deriving_inline compare, hash] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + + (** [to_string] and [of_string] convert between [t] and unsigned, + unprefixed hexadecimal. + They must be able to handle all non-negative values and also + [min_value]. [to_string min_value] must write a positive hex + representation. *) + val to_string : t -> string + val of_string : string -> t + val zero : t + val (<) : t -> t -> bool + val neg : t -> t + val module_name : string + end) + : Int_intf.Hexable with type t := I.t +(** in the output, [to_string], [of_string], [sexp_of_t], and [t_of_sexp] convert + between [t] and signed hexadecimal with an optional "0x" or "0X" prefix. *) + +(** global ref affecting whether the [sexp_of_t] returned by [Make] + is consistent with the [to_string] input or the [to_string_hum] output *) +val sexp_of_int_style : [ `No_underscores | `Underscores ] ref + +(** utility for defining to_string_hum on numeric types -- takes a string matching + (-|+)?[0-9a-fA-F]+ and puts [delimiter] every [chars_per_delimiter] characters + starting from the right. *) +val insert_delimiter_every : string -> delimiter:char -> chars_per_delimiter:int -> string + +(** [insert_delimiter_every ~chars_per_delimiter:3] *) +val insert_delimiter : string -> delimiter:char -> string + +(** [insert_delimiter ~delimiter:'_'] *) +val insert_underscores : string -> string diff --git a/src/int_intf.ml b/src/int_intf.ml new file mode 100644 index 0000000..2c75eaa --- /dev/null +++ b/src/int_intf.ml @@ -0,0 +1,363 @@ +(** An interface to use for int-like types, e.g., {{!Base.Int}[Int]} and + {{!Base.Int64}[Int64]}. *) + +open! Import + +module type Round = sig + type t + + (** [round] rounds an int to a multiple of a given [to_multiple_of] argument, according + to a direction [dir], with default [dir] being [`Nearest]. [round] will raise if + [to_multiple_of <= 0]. + + {v + | `Down | rounds toward Int.neg_infinity | + | `Up | rounds toward Int.infinity | + | `Nearest | rounds to the nearest multiple, or `Up in case of a tie | + | `Zero | rounds toward zero | + v} + + Here are some examples for [round ~to_multiple_of:10] for each direction: + + {v + | `Down | {10 .. 19} --> 10 | { 0 ... 9} --> 0 | {-10 ... -1} --> -10 | + | `Up | { 1 .. 10} --> 10 | {-9 ... 0} --> 0 | {-19 .. -10} --> -10 | + | `Zero | {10 .. 19} --> 10 | {-9 ... 9} --> 0 | {-19 .. -10} --> -10 | + | `Nearest | { 5 .. 14} --> 10 | {-5 ... 4} --> 0 | {-15 ... -6} --> -10 | + v} + + For convenience and performance, there are variants of [round] with [dir] + hard-coded. If you are writing performance-critical code you should use these. *) + + val round : ?dir:[ `Zero | `Nearest | `Up | `Down ] -> t -> to_multiple_of:t -> t + + val round_towards_zero : t -> to_multiple_of:t -> t + val round_down : t -> to_multiple_of:t -> t + val round_up : t -> to_multiple_of:t -> t + val round_nearest : t -> to_multiple_of:t -> t +end + +module type Hexable = sig + type t + module Hex : sig + type nonrec t = t [@@deriving_inline sexp, compare, hash] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + + include Stringable.S with type t := t + + val to_string_hum : ?delimiter:char -> t -> string + end +end + +module type S_common = sig + type t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Floatable.S with type t := t + include Intable.S with type t := t + include Identifiable.S with type t := t + include Comparable.With_zero with type t := t + include Hexable with type t := t + + (** [delimiter] is an underscore by default. *) + val to_string_hum : ?delimiter:char -> t -> string + + (** {2 Infix operators and constants} *) + + val zero : t + val one : t + val minus_one : t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + + (** Integer exponentiation *) + val ( ** ) : t -> t -> t + + (** Negation *) + + val neg : t -> t + val ( ~- ) : t -> t + + (** There are two pairs of integer division and remainder functions, [/%] and [%], and + [/] and [rem]. They both satisfy the same equation relating the quotient and the + remainder: + + {[ + x = (x /% y) * y + (x % y); + x = (x / y) * y + (rem x y); + ]} + + The functions return the same values if [x] and [y] are positive. They all raise + if [y = 0]. + + The functions differ if [x < 0] or [y < 0]. + + If [y < 0], then [%] and [/%] raise, whereas [/] and [rem] do not. + + [x % y] always returns a value between 0 and [y - 1], even when [x < 0]. On the + other hand, [rem x y] returns a negative value if and only if [x < 0]; that value + satisfies [abs (rem x y) <= abs y - 1]. *) + + val ( /% ) : t -> t -> t + val ( % ) : t -> t -> t + val ( / ) : t -> t -> t + val rem : t -> t -> t + + (** Float division of integers. *) + val ( // ) : t -> t -> float + + (** Same as [bit_and]. *) + val ( land ) : t -> t -> t + + (** Same as [bit_or]. *) + val ( lor ) : t -> t -> t + + (** Same as [bit_xor]. *) + val ( lxor ) : t -> t -> t + + (** Same as [bit_not]. *) + val lnot : t -> t + + (** Same as [shift_left]. *) + val ( lsl ) : t -> int -> t + + (** Same as [shift_right]. *) + val ( asr ) : t -> int -> t + + (** {2 Other common functions} *) + + include Round with type t := t + + (** Returns the absolute value of the argument. May be negative if the input is + [min_value]. *) + val abs : t -> t + + (** {2 Successor and predecessor functions} *) + + val succ : t -> t + val pred : t -> t + + (** {2 Exponentiation} *) + + (** [pow base exponent] returns [base] raised to the power of [exponent]. It is OK if + [base <= 0]. [pow] raises if [exponent < 0], or an integer overflow would occur. *) + val pow : t -> t -> t + + (** {2 Bit-wise logical operations } *) + + (** These are identical to [land], [lor], etc. except they're not infix and have + different names. *) + val bit_and : t -> t -> t + val bit_or : t -> t -> t + val bit_xor : t -> t -> t + val bit_not : t -> t + + (** Returns the number of 1 bits in the binary representation of the input. *) + val popcount : t -> int + + (** {2 Bit-shifting operations } + + The results are unspecified for negative shifts and shifts [>= num_bits]. *) + + (** Shifts left, filling in with zeroes. *) + val shift_left : t -> int -> t + + (** Shifts right, preserving the sign of the input. *) + val shift_right : t -> int -> t + + (** {2 Increment and decrement functions for integer references } *) + + val decr : t ref -> unit + val incr : t ref -> unit + + (** {2 Conversion functions to related integer types} *) + + val of_int32_exn : int32 -> t + val to_int32_exn : t -> int32 + val of_int64_exn : int64 -> t + val to_int64 : t -> int64 + val of_nativeint_exn : nativeint -> t + val to_nativeint_exn : t -> nativeint + + (** [of_float_unchecked] truncates the given floating point number to an integer, + rounding towards zero. + The result is unspecified if the argument is nan or falls outside the range + of representable integers. *) + val of_float_unchecked : float -> t +end + +module type Operators_unbounded = sig + type t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ~- ) : t -> t + val ( ** ) : t -> t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + + val ( % ) : t -> t -> t + val ( /% ) : t -> t -> t + val ( // ) : t -> t -> float + + val ( land ) : t -> t -> t + val ( lor ) : t -> t -> t + val ( lxor ) : t -> t -> t + val lnot : t -> t + + val ( lsl ) : t -> int -> t + val ( asr ) : t -> int -> t +end + +module type Operators = sig + include Operators_unbounded + val ( lsr ) : t -> int -> t +end + +(** [S_unbounded] is a generic interface for unbounded integers, e.g. [Bignum.Bigint]. + [S_unbounded] is a restriction of [S] (below) that omits values that depend on + fixed-size integers. *) +module type S_unbounded = sig + include S_common (** @inline *) + + (** A sub-module designed to be opened to make working with ints more convenient. *) + module O : Operators_unbounded with type t := t +end + +(** [S] is a generic interface for fixed-size integers. *) +module type S = sig + include S_common (** @inline *) + + (** The number of bits available in this integer type. Note that the integer + representations are signed. *) + val num_bits : int + + (** The largest representable integer. *) + val max_value : t + + (** The smallest representable integer. *) + val min_value : t + + (** Same as [shift_right_logical]. *) + val ( lsr ) : t -> int -> t + + (** Shifts right, filling in with zeroes, which will not preserve the sign of the + input. *) + val shift_right_logical : t -> int -> t + + (** [ceil_pow2 x] returns the smallest power of 2 that is greater than or equal to [x]. + The implementation may only be called for [x > 0]. Example: [ceil_pow2 17 = 32] *) + val ceil_pow2 : t -> t + + (** [floor_pow2 x] returns the largest power of 2 that is less than or equal to [x]. The + implementation may only be called for [x > 0]. Example: [floor_pow2 17 = 16] *) + val floor_pow2 : t -> t + + (** [ceil_log2 x] returns the ceiling of log-base-2 of [x], and raises if [x <= 0]. *) + val ceil_log2 : t -> int + + (** [floor_log2 x] returns the floor of log-base-2 of [x], and raises if [x <= 0]. *) + val floor_log2 : t -> int + + (** [is_pow2 x] returns true iff [x] is a power of 2. [is_pow2] raises if [x <= 0]. *) + val is_pow2 : t -> bool + + (** A sub-module designed to be opened to make working with ints more convenient. *) + module O : Operators with type t := t +end + +include + (struct + (** Various functors whose type-correctness ensures desired relationships between + interfaces. *) + + module Check_O_contained_in_S (M : S) = (M : module type of M.O) + module Check_O_contained_in_S_unbounded (M : S_unbounded) = (M : module type of M.O) + module Check_S_unbounded_in_S (M : S) = (M : S_unbounded) + end : sig end) + +module type Int_without_module_types = sig + include S with type t = int + + (** [max_value_30_bits = 2^30 - 1]. It is useful for writing tests that work on both + 64-bit and 32-bit platforms. *) + val max_value_30_bits : t + + (** {2 Conversion functions} *) + + val of_int : int -> t + val to_int : t -> int + val of_int32 : int32 -> t option + val to_int32 : t -> int32 option + val of_int64 : int64 -> t option + val of_nativeint : nativeint -> t option + val to_nativeint : t -> nativeint + + (** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases + where optional conversions return [Some x], truncating conversions return [x]. *) + + val of_int32_trunc : int32 -> t + val to_int32_trunc : t -> int32 + val of_int64_trunc : int64 -> t + val of_nativeint_trunc : nativeint -> t + + (**/**) + (*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) + module Private : sig + (*_ For ../bench/bench_int.ml *) + module O_F : sig + val ( % ) : int -> int -> int + val ( /% ) : int -> int -> int + val ( // ) : int -> int -> float + end + end +end + +(** OCaml's native integer type. + + The number of bits in an integer is platform dependent, being 31-bits on a 32-bit + platform, and 63-bits on a 64-bit platform. [int] is a signed integer type. [int]s + are also subject to overflow, meaning that [Int.max_value + 1 = Int.min_value]. + + [int]s always fit in a machine word. *) +module type Int = sig + include Int_without_module_types + + (** {2 Module types specifying integer operations.} *) + module type Hexable = Hexable + module type Int_without_module_types = Int_without_module_types + module type Operators = Operators + module type Operators_unbounded = Operators_unbounded + module type Round = Round + module type S = S + module type S_common = S_common + module type S_unbounded = S_unbounded +end diff --git a/src/int_math.ml b/src/int_math.ml new file mode 100644 index 0000000..63db7b8 --- /dev/null +++ b/src/int_math.ml @@ -0,0 +1,144 @@ +open! Import + +let invalid_argf = Printf.invalid_argf + +let negative_exponent () = + Printf.invalid_argf "exponent can not be negative" () + +let overflow () = + Printf.invalid_argf "integer overflow in pow" () + +(* To implement [int64_pow], we use C code rather than OCaml to eliminate allocation. *) +external int_math_int_pow : int -> int -> int = "Base_int_math_int_pow_stub" [@@noalloc] +external int_math_int64_pow : int64 -> int64 -> int64 = "Base_int_math_int64_pow_stub" + +let int_pow base exponent = + if exponent < 0 then negative_exponent (); + + if abs(base) > 1 && + (exponent > 63 || + abs(base) > Pow_overflow_bounds.int_positive_overflow_bounds.(exponent)) + then overflow (); + + int_math_int_pow base exponent +;; + +module Int64_with_comparisons = struct + include Caml.Int64 + external ( < ) : int64 -> int64 -> bool = "%lessthan" + external ( > ) : int64 -> int64 -> bool = "%greaterthan" + external ( >= ) : int64 -> int64 -> bool = "%greaterequal" +end + +(* we don't do [abs] in int64 case to avoid allocation *) +let int64_pow base exponent = + let open Int64_with_comparisons in + if exponent < 0L then negative_exponent (); + + if (base > 1L || base < (-1L)) && + (exponent > 63L || + (base >= 0L && + base > Pow_overflow_bounds.int64_positive_overflow_bounds.(to_int exponent)) + || + (base < 0L && + base < Pow_overflow_bounds.int64_negative_overflow_bounds.(to_int exponent))) + then overflow (); + + int_math_int64_pow base exponent +;; + + +let int63_pow_on_int64 base exponent = + let open Int64_with_comparisons in + if exponent < 0L then negative_exponent (); + + if abs(base) > 1L && + (exponent > 63L || + abs(base) > Pow_overflow_bounds.int63_on_int64_positive_overflow_bounds.(to_int exponent)) + then overflow (); + + int_math_int64_pow base exponent +;; + +module type T = sig + type t + include Floatable.S with type t := t + include Stringable.S with type t := t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ~- ) : t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + val of_int_exn : int -> t + val rem : t -> t -> t +end + +module Make (X : T) = struct + open X + + let ( % ) x y = + if y <= zero then + invalid_argf + "%s %% %s in core_int.ml: modulus should be positive" + (to_string x) (to_string y) (); + let rval = X.rem x y in + if rval < zero + then rval + y + else rval + ;; + + let one = of_int_exn 1 + ;; + + let ( /% ) x y = + if y <= zero then + invalid_argf + "%s /%% %s in core_int.ml: divisor should be positive" + (to_string x) (to_string y) (); + if x < zero + then (x + one) / y - one + else x / y + ;; + + (** float division of integers *) + let (//) x y = to_float x /. to_float y + ;; + + let round_down i ~to_multiple_of:modulus = i - (i % modulus) + ;; + + let round_up i ~to_multiple_of:modulus = + let remainder = i % modulus in + if remainder = zero + then i + else i + modulus - remainder + ;; + + let round_towards_zero i ~to_multiple_of = + if i = zero then zero else + if i > zero + then round_down i ~to_multiple_of + else round_up i ~to_multiple_of + ;; + + let round_nearest i ~to_multiple_of:modulus = + let remainder = i % modulus in + if remainder * of_int_exn 2 < modulus + then i - remainder + else i - remainder + modulus + ;; + + let round ?(dir=`Nearest) i ~to_multiple_of = + match dir with + | `Nearest -> round_nearest i ~to_multiple_of + | `Down -> round_down i ~to_multiple_of + | `Up -> round_up i ~to_multiple_of + | `Zero -> round_towards_zero i ~to_multiple_of + ;; +end diff --git a/src/int_math.mli b/src/int_math.mli new file mode 100644 index 0000000..bfd0f82 --- /dev/null +++ b/src/int_math.mli @@ -0,0 +1,39 @@ +(** This module is not exposed in Core. Instead, these functions are accessed and + commented in the various Core modules implementing [Int_intf.S]. *) + +open! Import + +(*_ This interface is not defined in int_intf.ml because we don't want users of Core to + think about it. *) +module type T = sig + type t + include Floatable.S with type t := t + include Stringable.S with type t := t + + val ( + ) : t -> t -> t + val ( - ) : t -> t -> t + val ( * ) : t -> t -> t + val ( / ) : t -> t -> t + val ( ~- ) : t -> t + include Comparisons.Infix with type t := t + + val abs : t -> t + val neg : t -> t + val zero : t + val of_int_exn : int -> t + + val rem : t -> t -> t +end + +(** derived operations common to various integer modules *) +module Make (X : T) : sig + val ( % ) : X.t -> X.t -> X.t + val ( /% ) : X.t -> X.t -> X.t + val ( // ) : X.t -> X.t -> float + include Int_intf.Round with type t := X.t +end + +val int_pow : int -> int -> int +val int64_pow : int64 -> int64 -> int64 + +val int63_pow_on_int64 : int64 -> int64 -> int64 diff --git a/src/int_math_stubs.c b/src/int_math_stubs.c new file mode 100644 index 0000000..733a327 --- /dev/null +++ b/src/int_math_stubs.c @@ -0,0 +1,92 @@ +#include <stdlib.h> +#include <stdbool.h> +#include <stdint.h> +#include <caml/alloc.h> +#include <caml/mlvalues.h> +#include <caml/memory.h> + +#ifdef _MSC_VER + +#include <intrin.h> + +#define __builtin_popcountll __popcnt64 +#define __builtin_popcount __popcnt + +static uint32_t __inline __builtin_clz(uint32_t x) +{ + int r = 0; + _BitScanForward(&r, x); + return r; +} + +static uint64_t __inline __builtin_clzll(uint64_t x) +{ + int r = 0; + _BitScanForward64(&r, x); + return r; +} + +#endif + +static int64_t int_pow(int64_t base, int64_t exponent) { + int64_t ret = 1; + int64_t mul[4]; + mul[0] = 1; + mul[1] = base; + mul[3] = 1; + + while(exponent != 0) { + mul[1] *= mul[3]; + mul[2] = mul[1] * mul[1]; + mul[3] = mul[2] * mul[1]; + ret *= mul[exponent & 3]; + exponent >>= 2; + } + + return ret; +} + +CAMLprim value Base_int_math_int_pow_stub(value base, value exponent) { + return (Val_long(int_pow(Long_val(base), Long_val(exponent)))); +} + +CAMLprim value Base_int_math_int64_pow_stub(value base, value exponent) { + CAMLparam2(base, exponent); + CAMLreturn(caml_copy_int64(int_pow(Int64_val(base), Int64_val(exponent)))); +} + +/* This implementation is faster than [__builtin_popcount(v) - 1], even though + * it seems more complicated. The [&] clears the shifted sign bit after + * [Long_val] or [Int_val]. */ +CAMLprim value Base_int_math_int_popcount(value v) { +#ifdef ARCH_SIXTYFOUR + return Val_int (__builtin_popcountll (Long_val (v) & ~((uint64_t)1 << 63))); +#else + return Val_int (__builtin_popcount (Int_val (v) & ~((uint32_t)1 << 31))); +#endif +} + +/* The specification of all below [clz] functions is undefined for [v = 0]. */ +CAMLprim value Base_int_math_int_clz(value v) { +#ifdef ARCH_SIXTYFOUR + return Val_int (__builtin_clzll (Long_val(v))); +#else + return Val_int (__builtin_clz (Int_val (v))); +#endif +} + +CAMLprim value Base_int_math_int32_clz(value v) { + return Val_int (__builtin_clz (Int32_val(v))); +} + +CAMLprim value Base_int_math_int64_clz(value v) { + return Val_int (__builtin_clzll (Int64_val(v))); +} + +CAMLprim value Base_int_math_nativeint_clz(value v) { +#ifdef ARCH_SIXTYFOUR + return Val_int (__builtin_clzll (Nativeint_val(v))); +#else + return Val_int (__builtin_clz (Nativeint_val(v))); +#endif +} diff --git a/src/intable.ml b/src/intable.ml new file mode 100644 index 0000000..2300503 --- /dev/null +++ b/src/intable.ml @@ -0,0 +1,11 @@ +(** Functor that adds integer conversion functions to a module. *) + +open! Import + +module type S = sig + type t + + val of_int_exn : int -> t + val to_int_exn : t -> int +end + diff --git a/src/internalhash.h b/src/internalhash.h new file mode 100644 index 0000000..b752b91 --- /dev/null +++ b/src/internalhash.h @@ -0,0 +1,3 @@ +#include <stdint.h> +#include <caml/mlvalues.h> +CAMLexport uint32_t Base_internalhash_fold_blob(uint32_t h, mlsize_t len, uint8_t *s); diff --git a/src/internalhash_stubs.c b/src/internalhash_stubs.c new file mode 100644 index 0000000..f3d0b39 --- /dev/null +++ b/src/internalhash_stubs.c @@ -0,0 +1,101 @@ +#include <stdint.h> +#include <caml/mlvalues.h> +#include <caml/hash.h> +#include "internalhash.h" + +/* This pretends that the state of the OCaml internal hash function, which is an + int32, is actually stored in an OCaml int. */ + +CAMLprim value Base_internalhash_fold_int32(value st, value i) +{ + return Val_long(caml_hash_mix_uint32(Long_val(st), Int32_val(i))); +} + +CAMLprim value Base_internalhash_fold_nativeint(value st, value i) +{ + return Val_long(caml_hash_mix_intnat(Long_val(st), Nativeint_val(i))); +} + +CAMLprim value Base_internalhash_fold_int64(value st, value i) +{ + return Val_long(caml_hash_mix_int64(Long_val(st), Int64_val(i))); +} + +CAMLprim value Base_internalhash_fold_int(value st, value i) +{ + return Val_long(caml_hash_mix_intnat(Long_val(st), Long_val(i))); +} + +CAMLprim value Base_internalhash_fold_float(value st, value i) +{ + return Val_long(caml_hash_mix_double(Long_val(st), Double_val(i))); +} + +/* This code mimics what hashtbl.hash does in OCaml's hash.c */ +#define FINAL_MIX(h) \ + h ^= h >> 16; \ + h *= 0x85ebca6b; \ + h ^= h >> 13; \ + h *= 0xc2b2ae35; \ + h ^= h >> 16; + +CAMLprim value Base_internalhash_get_hash_value(value st) +{ + uint32_t h = Int_val(st); + FINAL_MIX(h); + return Val_int(h & 0x3FFFFFFFU); /*30 bits*/ +} + +/* Macros copied from hash.c in ocaml distribution */ +#define ROTL32(x,n) ((x) << n | (x) >> (32-n)) + +#define MIX(h,d) \ + d *= 0xcc9e2d51; \ + d = ROTL32(d, 15); \ + d *= 0x1b873593; \ + h ^= d; \ + h = ROTL32(h, 13); \ + h = h * 5 + 0xe6546b64; + +/* Version of [caml_hash_mix_string] from hash.c - adapted for arbitrary char arrays */ +CAMLexport uint32_t Base_internalhash_fold_blob(uint32_t h, mlsize_t len, uint8_t *s) +{ + mlsize_t i; + uint32_t w; + + /* Mix by 32-bit blocks (little-endian) */ + for (i = 0; i + 4 <= len; i += 4) { +#ifdef ARCH_BIG_ENDIAN + w = s[i] + | (s[i+1] << 8) + | (s[i+2] << 16) + | (s[i+3] << 24); +#else + w = *((uint32_t *) &(s[i])); +#endif + MIX(h, w); + } + /* Finish with up to 3 bytes */ + w = 0; + switch (len & 3) { + case 3: w = s[i+2] << 16; /* fallthrough */ + case 2: w |= s[i+1] << 8; /* fallthrough */ + case 1: w |= s[i]; + MIX(h, w); + default: /*skip*/; /* len & 3 == 0, no extra bytes, do nothing */ + } + /* Finally, mix in the length. Ignore the upper 32 bits, generally 0. */ + h ^= (uint32_t) len; + return h; +} + +CAMLprim value Base_internalhash_fold_string(value st, value v_str) +{ + uint32_t h = Long_val(st); + mlsize_t len = caml_string_length(v_str); + uint8_t *s = (uint8_t *) String_val(v_str); + + h = Base_internalhash_fold_blob(h, len, s); + + return Val_long(h); +} diff --git a/src/invariant.ml b/src/invariant.ml new file mode 100644 index 0000000..026a9e7 --- /dev/null +++ b/src/invariant.ml @@ -0,0 +1,26 @@ +open! Import + +include Invariant_intf + +let raise_s = Error.raise_s + +let invariant here t sexp_of_t f : unit = + try + f () + with exn -> + raise_s + (Sexp.message "invariant failed" + [ "" , Source_code_position0.sexp_of_t here + ; "exn", sexp_of_exn exn + ; "" , sexp_of_t t ]) +;; + +let check_field t f field = + try + f (Field.get field t) + with exn -> + raise_s (Sexp.message "problem with field" + [ "field", sexp_of_string (Field.name field) + ; "exn" , sexp_of_exn exn + ]) +;; diff --git a/src/invariant.mli b/src/invariant.mli new file mode 100644 index 0000000..23ca91b --- /dev/null +++ b/src/invariant.mli @@ -0,0 +1 @@ +include Invariant_intf.Invariant (** @inline *) diff --git a/src/invariant_intf.ml b/src/invariant_intf.ml new file mode 100644 index 0000000..4e657a4 --- /dev/null +++ b/src/invariant_intf.ml @@ -0,0 +1,98 @@ +open! Import + +type 'a t = 'a -> unit + +type 'a inv = 'a t + +module type S = sig + type t + val invariant : t inv +end + +module type S1 = sig + type 'a t + val invariant : 'a inv -> 'a t inv +end + +module type S2 = sig + type ('a, 'b) t + val invariant : 'a inv -> 'b inv -> ('a, 'b) t inv +end + +module type S3 = sig + type ('a, 'b, 'c) t + val invariant : 'a inv -> 'b inv -> 'c inv -> ('a, 'b, 'c) t inv +end + +module type Invariant = sig + + (** This module defines signatures that are to be included in other signatures to ensure + a consistent interface to invariant-style functions. There is a signature ([S], + [S1], [S2], [S3]) for each arity of type. Usage looks like: + + {[ + type t + include Invariant.S with type t := t + ]} + + or + + {[ + type 'a t + include Invariant.S1 with type 'a t := 'a t + ]} + *) + + type nonrec 'a t = 'a t + + module type S = S + module type S1 = S1 + module type S2 = S2 + module type S3 = S3 + + (** [invariant here t sexp_of_t f] runs [f ()], and if [f] raises, wraps the exception + in an [Error.t] that states "invariant failed" and includes both the exception + raised by [f], as well as [sexp_of_t t]. Idiomatic usage looks like: + + {[ + invariant [%here] t [%sexp_of: t] (fun () -> + ... check t's invariants ... ) + ]} + + For polymorphic types: + + {[ + let invariant check_a t = + Invariant.invariant [%here] t [%sexp_of: _ t] (fun () -> ... ) + ]} + + It's okay to use [ [%sexp_of: _ t] ] because the exceptions raised by [check_a] will + show the parts that are sexp_opaque at top-level. *) + val invariant + : Source_code_position0.t + -> 'a + -> ('a -> Sexp.t) + -> (unit -> unit) + -> unit + + (** [check_field] is used when checking invariants using [Fields.iter]. It wraps an + exception raised when checking a field with the field's name. Idiomatic usage looks + like: + + {[ + type t = + { foo : Foo.t; + bar : Bar.t; + } + [@@deriving_inline fields][@@@end] + + let invariant t : unit = + Invariant.invariant [%here] t [%sexp_of: t] (fun () -> + let check f = Invariant.check_field t f in + Fields.iter + ~foo:(check Foo.invariant) + ~bar:(check Bar.invariant)) + ;; + ]} *) + val check_field : 'a -> 'b t -> ('a, 'b) Field.t -> unit +end diff --git a/src/lazy.ml b/src/lazy.ml new file mode 100644 index 0000000..6a32e4a --- /dev/null +++ b/src/lazy.ml @@ -0,0 +1,44 @@ +open! Import + +type 'a t = 'a lazy_t [@@deriving_inline sexp] +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + lazy_t_of_sexp +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_lazy_t +[@@@end] + +include (Caml.Lazy : module type of Caml.Lazy with type 'a t := 'a t) + +let map t ~f = lazy (f (force t)) + +let compare compare_a t1 t2 = + if phys_equal t1 t2 + then 0 + else compare_a (force t1) (force t2) +;; + +let hash_fold_t = Hash.Builtin.hash_fold_lazy_t + +include Monad.Make (struct + type nonrec 'a t = 'a t + + let return x = from_val x + + let bind t ~f = lazy (force (f (force t))) + + let map = map + + let map = `Custom map + end) + +module T_unforcing = struct + type nonrec 'a t = 'a t + + let sexp_of_t sexp_of_a t = + if is_val t + then sexp_of_a (force t) + else sexp_of_string "<unforced lazy>" + ;; +end diff --git a/src/lazy.mli b/src/lazy.mli new file mode 100644 index 0000000..1a0c599 --- /dev/null +++ b/src/lazy.mli @@ -0,0 +1,81 @@ +(*_ JS-only: This file is a modified version of lazy.mli from the OCaml distribution. *) + +(** A value of type ['a Lazy.t] is a deferred computation, called a suspension, that has a + result of type ['a]. + + The special expression syntax [lazy (expr)] makes a suspension of the computation of + [expr], without computing [expr] itself yet. "Forcing" the suspension will then + compute [expr] and return its result. + + Note: [lazy_t] is the built-in type constructor used by the compiler for the [lazy] + keyword. You should not use it directly. Always use [Lazy.t] instead. + + Note: [Lazy.force] is not thread-safe. If you use this module in a multi-threaded + program, you will need to add some locks. + + Note: if the program is compiled with the [-rectypes] option, ill-founded recursive + definitions of the form [let rec x = lazy x] or [let rec x = lazy(lazy(...(lazy x)))] + are accepted by the type-checker and lead, when forced, to ill-formed values that + trigger infinite loops in the garbage collector and other parts of the run-time + system. Without the [-rectypes] option, such ill-founded recursive definitions are + rejected by the type-checker. *) + +open! Import + +type 'a t = 'a lazy_t [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Monad.S with type 'a t := 'a t + +exception Undefined + +(** [force x] forces the suspension [x] and returns its result. If [x] has already been + forced, [Lazy.force x] returns the same value again without recomputing it. If it + raised an exception, the same exception is raised again. Raise [Undefined] if the + forcing of [x] tries to force [x] itself recursively. *) +external force : 'a t -> 'a = "%lazy_force" + +(** Like [force] except that [force_val x] does not use an exception handler, so it may be + more efficient. However, if the computation of [x] raises an exception, it is + unspecified whether [force_val x] raises the same exception or [Undefined]. *) +val force_val : 'a t -> 'a + +(** [from_fun f] is the same as [lazy (f ())] but slightly more efficient if [f] is a + variable. [from_fun] should only be used if the function [f] is already defined. In + particular it is always less efficient to write [from_fun (fun () -> expr)] than [lazy + expr]. *) +val from_fun : (unit -> 'a) -> 'a t + +(** [from_val v] returns an already-forced suspension of [v] (where [v] can be any + expression). Essentially, [from_val expr] is the same as [let var = expr in lazy + var]. *) +val from_val : 'a -> 'a t + +(** [is_val x] returns [true] if [x] has already been forced and did not raise an + exception. *) +val is_val : 'a t -> bool + +(** This type offers a serialization function [sexp_of_t] that won't force its argument. + Instead, it will serialize the ['a] if it is available, or just use a custom string + indicating it is not forced. Note that this is not a round-trippable type, thus the + type does not expose [of_sexp]. To be used in debug code, while tracking a Heisenbug, + etc. *) +module T_unforcing : sig + type nonrec 'a t = 'a t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end diff --git a/src/linked_queue.ml b/src/linked_queue.ml new file mode 100644 index 0000000..bc5d7af --- /dev/null +++ b/src/linked_queue.ml @@ -0,0 +1,159 @@ +open! Import + +include Linked_queue0 + + +let enqueue t x = Linked_queue0.push x t + +let dequeue t = if is_empty t then None else Some (Linked_queue0.pop t) +let dequeue_exn = Linked_queue0.pop + +let peek t = if is_empty t then None else Some (Linked_queue0.peek t) +let peek_exn = Linked_queue0.peek + + +module C = + Indexed_container.Make (struct + type nonrec 'a t = 'a t + let fold = fold + let iter = `Custom iter + let length = `Custom length + let foldi = `Define_using_fold + let iteri = `Define_using_fold + end) + +let count = C.count +let exists = C.exists +let find = C.find +let find_map = C.find_map +let fold_result = C.fold_result +let fold_until = C.fold_until +let for_all = C.for_all +let max_elt = C.max_elt +let mem = C.mem +let min_elt = C.min_elt +let sum = C.sum +let to_list = C.to_list + +let counti = C.counti +let existsi = C.existsi +let find_mapi = C.find_mapi +let findi = C.findi +let foldi = C.foldi +let for_alli = C.for_alli +let iteri = C.iteri + +let transfer ~src ~dst = Linked_queue0.transfer src dst + +let concat_map t ~f = + let res = create () in + iter t ~f:(fun a -> + List.iter (f a) ~f:(fun b -> enqueue res b)); + res +;; + +let concat_mapi t ~f = + let res = create () in + iteri t ~f:(fun i a -> + List.iter (f i a) ~f:(fun b -> enqueue res b)); + res +;; + +let filter_map t ~f = + let res = create () in + iter t ~f:(fun a -> + match f a with + | None -> () + | Some b -> enqueue res b); + res; +;; + +let filter_mapi t ~f = + let res = create () in + iteri t ~f:(fun i a -> + match f i a with + | None -> () + | Some b -> enqueue res b); + res; +;; + +let filter t ~f = + let res = create () in + iter t ~f:(fun a -> if f a then enqueue res a); + res; +;; + +let filteri t ~f = + let res = create () in + iteri t ~f:(fun i a -> if f i a then enqueue res a); + res; +;; + +let map t ~f = + let res = create () in + iter t ~f:(fun a -> enqueue res (f a)); + res; +;; + +let mapi t ~f = + let res = create () in + iteri t ~f:(fun i a -> enqueue res (f i a)); + res; +;; + +let filter_inplace q ~f = + let q' = filter q ~f in + clear q; + transfer ~src:q' ~dst:q; +;; + +let filteri_inplace q ~f = + let q' = filteri q ~f in + clear q; + transfer ~src:q' ~dst:q; +;; + +let enqueue_all t list = + List.iter list ~f:(fun x -> enqueue t x) +;; + +let of_list list = + let t = create () in + List.iter list ~f:(fun x -> enqueue t x); + t +;; + +let of_array array = + let t = create () in + Array.iter array ~f:(fun x -> enqueue t x); + t +;; + +let init len ~f = + let t = create () in + for i = 0 to len - 1 do + enqueue t (f i) + done; + t +;; + +let to_array t = + match length t with + | 0 -> [||] + | len -> + let arr = Array.create ~len (peek_exn t) in + let i = ref 0 in + iter t ~f:(fun v -> + arr.(!i) <- v; + incr i); + arr +;; + +let t_of_sexp a_of_sexp sexp = of_list (list_of_sexp a_of_sexp sexp) +let sexp_of_t sexp_of_a t = sexp_of_list sexp_of_a (to_list t) + +let singleton a = + let t = create () in + enqueue t a; + t +;; diff --git a/src/linked_queue.mli b/src/linked_queue.mli new file mode 100644 index 0000000..158e8f8 --- /dev/null +++ b/src/linked_queue.mli @@ -0,0 +1,19 @@ +(** This module is a Base-style wrapper around OCaml's standard [Queue] module. *) + +open! Import + +include Queue_intf.S with type 'a t = 'a Caml.Queue.t (** @inline *) + +(** [create ()] returns an empty queue. *) +val create : unit -> _ t + +(** [transfer ~src ~dst] adds all of the elements of [src] to the end of [dst], then + clears [src]. It is equivalent to the sequence: + + {[ + iter ~src ~f:(enqueue dst); + clear src + ]} + + but runs in constant time. *) +val transfer : src:'a t -> dst:'a t -> unit diff --git a/src/linked_queue0.ml b/src/linked_queue0.ml new file mode 100644 index 0000000..fcd67f4 --- /dev/null +++ b/src/linked_queue0.ml @@ -0,0 +1,16 @@ +open! Import0 + +type 'a t = 'a Caml.Queue.t + +let create = Caml.Queue.create +let clear = Caml.Queue.clear +let copy = Caml.Queue.copy +let is_empty = Caml.Queue.is_empty +let length = Caml.Queue.length +let peek = Caml.Queue.peek +let pop = Caml.Queue.pop +let push = Caml.Queue.push +let transfer = Caml.Queue.transfer + +let iter t ~f = Caml.Queue.iter f t +let fold t ~init ~f = Caml.Queue.fold f init t diff --git a/src/list.ml b/src/list.ml new file mode 100644 index 0000000..fc5f13c --- /dev/null +++ b/src/list.ml @@ -0,0 +1,1065 @@ +open! Import + +module Array = Array0 + + +include List1 (* This itself includes [List0]. *) + +let invalid_argf = Printf.invalid_argf + +module T = struct + type 'a t = 'a list [@@deriving_inline sexp] + let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + list_of_sexp + let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_list + [@@@end] +end + +module Or_unequal_lengths = struct + type 'a t = + | Ok of 'a + | Unequal_lengths + [@@deriving_inline compare, sexp_of] + let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = + fun _cmp__a -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Ok _a__003_, Ok _b__004_) -> _cmp__a _a__003_ _b__004_ + | (Ok _, _) -> (-1) + | (_, Ok _) -> 1 + | (Unequal_lengths, Unequal_lengths) -> 0) + let sexp_of_t : type a. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | Ok v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Ok"; v0] + | Unequal_lengths -> Ppx_sexp_conv_lib.Sexp.Atom "Unequal_lengths" + [@@@end] +end + +include T + +let of_list t = t + +let range' ~compare ~stride ?(start=`inclusive) ?(stop=`exclusive) start_i stop_i = + let next_i = stride start_i in + let order x y = Ordering.of_int (compare x y) in + let raise_stride_cannot_return_same_value () = + invalid_arg "List.range': stride function cannot return the same value" + in + let initial_stride_order = + match order start_i next_i with + | Equal -> raise_stride_cannot_return_same_value () + | Less -> `Less + | Greater -> `Greater + in + let rec loop i accum = + let i_to_stop_order = order i stop_i in + match i_to_stop_order, initial_stride_order with + | (Less, `Less) + | (Greater, `Greater) -> begin + (* haven't yet reached [stop_i]. Continue. *) + let next_i = stride i in + match order i next_i, initial_stride_order with + | Equal, _ -> raise_stride_cannot_return_same_value () + | Less, `Greater + | Greater, `Less -> + invalid_arg "List.range': stride function cannot change direction" + | Less, `Less + | Greater, `Greater -> loop next_i (i :: accum) + end + | (Less, `Greater) + | (Greater, `Less) -> + (* stepped past [stop_i]. Finished. *) + accum + | (Equal, _) -> + (* reached [stop_i]. Finished. *) + match stop with + | `inclusive -> i :: accum + | `exclusive -> accum + in + let start_i = + match start with + | `inclusive -> start_i + | `exclusive -> next_i + in + rev (loop start_i []) +;; + +let range ?(stride=1) ?(start=`inclusive) ?(stop=`exclusive) start_i stop_i = + if stride = 0 then invalid_arg "List.range: stride must be non-zero"; + range' ~compare ~stride:(fun x -> x + stride) ~start ~stop start_i stop_i +;; + +let hd t = + match t with + | [] -> None + | x :: _ -> Some x +;; + +let tl t = + match t with + | [] -> None + | _ :: t' -> Some t' +;; + +let nth t n = + if n < 0 then None else + let rec nth_aux t n = + match t with + | [] -> None + | a :: t -> if n = 0 then Some a else nth_aux t (n-1) + in nth_aux t n +;; + +let nth_exn t n = + match nth t n with + | None -> + invalid_argf "List.nth_exn %d called on list of length %d" + n (length t) () + | Some a -> a +;; + +let unordered_append l1 l2 = + match l1, l2 with + | [], l | l, [] -> l + | _ -> rev_append l1 l2 + +let check_length2_exn name l1 l2 = + let n1 = length l1 in + let n2 = length l2 in + if n1 <> n2 + then raise (invalid_argf "length mismatch in %s: %d <> %d " name n1 n2 ()) +;; + +let check_length3_exn name l1 l2 l3 = + let n1 = length l1 in + let n2 = length l2 in + let n3 = length l3 in + if n1 <> n2 || n2 <> n3 + then raise (invalid_argf "length mismatch in %s: %d <> %d || %d <> %d" + name n1 n2 n2 n3 ()) +;; + +let check_length2 l1 l2 ~f = + if length l1 <> length l2 + then Or_unequal_lengths.Unequal_lengths + else Ok (f l1 l2) +;; + +let check_length3 l1 l2 l3 ~f = + let n1 = length l1 in + let n2 = length l2 in + let n3 = length l3 in + if n1 <> n2 || n2 <> n3 + then Or_unequal_lengths.Unequal_lengths + else Ok (f l1 l2 l3) +;; + +let iter2 l1 l2 ~f = check_length2 l1 l2 ~f:(iter2_ok ~f) + +let iter2_exn l1 l2 ~f = + check_length2_exn "iter2_exn" l1 l2; + iter2_ok l1 l2 ~f; +;; + +let rev_map2 l1 l2 ~f = check_length2 l1 l2 ~f:(rev_map2_ok ~f) + +let rev_map2_exn l1 l2 ~f = + check_length2_exn "rev_map2_exn" l1 l2; + rev_map2_ok l1 l2 ~f; +;; + +let fold2 l1 l2 ~init ~f = check_length2 l1 l2 ~f:(fold2_ok ~init ~f) + +let fold2_exn l1 l2 ~init ~f = + check_length2_exn "fold2_exn" l1 l2; + fold2_ok l1 l2 ~init ~f; +;; + +let for_all2 l1 l2 ~f = check_length2 l1 l2 ~f:(for_all2_ok ~f) + +let for_all2_exn l1 l2 ~f = + check_length2_exn "for_all2_exn" l1 l2; + for_all2_ok l1 l2 ~f; +;; + +let exists2 l1 l2 ~f = check_length2 l1 l2 ~f:(exists2_ok ~f) + +let exists2_exn l1 l2 ~f = + check_length2_exn "exists2_exn" l1 l2; + exists2_ok l1 l2 ~f; +;; + +let mem t a ~equal = + let rec loop equal a = function + | [] -> false + | b :: bs -> equal a b || loop equal a bs + in + loop equal a t + +(* This is a copy of the code from the standard library, with an extra eta-expansion to + avoid creating partial closures (showed up for [filter]) in profiling). *) +let rev_filter t ~f = + let rec find ~f accu = function + | [] -> accu + | x :: l -> if f x then find ~f (x :: accu) l else find ~f accu l + in + find ~f [] t +;; + +let filter t ~f = rev (rev_filter t ~f) + +let find_map t ~f = + let rec loop = function + | [] -> None + | x :: l -> + match f x with + | None -> loop l + | Some _ as r -> r + in + loop t +;; + +let find_map_exn t ~f = + match find_map t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let find t ~f = + let rec loop = function + | [] -> None + | x :: l -> if f x then Some x else loop l + in + loop t +;; + +let findi t ~f = + let rec loop i t = + match t with + | [] -> None + | x :: l -> if f i x then Some (i, x) else loop (i + 1) l + in + loop 0 t +;; + +let find_mapi t ~f = + let rec loop i t = + match t with + | [] -> None + | x :: l -> + match f i x with + | Some _ as result -> result + | None -> loop (i + 1) l + in + loop 0 t +;; + +let find_mapi_exn t ~f = + match find_mapi t ~f with + | None -> raise Caml.Not_found + | Some x -> x + +let for_alli t ~f = + let rec loop i t = + match t with + | [] -> true + | hd :: tl -> f i hd && loop (i+1) tl + in + loop 0 t + +let existsi t ~f = + let rec loop i t = + match t with + | [] -> false + | hd :: tl -> f i hd || loop (i+1) tl + in + loop 0 t + +(** For the container interface. *) +let fold_left = fold +let to_array = Array.of_list +let to_list t = t + +(** Tail recursive versions of standard [List] module *) + +let slow_append l1 l2 = rev_append (rev l1) l2 + +(* There are a few optimized list operations here, including append and map. There are + basically two optimizations in play: loop unrolling, and dynamic switching between + stack and heap allocation. + + The loop-unrolling is straightforward, we just unroll 5 levels of the loop. This makes + each iteration faster, and also reduces the number of stack frames consumed per list + element. + + The dynamic switching is done by counting the number of stack frames, and then + switching to the "slow" implementation when we exceed a given limit. This means that + short lists use the fast stack-allocation method, and long lists use a slower one that + doesn't require stack space. *) +let rec count_append l1 l2 count = + match l2 with + | [] -> l1 + | _ -> + match l1 with + | [] -> l2 + | [x1] -> x1 :: l2 + | [x1; x2] -> x1 :: x2 :: l2 + | [x1; x2; x3] -> x1 :: x2 :: x3 :: l2 + | [x1; x2; x3; x4] -> x1 :: x2 :: x3 :: x4 :: l2 + | x1 :: x2 :: x3 :: x4 :: x5 :: tl -> + x1 :: x2 :: x3 :: x4 :: x5 :: + (if count > 1000 + then slow_append tl l2 + else count_append tl l2 (count + 1)) + +let append l1 l2 = count_append l1 l2 0 + +let slow_map l ~f = rev (rev_map l ~f) + +let rec count_map ~f l ctr = + match l with + | [] -> [] + | [x1] -> + let f1 = f x1 in + [f1] + | [x1; x2] -> + let f1 = f x1 in + let f2 = f x2 in + [f1; f2] + | [x1; x2; x3] -> + let f1 = f x1 in + let f2 = f x2 in + let f3 = f x3 in + [f1; f2; f3] + | [x1; x2; x3; x4] -> + let f1 = f x1 in + let f2 = f x2 in + let f3 = f x3 in + let f4 = f x4 in + [f1; f2; f3; f4] + | x1 :: x2 :: x3 :: x4 :: x5 :: tl -> + let f1 = f x1 in + let f2 = f x2 in + let f3 = f x3 in + let f4 = f x4 in + let f5 = f x5 in + f1 :: f2 :: f3 :: f4 :: f5 :: + (if ctr > 1000 + then slow_map ~f tl + else count_map ~f tl (ctr + 1)) + +let map l ~f = count_map ~f l 0 + +let folding_map t ~init ~f = + let acc = ref init in + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) +;; + +let fold_map t ~init ~f = + let acc = ref init in + let result = + map t ~f:(fun x -> + let new_acc, y = f !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let (>>|) l f = map l ~f + +let map2_ok l1 l2 ~f = rev (rev_map2_ok l1 l2 ~f) + +let map2 l1 l2 ~f = check_length2 l1 l2 ~f:(map2_ok ~f) + +let map2_exn l1 l2 ~f = + check_length2_exn "map2_exn" l1 l2; + map2_ok l1 l2 ~f +;; + +let rev_map3_ok l1 l2 l3 ~f = + let rec loop l1 l2 l3 ac = + match (l1, l2, l3) with + | ([], [], []) -> ac + | (x1 :: l1, x2 :: l2, x3 :: l3) -> loop l1 l2 l3 (f x1 x2 x3 :: ac) + | _ -> assert false + in + loop l1 l2 l3 []; +;; + +let rev_map3 l1 l2 l3 ~f = + check_length3 l1 l2 l3 ~f:(rev_map3_ok ~f) +;; + +let rev_map3_exn l1 l2 l3 ~f = + check_length3_exn "rev_map3_exn" l1 l2 l3; + rev_map3_ok l1 l2 l3 ~f +;; + +let map3_ok l1 l2 l3 ~f = rev (rev_map3_ok l1 l2 l3 ~f) + +let map3 l1 l2 l3 ~f = check_length3 l1 l2 l3 ~f:(map3_ok ~f) + +let map3_exn l1 l2 l3 ~f = + check_length3_exn "map3_exn" l1 l2 l3; + map3_ok l1 l2 l3 ~f; +;; + +let rec rev_map_append l1 l2 ~f = + match l1 with + | [] -> l2 + | h :: t -> rev_map_append ~f t (f h :: l2) + +let fold_right l ~f ~init = + match l with + | [] -> init (* avoid the allocation of [~f] below *) + | _ -> fold ~f:(fun a b -> f b a) ~init (rev l) + +let unzip list = + let rec loop list l1 l2 = + match list with + | [] -> (rev l1, rev l2) + | (x, y) :: tl -> loop tl (x :: l1) (y :: l2) + in + loop list [] [] + +let unzip3 list = + let rec loop list l1 l2 l3 = + match list with + | [] -> (rev l1, rev l2, rev l3) + | (x, y, z) :: tl -> loop tl (x :: l1) (y :: l2) (z :: l3) + in + loop list [] [] [] + +let zip_exn l1 l2 = + check_length2_exn "zip_exn" l1 l2; + map2_ok ~f:(fun a b -> (a, b)) l1 l2 + +let zip l1 l2 = map2 ~f:(fun a b -> (a, b)) l1 l2 + +(** Additional list operations *) + +let rev_mapi l ~f = + let rec loop i acc = function + | [] -> acc + | h :: t -> loop (i + 1) (f i h :: acc) t + in + loop 0 [] l + +let mapi l ~f = rev (rev_mapi l ~f) + +let folding_mapi t ~init ~f = + let acc = ref init in + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) +;; + +let fold_mapi t ~init ~f = + let acc = ref init in + let result = + mapi t ~f:(fun i x -> + let new_acc, y = f i !acc x in + acc := new_acc; + y) + in + !acc, result +;; + +let iteri l ~f = + ignore (fold l ~init:0 ~f:(fun i x -> f i x; i + 1)); +;; + +let foldi t ~init ~f = + snd (fold t ~init:(0, init) ~f:(fun (i, acc) v -> (i + 1, f i acc v))) +;; + +let filteri l ~f = + rev (foldi l + ~f:(fun pos acc x -> + if f pos x then x :: acc else acc) + ~init:[]) + +let reduce l ~f = match l with + | [] -> None + | hd :: tl -> Some (fold ~init:hd ~f tl) + +let reduce_exn l ~f = + match reduce l ~f with + | None -> raise (Invalid_argument "List.reduce_exn") + | Some v -> v + +let reduce_balanced l ~f = + (* Call the "size" of a value the number of list elements that have been combined into + it via calls to [f]. We proceed by using [f] to combine elements in the accumulator + of the same size until we can't combine any more, then getting a new element from the + input list and repeating. + + With this strategy, in the accumulator: + - we only ever have elements of sizes a power of two + - we never have more than one element of each size + - the sum of all the element sizes is equal to the number of elements consumed + + These conditions enforce that list of elements of each size is precisely the binary + expansion of the number of elements consumed: if you've consumed 13 = 0b1101 + elements, you have one element of size 8, one of size 4, and one of size 1. Hence + when a new element comes along, the number of combinings you need to do is the number + of trailing 1s in the binary expansion of [num], the number of elements that have + already gone into the accumulator. The accumulator is in ascending order of size, so + the next element to combine with is always the head of the list. *) + let rec step_accum num acc x = + if num land 1 = 0 + then x :: acc + else + match acc with + | [] -> assert false + (* New elements from later in the input list go on the front of the accumulator, so + the accumulator is in reverse order wrt the original list order, hence [f y x] + instead of [f x y]. *) + | y :: ys -> step_accum (num asr 1) ys (f y x) + in + (* Experimentally, inlining [foldi] and unrolling this loop a few times can reduce + runtime down to a third and allocation to 1/16th or so in the microbenchmarks below. + However, in most use cases [f] is likely to be expensive (otherwise why do you care + about the order of reduction?) so the overhead of this function itself doesn't really + matter. If you come up with a use-case where it does, then that's something you might + want to try: see hg log -pr 49ef065f429d. *) + match foldi l ~init:[] ~f:step_accum with + | [] -> None + | x :: xs -> Some (fold xs ~init:x ~f:(fun x y -> f y x)) + +let reduce_balanced_exn l ~f = + match reduce_balanced l ~f with + | None -> raise (Invalid_argument "List.reduce_balanced_exn") + | Some v -> v + +let groupi l ~break = + let groups = + foldi l ~init:[] ~f:(fun i acc x -> + match acc with + | [] -> [[x]] + | current_group :: tl -> + if break i (hd_exn current_group) x then + [x] :: current_group :: tl (* start new group *) + else + (x :: current_group) :: tl) (* extend current group *) + in + match groups with + | [] -> [] + | l -> rev_map l ~f:rev + +let group l ~break = groupi l ~break:(fun _ x y -> break x y) + +let concat_map l ~f = + let rec aux acc = function + | [] -> rev acc + | hd :: tl -> aux (rev_append (f hd) acc) tl + in + aux [] l + +let concat_mapi l ~f = + let rec aux cont acc = function + | [] -> rev acc + | hd :: tl -> aux (cont + 1) (rev_append (f cont hd) acc) tl + in + aux 0 [] l + +let merge l1 l2 ~compare = + let rec loop acc l1 l2 = + match l1,l2 with + | [], l2 -> rev_append acc l2 + | l1, [] -> rev_append acc l1 + | h1 :: t1, h2 :: t2 -> + if compare h1 h2 <= 0 + then loop (h1 :: acc) t1 l2 + else loop (h2 :: acc) l1 t2 + in + loop [] l1 l2 +;; + + +include struct + (* We are explicit about what we import from the general Monad functor so that we don't + accidentally rebind more efficient list-specific functions. *) + module Monad = Monad.Make (struct + type 'a t = 'a list + let bind x ~f = concat_map x ~f + let map = `Custom map + let return x = [x] + end) + open Monad + module Monad_infix = Monad_infix + module Let_syntax = Let_syntax + let ignore_m = ignore_m + let join = join + let bind = bind + let (>>=) t f = bind t ~f + let return = return + let all = all + let all_unit = all_unit + let all_ignore = all_unit +end + +(** returns final element of list *) +let rec last_exn list = match list with + | [x] -> x + | _ :: tl -> last_exn tl + | [] -> raise (Invalid_argument "List.last") + +(** optionally returns final element of list *) +let rec last list = match list with + | [x] -> Some x + | _ :: tl -> last tl + | [] -> None +;; + +let rec is_prefix list ~prefix ~equal = + match prefix with + | [] -> true + | hd::tl -> + match list with + | [] -> false + | hd'::tl' -> equal hd hd' && is_prefix tl' ~prefix:tl ~equal +;; + +let find_consecutive_duplicate t ~equal = + match t with + | [] -> None + | a1 :: t -> + let rec loop a1 t = + match t with + | [] -> None + | a2 :: t -> if equal a1 a2 then Some (a1, a2) else loop a2 t + in + loop a1 t +;; + +(* returns list without adjacent duplicates *) +let remove_consecutive_duplicates ?(which_to_keep=`Last) list ~equal = + let rec loop to_keep accum = function + | [] -> to_keep :: accum + | hd :: tl -> + if equal hd to_keep then ( + let to_keep = + match which_to_keep with + | `First -> to_keep + | `Last -> hd + in + loop to_keep accum tl + ) else ( + loop hd (to_keep :: accum) tl + ) + in + match list with + | [] -> [] + | hd :: tl -> rev (loop hd [] tl) +;; + +(** returns sorted version of list with duplicates removed *) +let dedup_and_sort ~compare list = + match list with + | [] -> [] (* performance hack *) + | _ -> + let equal x x' = compare x x' = 0 in + let sorted = sort ~compare list in + remove_consecutive_duplicates ~equal sorted + +let dedup = dedup_and_sort + +let find_a_dup ~compare l = + let sorted = sort ~compare l in + let rec loop l = match l with + | [] | [_] -> None + | hd1 :: (hd2 :: _ as tl) -> + if compare hd1 hd2 = 0 then Some hd1 else loop tl + in + loop sorted +;; + +let contains_dup ~compare lst = + match find_a_dup ~compare lst with + | Some _ -> true + | None -> false +;; + +let find_all_dups ~compare l = + (* We add this reversal, so we can skip a [rev] at the end. We could skip + [rev] anyway since we don not give any ordering guarantees, but it is + nice to get results in natural order. *) + let compare a b = (-1) * compare a b in + let sorted = sort ~compare l in + (* Walk the list and record the first of each consecutive run of identical elements *) + let rec loop sorted prev ~already_recorded acc = + match sorted with + | [] -> acc + | hd :: tl -> + if compare prev hd <> 0 + then loop tl hd ~already_recorded:false acc + else if already_recorded + then loop tl hd ~already_recorded:true acc + else loop tl hd ~already_recorded:true (hd :: acc) + in + match sorted with + | [] -> [] + | hd :: tl -> loop tl hd ~already_recorded:false [] +;; + +let count t ~f = Container.count ~fold t ~f +let sum m t ~f = Container.sum ~fold m t ~f +let min_elt t ~compare = Container.min_elt ~fold t ~compare +let max_elt t ~compare = Container.max_elt ~fold t ~compare + +let counti t ~f = foldi t ~init:0 ~f:(fun idx count a -> if f idx a then count + 1 else count) + +let init n ~f = + if n < 0 then invalid_argf "List.init %d" n (); + let rec loop i accum = + assert (i >= 0); + if i = 0 then accum + else loop (i-1) (f (i-1) :: accum) + in + loop n [] +;; + +let rev_filter_map l ~f = + let rec loop l accum = + match l with + | [] -> accum + | hd :: tl -> + match f hd with + | Some x -> loop tl (x :: accum) + | None -> loop tl accum + in + loop l [] +;; + +let filter_map l ~f = rev (rev_filter_map l ~f) + +let rev_filter_mapi l ~f = + let rec loop i l accum = + match l with + | [] -> accum + | hd :: tl -> + match f i hd with + | Some x -> loop (i + 1) tl (x :: accum) + | None -> loop (i + 1) tl accum + in + loop 0 l [] +;; + +let filter_mapi l ~f = rev (rev_filter_mapi l ~f) + +let filter_opt l = filter_map l ~f:Fn.id + +let partition3_map t ~f = + let rec loop t fst snd trd = + match t with + | [] -> (rev fst, rev snd, rev trd) + | x :: t -> + match f x with + | `Fst y -> loop t (y :: fst) snd trd + | `Snd y -> loop t fst (y :: snd) trd + | `Trd y -> loop t fst snd (y :: trd) + in + loop t [] [] [] +;; + +let partition_tf t ~f = + let f x = if f x then `Fst x else `Snd x in + partition_map t ~f +;; + +let partition_result t = + let f x = + match x with + | Ok v -> `Fst v + | Error e -> `Snd e + in + partition_map t ~f + +module Assoc = struct + + type ('a, 'b) t = ('a * 'b) list [@@deriving_inline sexp] + let t_of_sexp : + 'a 'b . + (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> + (Ppx_sexp_conv_lib.Sexp.t -> 'b) -> + Ppx_sexp_conv_lib.Sexp.t -> ('a, 'b) t + = + let _tp_loc = "src/list.ml.Assoc.t" in + fun _of_a -> + fun _of_b -> + fun t -> + list_of_sexp + (function + | Ppx_sexp_conv_lib.Sexp.List (v0::v1::[]) -> + let v0 = _of_a v0 + and v1 = _of_b v1 in (v0, v1) + | sexp -> + Ppx_sexp_conv_lib.Conv_error.tuple_of_size_n_expected _tp_loc + 2 sexp) t + let sexp_of_t : + 'a 'b . + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_b -> + fun v -> + sexp_of_list + (function + | (v0, v1) -> + let v0 = _of_a v0 + and v1 = _of_b v1 in Ppx_sexp_conv_lib.Sexp.List [v0; v1]) v + [@@@end] + + let find t ~equal key = + match find t ~f:(fun (key', _) -> equal key key') with + | None -> None + | Some x -> Some (snd x) + + let find_exn t ~equal key = + match find t key ~equal with + | None -> raise Caml.Not_found + | Some value -> value + + let mem t ~equal key = + match find t ~equal key with + | None -> false + | Some _ -> true + ;; + + let remove t ~equal key = + filter t ~f:(fun (key', _) -> not (equal key key')) + + let add t ~equal key value = + (* the remove doesn't change the map semantics, but keeps the list small *) + (key, value) :: remove t ~equal key + + let inverse t = map t ~f:(fun (x, y) -> (y, x)) + + let map t ~f = map t ~f:(fun (key, value) -> (key, f value)) + +end + +let sub l ~pos ~len = + (* We use [pos > length l - len] rather than [pos + len > length l] to avoid the + possibility of overflow. *) + if pos < 0 || len < 0 || pos > length l - len then invalid_arg "List.sub"; + rev + (foldi l ~init:[] + ~f:(fun i acc el -> + if i >= pos && i < (pos + len) + then el :: acc + else acc + ) + ) +;; + +let split_n t_orig n = + if n <= 0 then + ([], t_orig) + else + let rec loop n t accum = + if n = 0 then + (rev accum, t) + else + match t with + | [] -> (t_orig, []) (* in this case, t_orig = rev accum *) + | hd :: tl -> loop (n - 1) tl (hd :: accum) + in + loop n t_orig [] + +(* copied from [split_n] to avoid allocating a tuple *) +let take t_orig n = + if n <= 0 then + [] + else + let rec loop n t accum = + if n = 0 then + rev accum + else + match t with + | [] -> t_orig + | hd :: tl -> loop (n - 1) tl (hd :: accum) + in + loop n t_orig [] + +let rec drop t n = + match t with + | _ :: tl when n > 0 -> drop tl (n - 1) + | t -> t + +let chunks_of l ~length = + if length <= 0 + then invalid_argf "List.chunks_of: Expected length > 0, got %d" length (); + let rec aux of_length acc l = + match l with + | [] -> rev acc + | _ :: _ -> + let sublist, l = split_n l length in + aux of_length (sublist :: acc) l + in + aux length [] l +;; + +let split_while xs ~f = + let rec loop acc = function + | hd :: tl when f hd -> loop (hd :: acc) tl + | t -> (rev acc, t) + in + loop [] xs +;; + +(* copied from [split_while] to avoid allocating a tuple *) +let take_while xs ~f = + let rec loop acc = function + | hd :: tl when f hd -> loop (hd :: acc) tl + | _ -> rev acc + in + loop [] xs +;; + +let rec drop_while t ~f = + match t with + | hd :: tl when f hd -> drop_while tl ~f + | t -> t + +let cartesian_product list1 list2 = + if is_empty list2 then [] else + let rec loop l1 l2 accum = match l1 with + | [] -> accum + | (hd :: tl) -> + loop tl l2 + (rev_append + (map ~f:(fun x -> (hd,x)) l2) + accum) + in + rev (loop list1 list2 []) + +let concat l = fold_right l ~init:[] ~f:append + +let concat_no_order l = fold l ~init:[] ~f:(fun acc l -> rev_append l acc) + +let cons x l = x :: l + +let is_sorted l ~compare = + let rec loop l = + match l with + | [] | [_] -> true + | x1 :: ((x2 :: _) as rest) -> + compare x1 x2 <= 0 && loop rest + in loop l + +let is_sorted_strictly l ~compare = + let rec loop l = + match l with + | [] | [_] -> true + | x1 :: ((x2 :: _) as rest) -> + compare x1 x2 < 0 && loop rest + in loop l +;; + +module Infix = struct + let ( @ ) = append +end + +let permute ?(random_state = Random.State.default) list = + match list with + (* special cases to speed things up in trivial cases *) + | [] | [_] -> list + | [ x; y ] -> if Random.State.bool random_state then [ y; x ] else list + | _ -> + let arr = Array.of_list list in + Array_permute.permute arr ~random_state; + Array.to_list arr; +;; + +let random_element_exn ?(random_state = Random.State.default) list = + if is_empty list + then failwith "List.random_element_exn: empty list" + else nth_exn list (Random.State.int random_state (length list)) +;; + +let random_element ?(random_state = Random.State.default) list = + try Some (random_element_exn ~random_state list) + with _ -> None +;; + +let rec compare cmp a b = + match a, b with + | [], [] -> 0 + | [], _ -> -1 + | _ , [] -> 1 + | x :: xs, y :: ys -> + let n = cmp x y in + if n = 0 then compare cmp xs ys + else n +;; + +let hash_fold_t = hash_fold_list + +let equal equal t1 t2 = + let rec loop ~equal t1 t2 = + match t1, t2 with + | [], [] -> true + | x1 :: t1, x2 :: t2 -> equal x1 x2 && loop ~equal t1 t2 + | _ -> false + in + loop ~equal t1 t2 +;; + +let transpose = + let rec transpose_aux t rev_columns = + match partition_map t ~f:(function [] -> `Snd () | x :: xs -> `Fst (x, xs)) with + | (_ :: _, _ :: _) -> None + | ([], _) -> Some (rev_append rev_columns []) + | (heads_and_tails, []) -> + let (column, trimmed_rows) = unzip heads_and_tails in + transpose_aux trimmed_rows (column :: rev_columns) + in + fun t -> + transpose_aux t [] + +exception Transpose_got_lists_of_different_lengths of int list [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Transpose_got_lists_of_different_lengths]) + (function + | Transpose_got_lists_of_different_lengths v0 -> + let v0 = sexp_of_list sexp_of_int v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom + "src/list.ml.Transpose_got_lists_of_different_lengths"; + v0] + | _ -> assert false) +[@@@end] + +let transpose_exn l = + match transpose l with + | Some l -> l + | None -> + raise (Transpose_got_lists_of_different_lengths (map l ~f:length)) + +let intersperse t ~sep = + match t with + | [] -> [] + | x :: xs -> x :: fold_right xs ~init:[] ~f:(fun y acc -> sep :: y :: acc) + +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t diff --git a/src/list.mli b/src/list.mli new file mode 100644 index 0000000..0a1729a --- /dev/null +++ b/src/list.mli @@ -0,0 +1,514 @@ +(** Immutable, singly-linked lists, giving fast access to the front of the list, and slow + (i.e., O(n)) access to the back of the list. The comparison functions on lists are + lexicographic. *) + +open! Import + +type 'a t = 'a list [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Container.S1 with type 'a t := 'a t +include Monad.S with type 'a t := 'a t + +(** [Or_unequal_lengths] is used for functions that take multiple lists and that only make + sense if all the lists have the same length, e.g., [iter2], [map3]. Such functions + check the list lengths prior to doing anything else, and return [Unequal_lengths] if + not all the lists have the same length. *) +module Or_unequal_lengths : sig + type 'a t = + | Ok of 'a + | Unequal_lengths + [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +(** [of_list] is the identity function. It is useful so that the [List] module matches + the same signature that other container modules do, namely: + + {[ + val of_list : 'a List.t -> 'a t + ]} *) +val of_list : 'a t -> 'a t + +val nth : 'a t -> int -> 'a option + +(** Return the [n]-th element of the given list. The first element (head of the list) is + at position 0. Raise if the list is too short or [n] is negative. *) +val nth_exn : 'a t -> int -> 'a + +(** List reversal. *) +val rev : 'a t -> 'a t + +(** [rev_append l1 l2] reverses [l1] and concatenates it to [l2]. This is equivalent + to [(]{!List.rev}[ l1) @ l2], but [rev_append] is more efficient. *) +val rev_append : 'a t -> 'a t -> 'a t + +(** [unordered_append l1 l2] has the same elements as [l1 @ l2], but in some + unspecified order. Generally takes time proportional to length of first list, but is + O(1) if either list is empty. *) +val unordered_append : 'a t -> 'a t -> 'a t + +(** [rev_map l ~f] gives the same result as {!List.rev}[ (]{!ListLabels.map}[ f l)], + but is more efficient. *) +val rev_map : 'a t -> f:('a -> 'b) -> 'b t + +(** [iter2 [a1; ...; an] [b1; ...; bn] ~f] calls in turn [f a1 b1; ...; f an bn]. + The exn version will raise if the two lists have different lengths. *) +val iter2_exn : 'a t -> 'b t -> f:('a -> 'b -> unit) -> unit +val iter2 : 'a t -> 'b t -> f:('a -> 'b -> unit) -> unit Or_unequal_lengths.t + +(** [rev_map2_exn l1 l2 ~f] gives the same result as [List.rev (List.map2_exn l1 l2 + ~f)], but is more efficient. *) +val rev_map2_exn : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t +val rev_map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t Or_unequal_lengths.t + +(** [fold2 ~f ~init:a [b1; ...; bn] [c1; ...; cn]] is [f (... (f (f a b1 c1) b2 c2) + ...) bn cn]. The exn version will raise if the two lists have different lengths. *) +val fold2_exn : 'a t -> 'b t -> init:'c -> f:('c -> 'a -> 'b -> 'c) -> 'c +val fold2 : 'a t -> 'b t -> init:'c -> f:('c -> 'a -> 'b -> 'c) -> 'c Or_unequal_lengths.t + +(** Like {!List.for_all}, but passes the index as an argument. *) +val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like {!List.for_all}, but for a two-argument predicate. The exn version will raise if + the two lists have different lengths. *) +val for_all2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool +val for_all2 : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool Or_unequal_lengths.t + +(** Like {!List.exists}, but passes the index as an argument. *) +val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + +(** Like {!List.exists}, but for a two-argument predicate. The exn version will raise if + the two lists have different lengths. *) +val exists2_exn : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool +val exists2 : 'a t -> 'b t -> f:('a -> 'b -> bool) -> bool Or_unequal_lengths.t + +(** [filter l ~f] returns all the elements of the list [l] that satisfy the predicate [p]. + The order of the elements in the input list is preserved. *) +val filter : 'a t -> f:('a -> bool) -> 'a t + +(** Like [filter], but reverses the order of the input list. *) +val rev_filter : 'a t -> f:('a -> bool) -> 'a t + +val filteri : 'a t -> f: (int -> 'a -> bool) -> 'a t + +(** [partition_map t ~f] partitions [t] according to [f]. *) +val partition_map : 'a t -> f:('a -> [ `Fst of 'b | `Snd of 'c ]) -> 'b t * 'c t + +val partition3_map + : 'a t + -> f:('a -> [ `Fst of 'b | `Snd of 'c | `Trd of 'd]) + -> 'b t * 'c t * 'd t + +(** [partition_tf l ~f] returns a pair of lists [(l1, l2)], where [l1] is the list of all + the elements of [l] that satisfy the predicate [p], and [l2] is the list of all the + elements of [l] that do not satisfy [p]. The order of the elements in the input list + is preserved. The "tf" suffix is mnemonic to remind readers at a call that the result + is (trues, falses). *) +val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t + +(** [partition_result l] returns a pair of lists [(l1, l2)], where [l1] is the + list of all [Ok] elements in [l] and [l2] is the list of all [Error] + elements. + The order of elements in the input list is preserved. *) +val partition_result : ('ok, 'error) Result.t t -> 'ok t * 'error t + +(** [split_n \[e1; ...; em\] n] is [(\[e1; ...; en\], \[en+1; ...; em\])]. + + - If [n > m], [(\[e1; ...; em\], \[\])] is returned. + - If [n < 0], [(\[\], \[e1; ...; em\])] is returned. *) +val split_n : 'a t -> int -> 'a t * 'a t + +(** Sort a list in increasing order according to a comparison function. The comparison + function must return 0 if its arguments compare as equal, a positive integer if the + first is greater, and a negative integer if the first is smaller (see [Array.sort] for + a complete specification). For example, {!Poly.compare} is a suitable + comparison function. + + The current implementation uses Merge Sort. It runs in linear heap space and + logarithmic stack space. + + Presently, the sort is stable, meaning that two equal elements in the input will be in + the same order in the output. *) +val sort : 'a t -> compare:('a -> 'a -> int) -> 'a t + +(** Like [sort], but guaranteed to be stable. *) +val stable_sort : 'a t -> compare:('a -> 'a -> int) -> 'a t + +(** Merges two lists: assuming that [l1] and [l2] are sorted according to the comparison + function [compare], [merge compare l1 l2] will return a sorted list containing all the + elements of [l1] and [l2]. If several elements compare equal, the elements of [l1] + will be before the elements of [l2]. *) +val merge : 'a t -> 'a t -> compare:('a -> 'a -> int) -> 'a t + +val hd : 'a t -> 'a option + +val tl : 'a t -> 'a t option + +(** Returns the first element of the given list. Raises if the list is empty. *) +val hd_exn : 'a t -> 'a + +(** Returns the given list without its first element. Raises if the list is empty. *) +val tl_exn : 'a t -> 'a t + +val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + +(** [find_exn t ~f] returns the first element of [t] that satisfies [f]. It raises + [Caml.Not_found] or [Not_found_s] if there is no such element. *) +val find_exn : 'a t -> f:('a -> bool) -> 'a + +(** Returns the first evaluation of [f] that returns [Some]. Raises [Caml.Not_found] or + [Not_found_s] if [f] always returns [None]. *) +val find_map_exn : 'a t -> f:('a -> 'b option) -> 'b + +(** Like [find_map] and [find_map_exn], but passes the index as an argument. *) + +val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option +val find_mapi_exn : 'a t -> f:(int -> 'a -> 'b option) -> 'b + +(** E.g., [append [1; 2] [3; 4; 5]] is [[1; 2; 3; 4; 5]] *) +val append : 'a t -> 'a t -> 'a t + +(** [map f [a1; ...; an]] applies function [f] to [a1], [a2], ..., [an], in order, + and builds the list [[f a1; ...; f an]] with the results returned by [f]. *) +val map : 'a t -> f:('a -> 'b) -> 'b t + +(** [folding_map] is a version of [map] that threads an accumulator through calls to + [f]. *) + +val folding_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'c t +val folding_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'c t + +(** [fold_map] is a combination of [fold] and [map] that threads an accumulator through + calls to [f]. *) + +val fold_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'b * 'c t +val fold_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'b * 'c t + +(** [concat_map t ~f] is [concat (map t ~f)], except that there is no guarantee about the + order in which [f] is applied to the elements of [t]. *) +val concat_map : 'a t -> f:('a -> 'b t) -> 'b t + +(** [concat_mapi t ~f] is like concat_map, but passes the index as an argument *) +val concat_mapi : 'a t -> f:(int -> 'a -> 'b t) -> 'b t + +(** [map2 [a1; ...; an] [b1; ...; bn] ~f] is [[f a1 b1; ...; f an bn]]. The exn + version will raise if the two lists have different lengths. *) +val map2_exn :'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t +val map2 :'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t Or_unequal_lengths.t + +(** Analogous to [rev_map2]. *) + +val rev_map3_exn : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t +val rev_map3 : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t Or_unequal_lengths.t + +(** Analogous to [map2]. *) + +val map3_exn : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t +val map3 : 'a t -> 'b t -> 'c t -> f:('a -> 'b -> 'c -> 'd) -> 'd t Or_unequal_lengths.t + +(** [rev_map_append l1 l2 ~f] reverses [l1] mapping [f] over each element, and appends the + result to the front of [l2]. *) +val rev_map_append : 'a t -> 'b t -> f:('a -> 'b) -> 'b t + +(** [fold_right [a1; ...; an] ~f ~init:b] is [f a1 (f a2 (... (f an b) ...))]. *) +val fold_right : 'a t -> f:('a -> 'b -> 'b) -> init:'b -> 'b + +(** [fold_left] is the same as {!Container.S1.fold}, and one should always use + [fold] rather than [fold_left], except in functors that are parameterized + over a more general signature where this equivalence does not hold. *) +val fold_left : 'a t -> init:'b -> f:('b -> 'a -> 'b) -> 'b + +(** Transform a list of pairs into a pair of lists: [unzip [(a1,b1); ...; (an,bn)]] is + [([a1; ...; an], [b1; ...; bn])]. *) + +val unzip : ('a * 'b ) t -> 'a t * 'b t +val unzip3 : ('a * 'b * 'c) t -> 'a t * 'b t * 'c t + +(** Transform a pair of lists into an (optional) list of pairs: [zip [a1; ...; an] [b1; + ...; bn]] is [[(a1,b1); ...; (an,bn)]]. Returns [Unequal_lengths] if the two lists + have different lengths. *) + +val zip : 'a t -> 'b t -> ('a * 'b) t Or_unequal_lengths.t +val zip_exn : 'a t -> 'b t -> ('a * 'b) t + +(** [mapi] is just like map, but it also passes in the index of each element as the first + argument to the mapped function. Tail-recursive. *) +val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +val rev_mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +(** [iteri] is just like [iter], but it also passes in the index of each element as the + first argument to the iter'd function. Tail-recursive. *) +val iteri : 'a t -> f:(int -> 'a -> unit) -> unit + +(** [foldi] is just like [fold], but it also passes in the index of each element as the + first argument to the folded function. Tail-recursive. *) +val foldi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b) -> 'b + +(** [reduce_exn [a1; ...; an] ~f] is [f (... (f (f a1 a2) a3) ...) an]. It fails on the + empty list. Tail recursive. *) +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a +val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option + +(** [reduce_balanced] returns the same value as [reduce] when [f] is associative, but + differs in that the tree of nested applications of [f] has logarithmic depth. + + This is useful when your ['a] grows in size as you reduce it and [f] becomes more + expensive with bigger inputs. For example, [reduce_balanced ~f:(^)] takes [n*log(n)] + time, while [reduce ~f:(^)] takes quadratic time. *) +val reduce_balanced : 'a t -> f:('a -> 'a -> 'a) -> 'a option +val reduce_balanced_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a + +(** [group l ~break] returns a list of lists (i.e., groups) whose concatenation is equal + to the original list. Each group is broken where [break] returns true on a pair of + successive elements. + + Example: + + {[ + group ~break:(<>) ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] -> + + [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] ]} *) +val group : 'a t -> break:('a -> 'a -> bool) -> 'a t t + +(** This is just like [group], except that you get the index in the original list of the + current element along with the two elements. + + Example, group the chars of ["Mississippi"] into triples: + + {[ + groupi ~break:(fun i _ _ -> i mod 3 = 0) + ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] -> + + [['M'; 'i'; 's']; ['s'; 'i'; 's']; ['s'; 'i'; 'p']; ['p'; 'i']] ]} +*) +val groupi : 'a t -> break:(int -> 'a -> 'a -> bool) -> 'a t t + +(** [chunks_of l ~length] returns a list of lists whose concatenation is equal to the + original list. Every list has [length] elements, except for possibly the last list, + which may have fewer. [chunks_of] raises if [length <= 0]. *) +val chunks_of : 'a t -> length : int -> 'a t t + +(** The final element of a list. The [_exn] version raises on the empty list. *) +val last : 'a t -> 'a option +val last_exn : 'a t -> 'a + +(** [is_prefix xs ~prefix] returns [true] if [xs] starts with [prefix]. *) +val is_prefix : 'a t -> prefix:'a t -> equal:('a -> 'a -> bool) -> bool + + +(** [find_consecutive_duplicate t ~equal] returns the first pair of consecutive elements + [(a1, a2)] in [t] such that [equal a1 a2]. They are returned in the same order as + they appear in [t]. [equal] need not be an equivalence relation; it is simply used as + a predicate on consecutive elements. *) +val find_consecutive_duplicate : 'a t -> equal:('a -> 'a -> bool) -> ('a * 'a) option + +(** Returns the given list with consecutive duplicates removed. The relative order of the + other elements is unaffected. The element kept from a run of duplicates is determined + by [which_to_keep]. *) +val remove_consecutive_duplicates + : ?which_to_keep:[ `First | `Last ] (** default = `Last *) + -> 'a t + -> equal:('a -> 'a -> bool) + -> 'a t + +(** Returns the given list with duplicates removed and in sorted order. *) +val dedup_and_sort : compare:('a -> 'a -> int) -> 'a t -> 'a t + +val dedup : compare:('a -> 'a -> int) -> 'a t -> 'a t +[@@deprecated "[since 2017-04] Use [dedup_and_sort] instead"] + +(** [find_a_dup] returns a duplicate from the list (with no guarantees about which + duplicate you get), or [None] if there are no dups. *) +val find_a_dup : compare:('a -> 'a -> int) -> 'a t -> 'a option + +(** Returns true if there are any two elements in the list which are the same. O(n log n) + time complexity. *) +val contains_dup : compare:('a -> 'a -> int) -> 'a t -> bool + +(** [find_all_dups] returns a list of all elements that occur more than once, with + no guarantees about order. O(n log n) time complexity. *) +val find_all_dups : compare:('a -> 'a -> int) -> 'a t -> 'a list + +(** [count l ~f] is the number of elements in [l] that satisfy the predicate [f]. *) +val count : 'a t -> f:( 'a -> bool) -> int +val counti : 'a t -> f:(int -> 'a -> bool) -> int + +(** [range ?stride ?start ?stop start_i stop_i] is the list of integers from [start_i] to + [stop_i], stepping by [stride]. If [stride] < 0 then we need [start_i] > [stop_i] for + the result to be nonempty (or [start_i] = [stop_i] in the case where both bounds are + inclusive). *) +val range + : ?stride:int (** default = 1 *) + -> ?start:[`inclusive|`exclusive] (** default = `inclusive *) + -> ?stop:[`inclusive|`exclusive] (** default = `exclusive *) + -> int + -> int + -> int t + +(** [range'] is analogous to [range] for general start/stop/stride types. [range'] raises + if [stride x] returns [x] or if the direction that [stride x] moves [x] changes from + one call to the next. *) +val range' + : compare:('a -> 'a -> int) + -> stride:('a -> 'a) + -> ?start:[`inclusive|`exclusive] (** default = `inclusive *) + -> ?stop:[`inclusive|`exclusive] (** default = `exclusive *) + -> 'a + -> 'a + -> 'a t + +(** [init n ~f] is [[(f 0); (f 1); ...; (f (n-1))]]. It is an error if [n < 0]. + [init] applies [f] to values in decreasing order; starting with [n - 1], and + ending with [0]. This is the opposite order to [Array.init]. *) +val init : int -> f:(int -> 'a) -> 'a t + +(** [rev_filter_map l ~f] is the reversed sublist of [l] containing only elements for + which [f] returns [Some e]. *) +val rev_filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** rev_filter_mapi is just like [rev_filter_map], but it also passes in the index of each + element as the first argument to the mapped function. Tail-recursive. *) +val rev_filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** [filter_map l ~f] is the sublist of [l] containing only elements for which [f] returns + [Some e]. *) +val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** filter_mapi is just like [filter_map], but it also passes in the index of each element + as the first argument to the mapped function. Tail-recursive. *) +val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** [filter_opt l] is the sublist of [l] containing only elements which are [Some e]. In + other words, [filter_opt l] = [filter_map ~f:ident l]. *) +val filter_opt : 'a option t -> 'a t + +(** Interpret a list of (key, value) pairs as a map in which only the first occurrence of + a key affects the semantics, i.e.: + + {[List.Assoc.xxx alist ...args... ]} + + is always the same as (or at least sort of isomorphic to): + + {[ Map.xxx (alist |> Map.of_alist_multi |> Map.map ~f:List.hd) ...args... ]} *) +module Assoc : sig + + type ('a, 'b) t = ('a * 'b) list [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] + + val add : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> 'b -> ('a, 'b) t + val find : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> 'b option + val find_exn : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> 'b + val mem : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> bool + val remove : ('a, 'b) t -> equal:('a -> 'a -> bool) -> 'a -> ('a, 'b) t + val map : ('a, 'b) t -> f:('b -> 'c) -> ('a, 'c) t + + (** Bijectivity is not guaranteed because we allow a key to appear more than once. *) + val inverse : ('a, 'b) t -> ('b, 'a) t +end + +(** Note that [sub], unlike [slice], doesn't use Python-style indices! *) + +(** [sub pos len l] is the [len]-element sublist of [l], starting at [pos]. *) +val sub : 'a t -> pos:int -> len:int -> 'a t + +(** [take l n] returns the first [n] elements of [l], or all of [l] if [n > length l]. + [take l n = fst (split_n l n)]. *) +val take : 'a t -> int -> 'a t + +(** [drop l n] returns [l] without the first [n] elements, or the empty list if [n > + length l]. [drop l n] is equivalent to [snd (split_n l n)]. *) +val drop : 'a t -> int -> 'a t + +(** [take_while l ~f] returns the longest prefix of [l] for which [f] is [true]. *) +val take_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [drop_while l ~f] drops the longest prefix of [l] for which [f] is [true]. *) +val drop_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [split_while xs ~f = (take_while xs ~f, drop_while xs ~f)]. *) +val split_while : 'a t -> f : ('a -> bool) -> 'a t * 'a t + +(** Concatenates a list of lists. The elements of the argument are all concatenated + together (in the same order) to give the result. Tail recursive over outer and inner + lists. *) +val concat : 'a t t -> 'a t + +(** Like [concat], but faster and without preserving any ordering (i.e., for lists that + are essentially viewed as multi-sets). *) +val concat_no_order : 'a t t -> 'a t + +val cons : 'a -> 'a t -> 'a t + +(** Returns a list with all possible pairs -- if the input lists have length [len1] and + [len2], the resulting list will have length [len1 * len2]. *) +val cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** [permute ?random_state t] returns a permutation of [t]. + + [permute] side-effects [random_state] by repeated calls to [Random.State.int]. If + [random_state] is not supplied, [permute] uses [Random.State.default]. *) +val permute : ?random_state:Random.State.t -> 'a t -> 'a t + +(** [random_element ?random_state t] is [None] if [t] is empty, else it is [Some x] for + some [x] chosen uniformly at random from [t]. + + [random_element] side-effects [random_state] by calling [Random.State.int]. If + [random_state] is not supplied, [random_element] uses [Random.State.default]. *) +val random_element : ?random_state:Random.State.t -> 'a t -> 'a option +val random_element_exn : ?random_state:Random.State.t -> 'a t -> 'a + +(** [is_sorted t ~compare] returns [true] iff for all adjacent [a1; a2] in [t], [compare + a1 a2 <= 0]. + + [is_sorted_strictly] is similar, except it uses [<] instead of [<=]. *) +val is_sorted : 'a t -> compare:('a -> 'a -> int) -> bool +val is_sorted_strictly : 'a t -> compare:('a -> 'a -> int) -> bool + +val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + +module Infix : sig + val ( @ ) : 'a t -> 'a t -> 'a t +end + +(** [transpose m] transposes the rows and columns of the matrix [m], + considered as either a row of column lists or (dually) a column of row lists. + + Example: + + {[transpose [[1;2;3];[4;5;6]] = [[1;4];[2;5];[3;6]]]} + + On non-empty rectangular matrices, [transpose] is an involution (i.e., [transpose + (transpose m) = m]). Transpose returns [None] when called on lists of lists with + non-uniform lengths. *) +val transpose : 'a t t -> 'a t t option + +(** [transpose_exn] transposes the rows and columns of its argument, throwing an exception + if the list is not rectangular. *) +val transpose_exn : 'a t t -> 'a t t + +(** [intersperse xs ~sep] places [sep] between adjacent elements of [xs]. For example, + [intersperse [1;2;3] ~sep:0 = [1;0;2;0;3]]. *) +val intersperse : 'a t -> sep:'a -> 'a t diff --git a/src/list0.ml b/src/list0.ml new file mode 100644 index 0000000..44e423a --- /dev/null +++ b/src/list0.ml @@ -0,0 +1,42 @@ +(* [List0] defines list functions that are primitives or can be simply defined in terms of + [Caml.List]. [List0] is intended to completely express the part of [Caml.List] that + [Base] uses -- no other file in Base other than list0.ml should use [Caml.List]. + [List0] has few dependencies, and so is available early in Base's build order. All + Base files that need to use lists and come before [Base.List] in build order should do + [module List = List0]. Defining [module List = List0] is also necessary because it + prevents ocamldep from mistakenly causing a file to depend on [Base.List]. *) + +open! Import0 + +let hd_exn = Caml.List.hd +let length = Caml.List.length +let rev_append = Caml.List.rev_append +let tl_exn = Caml.List.tl +let unzip = Caml.List.split + +(* These are eta expanded in order to permute parameter order to follow Base + conventions. *) +let exists t ~f = Caml.List.exists t ~f +let exists2_ok l1 l2 ~f = Caml.List.exists2 l1 l2 ~f +let find_exn t ~f = Caml.List.find t ~f +let fold t ~init ~f = Caml.List.fold_left t ~f ~init +let fold2_ok l1 l2 ~init ~f = Caml.List.fold_left2 l1 l2 ~init ~f +let for_all t ~f = Caml.List.for_all t ~f +let for_all2_ok l1 l2 ~f = Caml.List.for_all2 l1 l2 ~f +let iter t ~f = Caml.List.iter t ~f +let iter2_ok l1 l2 ~f = Caml.List.iter2 l1 l2 ~f +let nontail_map t ~f = Caml.List.map t ~f +let nontail_mapi t ~f = Caml.List.mapi t ~f +let partition t ~f = Caml.List.partition t ~f +let rev_map t ~f = Caml.List.rev_map t ~f +let rev_map2_ok l1 l2 ~f = Caml.List.rev_map2 l1 l2 ~f + +let sort l ~compare = Caml.List.sort l ~cmp:compare +let stable_sort l ~compare = Caml.List.stable_sort l ~cmp:compare + +let rev = function + | [] | [_] as res -> res + | x :: y :: rest -> rev_append rest [y; x] +;; + +let is_empty = function [] -> true | _ -> false diff --git a/src/list1.ml b/src/list1.ml new file mode 100644 index 0000000..d3856eb --- /dev/null +++ b/src/list1.ml @@ -0,0 +1,15 @@ +open! Import + +include List0 + +let partition_map t ~f = + let rec loop t fst snd = + match t with + | [] -> (rev fst, rev snd) + | x :: t -> + match f x with + | `Fst y -> loop t (y :: fst) snd + | `Snd y -> loop t fst (y :: snd) + in + loop t [] [] +;; diff --git a/src/map.ml b/src/map.ml new file mode 100644 index 0000000..18a2ee8 --- /dev/null +++ b/src/map.ml @@ -0,0 +1,1770 @@ +(***********************************************************************) +(* *) +(* Objective Caml *) +(* *) +(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *) +(* *) +(* Copyright 1996 Institut National de Recherche en Informatique et *) +(* en Automatique. All rights reserved. This file is distributed *) +(* under the terms of the Apache 2.0 license. See ../THIRD-PARTY.txt *) +(* for details. *) +(* *) +(***********************************************************************) + +open! Import + +module List = List0 + +include Map_intf + +let with_return = With_return.with_return + +exception Duplicate [@@deriving_inline sexp] +let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Duplicate]) + (function + | Duplicate -> Ppx_sexp_conv_lib.Sexp.Atom "src/map.ml.Duplicate" + | _ -> assert false) +[@@@end] + +module Tree0 = struct + + type ('k, 'v) t = + | Empty + | Leaf of 'k * 'v + | Node of ('k, 'v) t * 'k * 'v * ('k, 'v) t * int + + type ('k, 'v) tree = ('k, 'v) t + + let height = function + | Empty -> 0 + | Leaf _ -> 1 + | Node(_,_,_,_,h) -> h + ;; + + let invariants = + let in_range lower upper compare_key k = + (match lower with + | None -> true + | Some lower -> compare_key lower k < 0 + ) + && (match upper with + | None -> true + | Some upper -> compare_key k upper < 0 + ) + in + let rec loop lower upper compare_key t = + match t with + | Empty -> true + | Leaf (k, _) -> in_range lower upper compare_key k + | Node (l, k, _, r, h) -> + let hl = height l and hr = height r in + abs (hl - hr) <= 2 + && h = (max hl hr) + 1 + && in_range lower upper compare_key k + && loop lower (Some k) compare_key l + && loop (Some k) upper compare_key r + in + fun t ~compare_key -> + loop None None compare_key t + ;; + + (* precondition: |height(l) - height(r)| <= 2 *) + let create l x d r = + let hl = height l and hr = height r in + if hl = 0 && hr = 0 then + Leaf (x, d) + else + Node(l, x, d, r, (if hl >= hr then hl + 1 else hr + 1)) + ;; + + let singleton key data = Leaf (key, data) + + (* We must call [f] with increasing indexes, because the bin_prot reader in + Core_kernel.Map needs it. *) + let of_increasing_iterator_unchecked ~len ~f = + let rec loop n ~f i : (_, _) t = + match n with + | 0 -> Empty + | 1 -> + let k,v = f i in + Leaf (k, v) + | 2 -> + let kl,vl = f i in + let k ,v = f (i + 1) in + Node (Leaf (kl, vl), k, v, Empty, 2) + | 3 -> + let kl,vl = f i in + let k,v = f (i + 1) in + let kr,vr = f (i + 2) in + Node (Leaf (kl, vl), k, v, Leaf (kr, vr), 2) + | n -> + let left_length = n lsr 1 in + let right_length = n - left_length - 1 in + let left = loop left_length ~f i in + let k, v = f (i + left_length) in + let right = loop right_length ~f (i + left_length + 1) in + create left k v right + in + loop len ~f 0 + + let of_sorted_array_unchecked array ~compare_key = + let array_length = Array.length array in + let next = + if array_length < 2 + || let k0, _ = array.(0) in + let k1, _ = array.(1) in + compare_key k0 k1 < 0 + then + (fun i -> array.(i)) + else + (fun i -> array.(array_length - 1 - i)) + in + (of_increasing_iterator_unchecked ~len:array_length ~f:next, array_length) + + let of_sorted_array array ~compare_key = + match array with + | [||] | [|_|] -> Result.Ok (of_sorted_array_unchecked array ~compare_key) + | _ -> + with_return (fun r -> + let increasing = + match compare_key (fst array.(0)) (fst array.(1)) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> i < 0 + in + for i = 1 to Array.length array - 2 do + match compare_key (fst array.(i)) (fst array.(i+1)) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> + if Poly.(<>) (i < 0) increasing then + r.return (Or_error.error_string "of_sorted_array: elements are not ordered") + done; + Result.Ok (of_sorted_array_unchecked array ~compare_key) + ) + + (* precondition: |height(l) - height(r)| <= 3 *) + let bal l x d r = + let hl = height l in + let hr = height r in + if hl > hr + 2 then begin + match l with + | Empty -> invalid_arg "Map.bal" + | Leaf _ -> assert false (* height(Leaf) = 1 && 1 is not larger than hr + 2 *) + | Node(ll, lv, ld, lr, _) -> + if height ll >= height lr then + create ll lv ld (create lr x d r) + else begin + match lr with + Empty -> invalid_arg "Map.bal" + | Leaf (lrv, lrd) -> + create (create ll lv ld Empty) lrv lrd (create Empty x d r) + | Node(lrl, lrv, lrd, lrr, _)-> + create (create ll lv ld lrl) lrv lrd (create lrr x d r) + end + end else if hr > hl + 2 then begin + match r with + | Empty -> invalid_arg "Map.bal" + | Leaf _ -> assert false (* height(Leaf) = 1 && 1 is not larger than hl + 2 *) + | Node(rl, rv, rd, rr, _) -> + if height rr >= height rl then + create (create l x d rl) rv rd rr + else begin + match rl with + Empty -> invalid_arg "Map.bal" + | Leaf (rlv, rld) -> + create (create l x d Empty) rlv rld (create Empty rv rd rr) + | Node(rll, rlv, rld, rlr, _) -> + create (create l x d rll) rlv rld (create rlr rv rd rr) + end + end + else create l x d r + ;; + + let empty = Empty + + let is_empty = function Empty -> true | _ -> false + + let raise_key_already_present ~key ~sexp_of_key = + Error.raise_s ( + Sexp.message + "[Map.add_exn] got key already present" + [ "key", key |> sexp_of_key ]) + ;; + + module Add_or_set = struct + type t = + | Add_exn_internal + | Add_exn + | Set + end + + let rec find_and_add_or_set t ~length ~key:x ~data ~compare_key ~sexp_of_key + ~(add_or_set : Add_or_set.t) = + match t with + | Empty -> (Leaf (x, data), length + 1) + | Leaf(v, d) -> + let c = compare_key x v in + if c = 0 then + (match add_or_set with + | Add_exn_internal -> Exn.raise_without_backtrace Duplicate + | Add_exn -> raise_key_already_present ~key:x ~sexp_of_key + | Set -> (Leaf(x, data), length)) + else if c < 0 then + (Node(Leaf(x, data), v, d, Empty, 2), length + 1) + else + (Node(Empty, v, d, Leaf(x, data), 2), length + 1) + | Node(l, v, d, r, h) -> + let c = compare_key x v in + if c = 0 then + (match add_or_set with + | Add_exn_internal -> Exn.raise_without_backtrace Duplicate + | Add_exn -> raise_key_already_present ~key:x ~sexp_of_key + | Set -> (Node(l, x, data, r, h), length)) + else if c < 0 then + let l, length = + find_and_add_or_set ~length ~key:x ~data l ~compare_key ~sexp_of_key ~add_or_set + in + (bal l v d r, length) + else + let r, length = + find_and_add_or_set ~length ~key:x ~data r ~compare_key ~sexp_of_key ~add_or_set + in + (bal l v d r, length) + ;; + + let add_exn t ~length ~key ~data ~compare_key ~sexp_of_key = + find_and_add_or_set t ~length ~key ~data ~compare_key ~sexp_of_key ~add_or_set:Add_exn + ;; + + let add_exn_internal t ~length ~key ~data ~compare_key ~sexp_of_key = + find_and_add_or_set + t ~length ~key ~data ~compare_key ~sexp_of_key ~add_or_set:Add_exn_internal + ;; + + let set t ~length ~key ~data ~compare_key = + find_and_add_or_set t ~length ~key ~data + ~compare_key + ~sexp_of_key:(fun _ -> List []) + ~add_or_set:Set + ;; + + let set' t key data ~compare_key = fst (set t ~length:0 ~key ~data ~compare_key) + + module Build_increasing = struct + module Fragment = struct + type nonrec ('k, 'v) t = { + left_subtree : ('k, 'v) t; + key : 'k; + data: 'v; + } + + let singleton_to_tree_exn = function + | { left_subtree = Empty; key; data; } -> singleton key data + | _ -> failwith "Map.singleton_to_tree_exn: not a singleton" + ;; + + let singleton ~key ~data = { left_subtree = Empty; key; data; } + + (* precondition: |height(l.left_subtree) - height(r)| <= 2, + max_key(l) < min_key(r) + *) + let collapse l r = create l.left_subtree l.key l.data r + + (* precondition: |height(l.left_subtree) - height(r.left_subtree)| <= 2, + max_key(l) < min_key(r) + *) + let join l r = { r with left_subtree = collapse l r.left_subtree; } + + let max_key t = t.key + end + + (** Build trees from singletons in a balanced way by using skew binary encoding. + Each level contains trees of the same height, consecutive levels have consecutive + heights. There are no gaps. The first level are single keys. + *) + type ('k, 'v) t = + | Zero of unit (* [unit] to make pattern matching faster *) + | One of ('k, 'v) t * ('k, 'v) Fragment.t + | Two of ('k, 'v) t * ('k, 'v) Fragment.t * ('k, 'v) Fragment.t + + let empty = Zero () + + let add_unchecked = + let rec go t x = match t with + | Zero () -> One (t, x) + | One (t, y) -> Two (t, y, x) + | Two (t, z, y) -> One (go t (Fragment.join z y), x) + in + fun t ~key ~data -> go t (Fragment.singleton ~key ~data) + ;; + + let to_tree = + let rec go t r = match t with + | Zero () -> r + | One (t, l) -> go t (Fragment.collapse l r) + | Two (t, ll, l) -> go t (Fragment.collapse (Fragment.join ll l) r) + in + function + | Zero () -> Empty + | One (t, r) -> go t (Fragment.singleton_to_tree_exn r) + | Two (t, l, r) -> go (One (t, l)) (Fragment.singleton_to_tree_exn r) + ;; + + let max_key = function + | Zero () -> None + | One (_, r) | Two (_, _, r) -> Some (Fragment.max_key r) + ;; + end + + let of_increasing_sequence seq ~compare_key = with_return (fun { return; } -> + let builder, length = + Sequence.fold seq ~init:(Build_increasing.empty, 0) + ~f:(fun (builder, length) (key, data) -> + match Build_increasing.max_key builder with + | Some prev_key when compare_key prev_key key >= 0 -> + return (Or_error.error_string "of_increasing_sequence: non-increasing key") + | _ -> Build_increasing.add_unchecked builder ~key ~data, length + 1) + in + Ok (Build_increasing.to_tree builder, length)) + ;; + + (* Like [bal] but allows any difference in height between [l] and [r]. + + O(|height l - height r|) *) + let rec join l k d r ~compare_key = + match l, r with + | Empty, _ -> set' r k d ~compare_key + | _, Empty -> set' l k d ~compare_key + | Leaf(lk, ld), _ -> set' (set' r k d ~compare_key) lk ld ~compare_key + | _, Leaf(rk, rd) -> set' (set' l k d ~compare_key) rk rd ~compare_key + | Node(ll, lk, ld, lr, lh), Node(rl, rk, rd, rr, rh) -> + (* [bal] requires height difference <= 3. *) + if lh > rh + 3 + (* [height lr >= height r], + therefore [height (join lr k d r ...)] is [height rl + 1] or [height rl] + therefore the height difference with [ll] will be <= 3 *) + then bal ll lk ld (join lr k d r ~compare_key) + else if rh > lh + 3 + then bal (join l k d rl ~compare_key) rk rd rr + else bal l k d r + ;; + + let rec split t x ~compare_key = + match t with + | Empty -> (Empty, None, Empty) + | Leaf(k, d) -> + let cmp = compare_key x k in + if cmp = 0 then (Empty, Some (k, d), Empty) + else if cmp < 0 then (Empty, None, t) + else (t, None, Empty) + | Node(l, k, d, r, _) -> + let cmp = compare_key x k in + if cmp = 0 then (l, Some (k, d), r) + else if cmp < 0 then + let ll, maybe, lr = split l x ~compare_key in + (ll, maybe, join lr k d r ~compare_key) + else + let rl, maybe, rr = split r x ~compare_key in + (join l k d rl ~compare_key, maybe, rr) + ;; + + let split_and_reinsert_boundary t ~into x ~compare_key = + let left, boundary_opt, right = split t x ~compare_key in + match boundary_opt with + | None -> left, right + | Some (key, data) -> + let insert_into tree = fst (set tree ~key ~data ~length:0 ~compare_key) in + match into with + | `Left -> insert_into left, right + | `Right -> left, insert_into right + ;; + + let split_range t ~(lower_bound : 'a Maybe_bound.t) ~(upper_bound : 'a Maybe_bound.t) ~compare_key = + if Maybe_bound.bounds_crossed ~compare:compare_key ~lower:lower_bound ~upper:upper_bound + then empty, empty, empty + else + let left, mid_and_right = + match lower_bound with + | Unbounded -> empty, t + | Incl lb -> split_and_reinsert_boundary ~into:`Right t lb ~compare_key + | Excl lb -> split_and_reinsert_boundary ~into:`Left t lb ~compare_key + in + let mid, right = + match upper_bound with + | Unbounded -> mid_and_right, empty + | Incl lb -> split_and_reinsert_boundary ~into:`Left mid_and_right lb ~compare_key + | Excl lb -> split_and_reinsert_boundary ~into:`Right mid_and_right lb ~compare_key + in + left, mid, right + ;; + + let rec find t x ~compare_key = + match t with + | Empty -> None + | Leaf (v, d) -> if compare_key x v = 0 then Some d else None + | Node(l, v, d, r, _) -> + let c = compare_key x v in + if c = 0 then Some d + else find (if c < 0 then l else r) x ~compare_key + ;; + + let add_multi t ~length ~key ~data ~compare_key = + let data = data :: Option.value (find t key ~compare_key) ~default:[] in + set ~length ~key ~data t ~compare_key + ;; + + let find_multi t x ~compare_key = + match find t x ~compare_key with + | None -> [] + | Some l -> l + ;; + + let find_exn t x ~compare_key = + match find t x ~compare_key with + | Some data -> data + | None -> + raise Caml.Not_found + ;; + + let mem t x ~compare_key = Option.is_some (find t x ~compare_key) + + let rec min_elt = function + | Empty -> None + | Leaf (k, d) -> Some (k, d) + | Node (Empty, k, d, _, _) -> Some (k, d) + | Node (l, _, _, _, _) -> min_elt l + ;; + + exception Map_min_elt_exn_of_empty_map [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Map_min_elt_exn_of_empty_map]) + (function + | Map_min_elt_exn_of_empty_map -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/map.ml.Tree0.Map_min_elt_exn_of_empty_map" + | _ -> assert false) + [@@@end] + exception Map_max_elt_exn_of_empty_map [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Map_max_elt_exn_of_empty_map]) + (function + | Map_max_elt_exn_of_empty_map -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/map.ml.Tree0.Map_max_elt_exn_of_empty_map" + | _ -> assert false) + [@@@end] + + let min_elt_exn t = + match min_elt t with + | None -> raise Map_min_elt_exn_of_empty_map + | Some v -> v + ;; + + let rec max_elt = function + | Empty -> None + | Leaf (k, d) -> Some (k, d) + | Node (_, k, d, Empty, _) -> Some (k, d) + | Node (_, _, _, r, _) -> max_elt r + ;; + let max_elt_exn t = + match max_elt t with + | None -> raise Map_max_elt_exn_of_empty_map + | Some v -> v + ;; + + let rec remove_min_elt t = + match t with + Empty -> invalid_arg "Map.remove_min_elt" + | Leaf _ -> Empty + | Node(Empty, _, _, r, _) -> r + | Node(l, x, d, r, _) -> bal (remove_min_elt l) x d r + + let append ~lower_part ~upper_part ~compare_key = + match max_elt lower_part, min_elt upper_part with + | None, _ -> `Ok upper_part + | _, None -> `Ok lower_part + | Some (max_lower, _), Some (min_upper, v) + when compare_key max_lower min_upper < 0 -> + let upper_part_without_min = remove_min_elt upper_part in + `Ok (join ~compare_key:compare_key lower_part min_upper v upper_part_without_min) + | _ -> `Overlapping_key_ranges + ;; + + + let fold_range_inclusive = + (* This assumes that min <= max, which is checked by the outer function. *) + let rec go t ~min ~max ~init ~f ~compare_key = + match t with + | Empty -> init + | Leaf (k, d) -> + if compare_key k min < 0 || compare_key k max > 0 then + (* k < min || k > max *) + init + else + f ~key:k ~data:d init + | Node (l, k, d, r, _) -> + let c_min = compare_key k min in + if c_min < 0 then + (* if k < min, then this node and its left branch are outside our range *) + go r ~min ~max ~init ~f ~compare_key + else if c_min = 0 then + (* if k = min, then this node's left branch is outside our range *) + go r ~min ~max ~init:(f ~key:k ~data:d init) ~f ~compare_key + else (* k > min *) + begin + let z = go l ~min ~max ~init ~f ~compare_key in + let c_max = compare_key k max in + (* if k > max, we're done *) + if c_max > 0 then z + else + let z = f ~key:k ~data:d z in + (* if k = max, then we fold in this one last value and we're done *) + if c_max = 0 then z + else go r ~min ~max ~init:z ~f ~compare_key + end + in fun t ~min ~max ~init ~f ~compare_key -> + if compare_key min max <= 0 then + go t ~min ~max ~init ~f ~compare_key + else + init + ;; + + let range_to_alist t ~min ~max ~compare_key = + List.rev + (fold_range_inclusive t ~min ~max ~init:[] ~f:(fun ~key ~data l -> (key,data)::l) + ~compare_key) + ;; + + let concat_unchecked t1 t2 = + match (t1, t2) with + | (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> + let (x, d) = min_elt_exn t2 in + bal t1 x d (remove_min_elt t2) + ;; + + let rec remove t x ~length ~compare_key = + match t with + | Empty -> (Empty, length) + | Leaf (v, _) -> + if compare_key x v = 0 then + (Empty, length - 1) + else (t, length) + | Node(l, v, d, r, _) -> + let c = compare_key x v in + if c = 0 then + (concat_unchecked l r, length - 1) + else if c < 0 then + let l, length = remove l x ~length ~compare_key in + (bal l v d r, length) + else + let r, length = remove r x ~length ~compare_key in + (bal l v d r, length) + ;; + + (* Use exception to avoid tree-rebuild in no-op case *) + exception Change_no_op + + let change t key ~f ~length ~compare_key = + let rec change_core t key f = + match t with + | Empty -> + begin match (f None) with + | None -> raise Change_no_op (* equivalent to returning: Empty *) + | Some data -> (Leaf(key, data), length + 1) + end + | Leaf(v, d) -> + let c = compare_key key v in + if c = 0 then + match f (Some d) with + | None -> (Empty, length - 1) + | Some d' -> (Leaf(v, d'), length) + else if c < 0 then + let l, length = change_core Empty key f in + (bal l v d Empty, length) + else + let r, length = change_core Empty key f in + (bal Empty v d r, length) + | Node(l, v, d, r, h) -> + let c = compare_key key v in + if c = 0 then + begin match (f (Some d)) with + | None -> (concat_unchecked l r, length - 1) + | Some data -> (Node(l, key, data, r, h), length) + end + else + if c < 0 then + let l, length = change_core l key f in + (bal l v d r, length) + else + let r, length = change_core r key f in + (bal l v d r, length) + in + try change_core t key f with Change_no_op -> (t, length) + ;; + + let remove_multi t key ~length ~compare_key = + change t key ~length ~compare_key ~f:(function + | None | Some ([] | [_]) -> None + | Some (_ :: ((_ :: _) as non_empty_tail)) -> Some non_empty_tail) + ;; + + let rec iter_keys t ~f = + match t with + | Empty -> () + | Leaf(v, _) -> f v + | Node(l, v, _, r, _) -> iter_keys ~f l; f v; iter_keys ~f r + ;; + + let rec iter t ~f = + match t with + | Empty -> () + | Leaf(_, d) -> f d + | Node(l, _, d, r, _) -> iter ~f l; f d; iter ~f r + ;; + + let rec iteri t ~f = + match t with + | Empty -> () + | Leaf(v, d) -> f ~key:v ~data:d + | Node(l, v, d, r, _) -> iteri ~f l; f ~key:v ~data:d; iteri ~f r + ;; + + let rec map t ~f = + match t with + | Empty -> Empty + | Leaf(v, d) -> Leaf(v, f d) + | Node(l, v, d, r, h) -> + let l' = map ~f l in + let d' = f d in + let r' = map ~f r in + Node(l', v, d', r', h) + ;; + + let rec mapi t ~f = + match t with + | Empty -> Empty + | Leaf(v, d) -> Leaf(v, f ~key:v ~data:d) + | Node(l, v, d, r, h) -> + let l' = mapi ~f l in + let d' = f ~key:v ~data:d in + let r' = mapi ~f r in + Node(l', v, d', r', h) + ;; + + let rec fold t ~init:accu ~f = + match t with + | Empty -> accu + | Leaf(v, d) -> f ~key:v ~data:d accu + | Node(l, v, d, r, _) -> fold ~f r ~init:(f ~key:v ~data:d (fold ~f l ~init:accu)) + ;; + + let rec fold_right t ~init:accu ~f = + match t with + | Empty -> accu + | Leaf(v, d) -> f ~key:v ~data:d accu + | Node(l, v, d, r, _) -> + fold_right ~f l ~init:(f ~key:v ~data:d (fold_right ~f r ~init:accu)) + ;; + + let filter_keys t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + if f key + then set ~length ~key ~data accu ~compare_key + else (accu, length)) + ;; + + let filter t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + if f data + then set ~length ~key ~data accu ~compare_key + else (accu, length)) + ;; + + let filteri t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + if f ~key ~data + then set ~length ~key ~data accu ~compare_key + else (accu, length)) + ;; + + let filter_map t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + match f data with + | None -> (accu, length) + | Some b -> set ~length ~key ~data:b accu ~compare_key) + ;; + + let filter_mapi t ~f ~compare_key = + fold ~init:(Empty, 0) t ~f:(fun ~key ~data (accu, length) -> + match f ~key ~data with + | None -> (accu, length) + | Some b -> set ~length ~key ~data:b accu ~compare_key) + ;; + + let partition_mapi t ~f ~compare_key = + fold t ~init:((Empty, 0), (Empty, 0)) ~f:(fun ~key ~data (pair1, pair2) -> + match f ~key ~data with + | `Fst x -> + let t, length = pair1 in + (set t ~key ~data:x ~compare_key ~length, pair2) + | `Snd y -> + let t, length = pair2 in + (pair1, set t ~key ~data:y ~compare_key ~length)) + ;; + + let partition_map t ~f ~compare_key = + partition_mapi t ~compare_key ~f:(fun ~key:_ ~data -> f data) + ;; + + let partitioni_tf t ~f ~compare_key = + partition_mapi t ~compare_key ~f:(fun ~key ~data -> + if f ~key ~data + then `Fst data + else `Snd data) + ;; + + let partition_tf t ~f ~compare_key = + partition_mapi t ~compare_key ~f:(fun ~key:_ ~data -> + if f data + then `Fst data + else `Snd data) + ;; + + module Enum = struct + type increasing + type decreasing + + type ('k, 'v, 'direction) t = + | End + | More of 'k * 'v * ('k, 'v) tree * ('k, 'v, 'direction) t + + let rec cons t (e : (_, _, increasing) t) : (_, _, increasing) t = + match t with + | Empty -> e + | Leaf(v, d) -> More(v, d, Empty, e) + | Node(l, v, d, r, _) -> cons l (More(v, d, r, e)) + ;; + + let rec cons_right t (e : (_, _, decreasing) t) : (_, _, decreasing) t = + match t with + | Empty -> e + | Leaf(v, d) -> More(v, d, Empty, e) + | Node(l, v, d, r, _) -> cons_right r (More(v, d, l, e)) + ;; + + let of_tree tree : (_, _, increasing) t = cons tree End + ;; + + let of_tree_right tree : (_, _, decreasing) t = cons_right tree End + ;; + + let starting_at_increasing t key compare : (_, _, increasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf(v, d) -> loop (Node(Empty, v, d, Empty, 1)) e + | Node(_, v, _, r, _) when compare v key < 0 -> loop r e + | Node(l, v, d, r, _) -> loop l (More(v, d, r, e)) + in + loop t End + ;; + + let starting_at_decreasing t key compare : (_, _, decreasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf(v, d) -> loop (Node(Empty, v, d, Empty, 1)) e + | Node(l, v, _, _, _) when compare v key > 0 -> loop l e + | Node(l, v, d, r, _) -> loop r (More(v, d, l, e)) + in + loop t End + ;; + + let compare compare_key compare_data t1 t2 = + let rec loop t1 t2 = + match t1, t2 with + | (End, End) -> 0 + | (End, _) -> -1 + | (_, End) -> 1 + | (More (v1, d1, r1, e1), More (v2, d2, r2, e2)) -> + let c = compare_key v1 v2 in + if c <> 0 then c else + let c = compare_data d1 d2 in + if c <> 0 then c else + if phys_equal r1 r2 + then loop e1 e2 + else loop (cons r1 e1) (cons r2 e2) + in + loop t1 t2 + ;; + + let equal compare_key data_equal t1 t2 = + let rec loop t1 t2 = + match t1, t2 with + | (End, End) -> true + | (End, _) | (_, End) -> false + | (More (v1, d1, r1, e1), More (v2, d2, r2, e2)) -> + compare_key v1 v2 = 0 + && data_equal d1 d2 + && (if phys_equal r1 r2 + then loop e1 e2 + else loop (cons r1 e1) (cons r2 e2)) + in + loop t1 t2 + ;; + + let rec fold ~init ~f = function + | End -> init + | More (key, data, tree, enum) -> + let next = f ~key ~data init in + fold (cons tree enum) ~init:next ~f + ;; + + let fold2 compare_key t1 t2 ~init ~f = + let rec loop t1 t2 curr = + match t1, t2 with + | End, End -> curr + | End, _ -> + fold t2 ~init:curr ~f:(fun ~key ~data acc -> f ~key ~data:(`Right data) acc) + | _ , End -> + fold t1 ~init:curr ~f:(fun ~key ~data acc -> f ~key ~data:(`Left data) acc) + | More (k1, v1, tree1, enum1), More (k2, v2, tree2, enum2) -> + let compare_result = compare_key k1 k2 in + if compare_result = 0 then begin + let next = f ~key:k1 ~data:(`Both (v1, v2)) curr in + loop (cons tree1 enum1) (cons tree2 enum2) next + end else if compare_result < 0 then begin + let next = f ~key:k1 ~data:(`Left v1) curr in + loop (cons tree1 enum1) t2 next + end else begin + let next = f ~key:k2 ~data:(`Right v2) curr in + loop t1 (cons tree2 enum2) next + end + in + loop t1 t2 init + ;; + + let symmetric_diff t1 t2 ~compare_key ~data_equal = + let step state = + match state with + | End, End -> + Sequence.Step.Done + | End, More (key, data, tree, enum) -> + Sequence.Step.Yield ((key, `Right data), (End, cons tree enum)) + | More (key, data, tree, enum), End -> + Sequence.Step.Yield ((key, `Left data), (cons tree enum, End)) + | (More (k1, v1, tree1, enum1) as left), (More (k2, v2, tree2, enum2) as right) -> + let compare_result = compare_key k1 k2 in + if compare_result = 0 then begin + let next_state = + if phys_equal tree1 tree2 + then (enum1, enum2) + else (cons tree1 enum1, cons tree2 enum2) + in + if data_equal v1 v2 + then Sequence.Step.Skip next_state + else Sequence.Step.Yield ((k1, `Unequal (v1, v2)), next_state) + end else if compare_result < 0 then begin + Sequence.Step.Yield ((k1, `Left v1), (cons tree1 enum1, right)) + end else begin + Sequence.Step.Yield ((k2, `Right v2), (left, (cons tree2 enum2))) + end + in + Sequence.unfold_step ~init:(of_tree t1, of_tree t2) ~f:step + ;; + + + end + + let to_sequence_increasing comparator ~from_key t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More(k,v,t,e) -> Sequence.Step.Yield((k,v), Enum.cons t e) + in + let init = + match from_key with + | None -> Enum.of_tree t + | Some key -> Enum.starting_at_increasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence_decreasing comparator ~from_key t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More(k,v,t,e) -> Sequence.Step.Yield((k,v), Enum.cons_right t e) + in + let init = + match from_key with + | None -> Enum.of_tree_right t + | Some key -> Enum.starting_at_decreasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence comparator ?(order=`Increasing_key) ?keys_greater_or_equal_to + ?keys_less_or_equal_to t = + let inclusive_bound side t bound = + let compare_key = comparator.Comparator.compare in + let l, maybe, r = split t bound ~compare_key in + let t = side (l, r) in + match maybe with + | None -> t + | Some (key, data) -> set' t key data ~compare_key + in + match order with + | `Increasing_key -> + let t = Option.fold keys_less_or_equal_to ~init:t ~f:(inclusive_bound fst) in + to_sequence_increasing comparator ~from_key:keys_greater_or_equal_to t + | `Decreasing_key -> + let t = Option.fold keys_greater_or_equal_to ~init:t ~f:(inclusive_bound snd) in + to_sequence_decreasing comparator ~from_key:keys_less_or_equal_to t + ;; + + let compare compare_key compare_data t1 t2 = + Enum.compare compare_key compare_data (Enum.of_tree t1) (Enum.of_tree t2) + ;; + + let equal compare_key compare_data t1 t2 = + Enum.equal compare_key compare_data (Enum.of_tree t1) (Enum.of_tree t2) + ;; + + let iter2 t1 t2 ~f ~compare_key = + Enum.fold2 compare_key (Enum.of_tree t1) (Enum.of_tree t2) + ~init:() + ~f:(fun ~key ~data () -> f ~key ~data) + ;; + + let fold2 t1 t2 ~init ~f ~compare_key = + Enum.fold2 compare_key (Enum.of_tree t1) (Enum.of_tree t2) ~f ~init + ;; + + let symmetric_diff = Enum.symmetric_diff + + let rec length = function + | Empty -> 0 + | Leaf _ -> 1 + | Node (l, _, _, r, _) -> length l + length r + 1 + ;; + + let hash_fold_t_ignoring_structure hash_fold_key hash_fold_data state t = + fold t ~init:(hash_fold_int state (length t)) + ~f:(fun ~key ~data state -> hash_fold_data (hash_fold_key state key) data) + ;; + + let of_alist_fold alist ~init ~f ~compare_key = + List.fold alist ~init:(empty, 0) + ~f:(fun (accum, length) (key, data) -> + let prev_data = + match find accum key ~compare_key with + | None -> init + | Some prev -> prev + in + let data = f prev_data data in + set accum ~length ~key ~data ~compare_key) + ;; + + let of_alist_reduce alist ~f ~compare_key = + List.fold alist ~init:(empty, 0) + ~f:(fun (accum, length) (key, data) -> + let new_data = + match find accum key ~compare_key with + | None -> data + | Some prev -> f prev data + in + set accum ~length ~key ~data:new_data ~compare_key) + ;; + + let keys t = fold_right ~f:(fun ~key ~data:_ list -> key::list) t ~init:[] + let data t = fold_right ~f:(fun ~key:_ ~data list -> data::list) t ~init:[] + + let of_alist alist ~compare_key = + with_return (fun r -> + let map = + List.fold alist ~init:(empty, 0) ~f:(fun (t, length) (key,data) -> + let ((_, length') as acc) = set ~length ~key ~data t ~compare_key in + if length = length' then r.return (`Duplicate_key key) + else acc) + in + `Ok map) + ;; + + let for_all t ~f = + with_return (fun r -> + iter t ~f:(fun data -> if not (f data) then r.return false); + true) + + let for_alli t ~f = + with_return (fun r -> + iteri t ~f:(fun ~key ~data -> if not (f ~key ~data) then r.return false); + true) + + let exists t ~f = + with_return (fun r -> + iter t ~f:(fun data -> if f data then r.return true); + false) + + let existsi t ~f = + with_return (fun r -> + iteri t ~f:(fun ~key ~data -> if f ~key ~data then r.return true); + false) + + let count t ~f = + fold t ~init:0 ~f:(fun ~key:_ ~data acc -> if f data then acc + 1 else acc) + + let counti t ~f = + fold t ~init:0 ~f:(fun ~key ~data acc -> if f ~key ~data then acc + 1 else acc) + + let of_alist_or_error alist ~comparator = + match of_alist alist ~compare_key:comparator.Comparator.compare with + | `Ok x -> Result.Ok x + | `Duplicate_key key -> + Or_error.error "Map.of_alist_or_error: duplicate key" key comparator.sexp_of_t + ;; + + let of_alist_exn alist ~comparator = + match of_alist alist ~compare_key:comparator.Comparator.compare with + | `Ok x -> x + | `Duplicate_key key -> + Error.create "Map.of_alist_exn: duplicate key" key comparator.sexp_of_t + |> Error.raise + ;; + + let of_alist_multi alist ~compare_key = + let alist = List.rev alist in + of_alist_fold alist ~init:[] ~f:(fun l x -> x :: l) ~compare_key + ;; + + let to_alist ?(key_order = `Increasing) t = + match key_order with + | `Increasing -> fold_right t ~init:[] ~f:(fun ~key ~data x -> (key, data) :: x) + | `Decreasing -> fold t ~init:[] ~f:(fun ~key ~data x -> (key, data) :: x) + ;; + + let merge t1 t2 ~f ~compare_key = + let elts = Uniform_array.unsafe_create_uninitialized ~len:(length t1 + length t2) in + let i = ref 0 in + iter2 t1 t2 ~compare_key ~f:(fun ~key ~data:values -> + match f ~key values with + | Some value -> Uniform_array.set elts !i (key, value); incr i + | None -> ()); + let len = !i in + let get i = Uniform_array.get elts i in + let tree = of_increasing_iterator_unchecked ~len ~f:get in + tree, len + ;; + + module Closest_key_impl = struct + (* [marker] and [repackage] allow us to create "logical" options without actually + allocating any options. Passing [Found key value] to a function is equivalent to + passing [Some (key, value)]; passing [Missing () ()] is equivalent to passing + [None]. *) + type ('k, 'v, 'k_opt, 'v_opt) marker = + | Missing : ('k, 'v, unit, unit) marker + | Found : ('k, 'v, 'k, 'v) marker + + let repackage (type k) (type v) (type k_opt) (type v_opt) + (marker : (k, v, k_opt, v_opt) marker) (k : k_opt) (v : v_opt) + : (k * v) option = + match marker with + | Missing -> None + | Found -> Some (k, v) + ;; + + (* The type signature is explicit here to allow polymorphic recursion. *) + let rec loop + : 'k 'v 'k_opt 'v_opt. + ('k, 'v) tree + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'k -> compare_key:('k -> 'k -> int) + -> ('k, 'v, 'k_opt, 'v_opt) marker -> 'k_opt -> 'v_opt + -> ('k * 'v) option + = fun t dir k ~compare_key found_marker found_key found_value -> + match t with + | Empty -> + repackage found_marker found_key found_value + | Leaf (k', v') -> + let c = compare_key k' k in + if + match dir with + | `Greater_or_equal_to -> c >= 0 + | `Greater_than -> c > 0 + | `Less_or_equal_to -> c <= 0 + | `Less_than -> c < 0 + then + Some (k', v') + else + repackage found_marker found_key found_value + | Node (l, k', v', r, _) -> + let c = compare_key k' k in + if c = 0 then begin + (* This is a base case (no recursive call). *) + match dir with + | `Greater_or_equal_to | `Less_or_equal_to -> + Some (k', v') + | `Greater_than -> + if is_empty r then + repackage found_marker found_key found_value + else + min_elt r + | `Less_than -> + if is_empty l then + repackage found_marker found_key found_value + else + max_elt l + end else begin + (* We are guaranteed here that k' <> k. *) + (* This is the only recursive case. *) + match dir with + | `Greater_or_equal_to | `Greater_than -> + if c > 0 + then loop l dir k ~compare_key Found k' v' + else loop r dir k ~compare_key found_marker found_key found_value + | `Less_or_equal_to | `Less_than -> + if c < 0 + then loop r dir k ~compare_key Found k' v' + else loop l dir k ~compare_key found_marker found_key found_value + end + ;; + + let closest_key t dir k ~compare_key = + loop t dir k ~compare_key Missing () () + ;; + end + let closest_key = Closest_key_impl.closest_key + + let rec rank t k ~compare_key = + match t with + | Empty -> None + | Leaf (k', _) -> if compare_key k' k = 0 then Some 0 else None + | Node (l, k', _, r, _) -> + let c = compare_key k' k in + if c = 0 + then Some (length l) + else if c > 0 + then rank l k ~compare_key + else Option.map (rank r k ~compare_key) ~f:(fun rank -> rank + 1 + (length l)) + ;; + + (* this could be implemented using [Sequence] interface but the following implementation + allocates only 2 words and doesn't require write-barrier *) + let rec nth' num_to_search = function + | Empty -> None + | Leaf (k, v) -> + if !num_to_search = 0 + then Some (k, v) + else begin + decr num_to_search; + None + end + | Node (l, k, v, r, _) -> + match nth' num_to_search l with + | (Some _) as some -> some + | None -> + if !num_to_search = 0 + then Some (k, v) + else begin + decr num_to_search; + nth' num_to_search r + end + + let nth t n = nth' (ref n) t + ;; + + type ('k, 'v) acc = + { mutable bad_key : 'k option + ; mutable map_length : (('k, 'v) t * int) + } + + let of_iteri ~iteri ~compare_key = + let acc = { bad_key = None; map_length = (empty, 0) } in + iteri ~f:(fun ~key ~data -> + let map, length = acc.map_length in + let ((_, length') as pair) = set ~length ~key ~data map ~compare_key in + if length = length' && Option.is_none acc.bad_key + then acc.bad_key <- Some key + else acc.map_length <- pair); + match acc.bad_key with + | None -> `Ok acc.map_length + | Some key -> `Duplicate_key key + ;; + + let t_of_sexp_direct key_of_sexp value_of_sexp sexp ~comparator = + let alist = list_of_sexp (pair_of_sexp key_of_sexp value_of_sexp) sexp in + of_alist_exn alist ~comparator + ;; + + let sexp_of_t sexp_of_key sexp_of_value t = + let f ~key ~data acc = Sexp.List [sexp_of_key key; sexp_of_value data] :: acc in + Sexp.List (fold_right ~f t ~init:[]) + ;; +end + +type ('k, 'v, 'comparator) t = + { (* [comparator] is the first field so that polymorphic equality fails on a map due + to the functional value in the comparator. + Note that this does not affect polymorphic [compare]: that still produces + nonsense. *) + comparator : ('k, 'comparator) Comparator.t; + tree : ('k, 'v) Tree0.t; + length : int; + } + +type ('k, 'v, 'comparator) tree = ('k, 'v) Tree0.t + +let compare_key t = t.comparator.Comparator.compare + +let like {tree = _; length = _; comparator} (tree, length) = {tree; length; comparator} +let like2 x (y,z) = like x y, like x z +let with_same_length { tree = _; comparator; length } tree = + { tree; comparator; length } + +let of_tree ~comparator tree = { tree; comparator; length = Tree0.length tree} + +(* Exposing this function would make it very easy for the invariants + of this module to be broken. *) +let of_tree_unsafe ~comparator ~length tree = {tree; comparator; length} + +module Accessors = struct + let comparator t = t.comparator + let to_tree t = t.tree + let invariants t = Tree0.invariants t.tree ~compare_key:(compare_key t) + let is_empty t = Tree0.is_empty t.tree + let length t = t.length + let set t ~key ~data = + like t (Tree0.set t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t)) + ;; + let add_exn t ~key ~data = + like t (Tree0.add_exn t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t) + ~sexp_of_key:t.comparator.sexp_of_t) + ;; + let add_exn_internal t ~key ~data = + like t (Tree0.add_exn_internal + t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t) + ~sexp_of_key:t.comparator.sexp_of_t) + ;; + let add t ~key ~data = + match add_exn_internal t ~key ~data with + | result -> `Ok result + | exception Duplicate -> `Duplicate + ;; + let add_multi t ~key ~data = + like t + (Tree0.add_multi t.tree ~length:t.length ~key ~data ~compare_key:(compare_key t)) + ;; + let remove_multi t key = + like t (Tree0.remove_multi t.tree ~length:t.length key ~compare_key:(compare_key t)) + ;; + let find_multi t key = Tree0.find_multi t.tree key ~compare_key:(compare_key t) + let change t key ~f = + like t (Tree0.change t.tree key ~f ~length:t.length ~compare_key:(compare_key t)) + ;; + let update t key ~f = change t key ~f:(fun data -> Some (f data)) + let find_exn t key = Tree0.find_exn t.tree key ~compare_key:(compare_key t) + let find t key = Tree0.find t.tree key ~compare_key:(compare_key t) + let remove t key = + like t (Tree0.remove t.tree key ~length:t.length ~compare_key:(compare_key t)) + ;; + let mem t key = Tree0.mem t.tree key ~compare_key:(compare_key t) + ;; + let iter_keys t ~f = Tree0.iter_keys t.tree ~f + let iter t ~f = Tree0.iter t.tree ~f + let iteri t ~f = Tree0.iteri t.tree ~f + let iter2 t1 t2 ~f = Tree0.iter2 t1.tree t2.tree ~f ~compare_key:(compare_key t1) + ;; + let map t ~f = with_same_length t (Tree0.map t.tree ~f) + let mapi t ~f = with_same_length t (Tree0.mapi t.tree ~f) + let fold t ~init ~f = Tree0.fold t.tree ~f ~init + let fold_right t ~init ~f = Tree0.fold_right t.tree ~f ~init + let fold2 t1 t2 ~init ~f = + Tree0.fold2 t1.tree t2.tree ~init ~f ~compare_key:(compare_key t1) + ;; + let filter_keys t ~f = like t (Tree0.filter_keys t.tree ~f ~compare_key:(compare_key t)) + let filter t ~f = like t (Tree0.filter t.tree ~f ~compare_key:(compare_key t)) + let filteri t ~f = like t (Tree0.filteri t.tree ~f ~compare_key:(compare_key t)) + let filter_map t ~f = like t (Tree0.filter_map t.tree ~f ~compare_key:(compare_key t)) + let filter_mapi t ~f = like t (Tree0.filter_mapi t.tree ~f ~compare_key:(compare_key t)) + ;; + let partition_mapi t ~f = + like2 t (Tree0.partition_mapi t.tree ~f ~compare_key:(compare_key t)) + ;; + let partition_map t ~f = + like2 t (Tree0.partition_map t.tree ~f ~compare_key:(compare_key t)) + ;; + let partitioni_tf t ~f = + like2 t (Tree0.partitioni_tf t.tree ~f ~compare_key:(compare_key t)) + ;; + let partition_tf t ~f = + like2 t (Tree0.partition_tf t.tree ~f ~compare_key:(compare_key t)) + ;; + + let compare_direct compare_data t1 t2 = + Tree0.compare (compare_key t1) compare_data t1.tree t2.tree + ;; + let equal compare_data t1 t2 = + Tree0.equal (compare_key t1) compare_data t1.tree t2.tree + ;; + let keys t = Tree0.keys t.tree + let data t = Tree0.data t.tree + let to_alist ?key_order t = Tree0.to_alist ?key_order t.tree + let validate ~name f t = Validate.alist ~name f (to_alist t) + let symmetric_diff t1 t2 ~data_equal = + Tree0.symmetric_diff t1.tree t2.tree ~compare_key:(compare_key t1) ~data_equal + ;; + let merge t1 t2 ~f = + like t1 (Tree0.merge t1.tree t2.tree ~f ~compare_key:(compare_key t1)) + ;; + let min_elt t = Tree0.min_elt t.tree + let min_elt_exn t = Tree0.min_elt_exn t.tree + let max_elt t = Tree0.max_elt t.tree + let max_elt_exn t = Tree0.max_elt_exn t.tree + let for_all t ~f = Tree0.for_all t.tree ~f + let for_alli t ~f = Tree0.for_alli t.tree ~f + let exists t ~f = Tree0.exists t.tree ~f + let existsi t ~f = Tree0.existsi t.tree ~f + let count t ~f = Tree0.count t.tree ~f + let counti t ~f = Tree0.counti t.tree ~f + let split t k = + let l, maybe, r = Tree0.split t.tree k ~compare_key:(compare_key t) in + let comparator = comparator t in + (* Try to traverse the least amount possible to calculate the length, + using height as a heuristic. *) + let both_len = if Option.is_some maybe then t.length - 1 else t.length in + if Tree0.height l < Tree0.height r then + let l = of_tree l ~comparator in + l, maybe, of_tree_unsafe r ~comparator ~length:(both_len - length l) + else + let r = of_tree r ~comparator in + of_tree_unsafe l ~comparator ~length:(both_len - length r), maybe, r + ;; + let subrange t ~lower_bound ~upper_bound = + let left, mid, right = + Tree0.split_range t.tree ~lower_bound ~upper_bound ~compare_key:(compare_key t) + in + (* Try to traverse the least amount possible to calculate the length, + using height as a heuristic. *) + let outer_joined_height = + let h_l = Tree0.height left + and h_r = Tree0.height right in + if h_l = h_r then + h_l + 1 + else + max h_l h_r + in + if outer_joined_height < Tree0.height mid then + let mid_length = t.length - (Tree0.length left + Tree0.length right) in + of_tree_unsafe mid ~comparator:(comparator t) ~length:mid_length + else + of_tree mid ~comparator:(comparator t) + ;; + let append ~lower_part ~upper_part = + match Tree0.append ~compare_key:(compare_key lower_part) ~lower_part:lower_part.tree ~upper_part:upper_part.tree with + | `Ok tree -> `Ok (of_tree_unsafe tree ~comparator:(comparator lower_part) ~length:(lower_part.length + upper_part.length)) + | `Overlapping_key_ranges -> `Overlapping_key_ranges + ;; + let fold_range_inclusive t ~min ~max ~init ~f = + Tree0.fold_range_inclusive t.tree ~min ~max ~init ~f ~compare_key:(compare_key t) + ;; + let range_to_alist t ~min ~max = + Tree0.range_to_alist t.tree ~min ~max ~compare_key:(compare_key t) + ;; + let closest_key t dir key = Tree0.closest_key t.tree dir key ~compare_key:(compare_key t) + let nth t n = Tree0.nth t.tree n + let nth_exn t n = Option.value_exn (nth t n) + let rank t key = Tree0.rank t.tree key ~compare_key:(compare_key t) + let sexp_of_t sexp_of_k sexp_of_v _ t = Tree0.sexp_of_t sexp_of_k sexp_of_v t.tree + let to_sequence ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to t = + Tree0.to_sequence t.comparator ?order ?keys_greater_or_equal_to + ?keys_less_or_equal_to t.tree + + let hash_fold_direct hash_fold_key hash_fold_data state t = + Tree0.hash_fold_t_ignoring_structure hash_fold_key hash_fold_data state t.tree +end + +(* [0] is used as the [length] argument everywhere in this module, since trees do not + have their lengths stored at the root, unlike maps. The values are discarded always. *) +module Tree = struct + type ('k, 'v, 'comparator) t = ('k, 'v, 'comparator) tree + + let empty_without_value_restriction = Tree0.empty + let empty ~comparator:_ = empty_without_value_restriction + let of_tree ~comparator:_ tree = tree + let singleton ~comparator:_ k v = Tree0.singleton k v + let of_sorted_array_unchecked ~comparator array = + fst (Tree0.of_sorted_array_unchecked array ~compare_key:comparator.Comparator.compare) + ;; + let of_sorted_array ~comparator array = + Tree0.of_sorted_array array ~compare_key:comparator.Comparator.compare + |> Or_error.map ~f:fst + ;; + let of_alist ~comparator alist = + match Tree0.of_alist alist ~compare_key:comparator.Comparator.compare with + | `Duplicate_key _ as d -> d + | `Ok (tree, _size) -> `Ok tree + ;; + let of_alist_or_error ~comparator alist = + Tree0.of_alist_or_error alist ~comparator + |> Or_error.map ~f:fst + let of_alist_exn ~comparator alist = fst (Tree0.of_alist_exn alist ~comparator) + let of_alist_multi ~comparator alist = + fst (Tree0.of_alist_multi alist ~compare_key:comparator.Comparator.compare) + ;; + let of_alist_fold ~comparator alist ~init ~f = + fst (Tree0.of_alist_fold alist ~init ~f ~compare_key:comparator.Comparator.compare) + ;; + let of_alist_reduce ~comparator alist ~f = + fst (Tree0.of_alist_reduce alist ~f ~compare_key:comparator.Comparator.compare) + ;; + let of_iteri ~comparator ~iteri = + match Tree0.of_iteri ~iteri ~compare_key:comparator.Comparator.compare with + | `Ok (tree, _size) -> `Ok tree + | `Duplicate_key _ as d -> d + ;; + let of_increasing_iterator_unchecked ~comparator:_required_by_intf ~len ~f = + Tree0.of_increasing_iterator_unchecked ~len ~f + ;; + let of_increasing_sequence ~comparator seq = + Or_error.map ~f:fst (Tree0.of_increasing_sequence seq ~compare_key:comparator.Comparator.compare) + ;; + + let to_tree t = t + let invariants ~comparator t = + Tree0.invariants t ~compare_key:comparator.Comparator.compare + ;; + let is_empty t = Tree0.is_empty t + let length t = Tree0.length t + let set ~comparator t ~key ~data = + fst (Tree0.set t ~key ~data ~length:0 ~compare_key:comparator.Comparator.compare) + ;; + let add_exn ~comparator t ~key ~data = + fst (Tree0.add_exn t ~key ~data ~length:0 ~compare_key:comparator.Comparator.compare + ~sexp_of_key:comparator.sexp_of_t) + ;; + let add ~comparator t ~key ~data = + try + `Ok (add_exn t ~comparator ~key ~data) + with _ -> + `Duplicate + ;; + let add_multi ~comparator t ~key ~data = + Tree0.add_multi t ~key ~data ~length:0 ~compare_key:comparator.Comparator.compare + |> fst + ;; + let remove_multi ~comparator t key = + Tree0.remove_multi t key ~length:0 ~compare_key:comparator.Comparator.compare + |> fst + ;; + let find_multi ~comparator t key = + Tree0.find_multi t key ~compare_key:comparator.Comparator.compare + ;; + let change ~comparator t key ~f = + fst (Tree0.change t key ~f ~length:0 ~compare_key:comparator.Comparator.compare) + ;; + let update ~comparator t key ~f = change ~comparator t key ~f:(fun data -> Some (f data)) + let find_exn ~comparator t key = + Tree0.find_exn t key ~compare_key:comparator.Comparator.compare + ;; + let find ~comparator t key = + Tree0.find t key ~compare_key:comparator.Comparator.compare + ;; + let remove ~comparator t key = + fst (Tree0.remove t key ~length:0 ~compare_key:comparator.Comparator.compare) + ;; + let mem ~comparator t key = Tree0.mem t key ~compare_key:comparator.Comparator.compare + ;; + let iter_keys t ~f = Tree0.iter_keys t ~f + let iter t ~f = Tree0.iter t ~f + let iteri t ~f = Tree0.iteri t ~f + let iter2 ~comparator t1 t2 ~f = + Tree0.iter2 t1 t2 ~f ~compare_key:comparator.Comparator.compare + ;; + let map t ~f = Tree0.map t ~f + let mapi t ~f = Tree0.mapi t ~f + let fold t ~init ~f = Tree0.fold t ~f ~init + let fold_right t ~init ~f = Tree0.fold_right t ~f ~init + let fold2 ~comparator t1 t2 ~init ~f = + Tree0.fold2 t1 t2 ~init ~f ~compare_key:comparator.Comparator.compare + ;; + let filter_keys ~comparator t ~f = + fst (Tree0.filter_keys t ~f ~compare_key:comparator.Comparator.compare) + let filter ~comparator t ~f = + fst (Tree0.filter t ~f ~compare_key:comparator.Comparator.compare) + let filteri ~comparator t ~f = + fst (Tree0.filteri t ~f ~compare_key:comparator.Comparator.compare) + let filter_map ~comparator t ~f = + fst (Tree0.filter_map t ~f ~compare_key:comparator.Comparator.compare) + let filter_mapi ~comparator t ~f = + fst (Tree0.filter_mapi t ~f ~compare_key:comparator.Comparator.compare) + ;; + let partition_mapi ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partition_mapi t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let partition_map ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partition_map t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let partitioni_tf ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partitioni_tf t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let partition_tf ~comparator t ~f = + let (a, _), (b, _) = + Tree0.partition_tf t ~f ~compare_key:comparator.Comparator.compare + in + (a, b) + ;; + let compare_direct ~comparator compare_data t1 t2 = + Tree0.compare comparator.Comparator.compare compare_data t1 t2 + ;; + let equal ~comparator compare_data t1 t2 = + Tree0.equal comparator.Comparator.compare compare_data t1 t2 + ;; + let keys t = Tree0.keys t + let data t = Tree0.data t + let to_alist ?key_order t = Tree0.to_alist ?key_order t + let validate ~name f t = Validate.alist ~name f (to_alist t) + let symmetric_diff ~comparator t1 t2 ~data_equal = + Tree0.symmetric_diff t1 t2 ~compare_key:comparator.Comparator.compare ~data_equal + ;; + let merge ~comparator t1 t2 ~f = + fst (Tree0.merge t1 t2 ~f ~compare_key:comparator.Comparator.compare) + ;; + let min_elt t = Tree0.min_elt t + let min_elt_exn t = Tree0.min_elt_exn t + let max_elt t = Tree0.max_elt t + let max_elt_exn t = Tree0.max_elt_exn t + let for_all t ~f = Tree0.for_all t ~f + let for_alli t ~f = Tree0.for_alli t ~f + let exists t ~f = Tree0.exists t ~f + let existsi t ~f = Tree0.existsi t ~f + let count t ~f = Tree0.count t ~f + let counti t ~f = Tree0.counti t ~f + let split ~comparator t k = Tree0.split t k ~compare_key:comparator.Comparator.compare + let append ~comparator ~lower_part ~upper_part = + Tree0.append ~lower_part ~upper_part ~compare_key:comparator.Comparator.compare + let subrange ~comparator t ~lower_bound ~upper_bound = + let (_, ret, _) = + Tree0.split_range t ~lower_bound ~upper_bound + ~compare_key:comparator.Comparator.compare in + ret + ;; + let fold_range_inclusive ~comparator t ~min ~max ~init ~f = + Tree0.fold_range_inclusive t ~min ~max ~init ~f + ~compare_key:comparator.Comparator.compare + ;; + let range_to_alist ~comparator t ~min ~max = + Tree0.range_to_alist t ~min ~max ~compare_key:comparator.Comparator.compare + ;; + let closest_key ~comparator t dir key = + Tree0.closest_key t dir key ~compare_key:comparator.Comparator.compare + ;; + let nth ~comparator:_ t n = Tree0.nth t n + let nth_exn ~comparator t n = Option.value_exn (nth ~comparator t n) + let rank ~comparator t key = Tree0.rank t key ~compare_key:comparator.Comparator.compare + let sexp_of_t sexp_of_k sexp_of_v _ t = Tree0.sexp_of_t sexp_of_k sexp_of_v t + let t_of_sexp_direct ~comparator k_of_sexp v_of_sexp sexp = + fst (Tree0.t_of_sexp_direct k_of_sexp v_of_sexp sexp ~comparator) + ;; + + let to_sequence ~comparator ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to t + = + Tree0.to_sequence comparator ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to + t +end + +module Using_comparator = struct + type nonrec ('k, 'v, 'cmp) t = ('k, 'v, 'cmp) t + + include Accessors + + let empty ~comparator = { tree = Tree0.empty; comparator; length = 0 } + + let singleton ~comparator k v = { comparator; tree = Tree0.singleton k v; length = 1 } + + let of_tree0 ~comparator (tree, length) = + { comparator; tree; length } + + let of_tree ~comparator tree = of_tree0 ~comparator (tree, Tree0.length tree) + let to_tree = to_tree + + let of_sorted_array_unchecked ~comparator array = + of_tree0 ~comparator + (Tree0.of_sorted_array_unchecked array ~compare_key:comparator.Comparator.compare) + ;; + + let of_sorted_array ~comparator array = + Or_error.map (Tree0.of_sorted_array array ~compare_key:comparator.Comparator.compare) + ~f:(fun tree -> of_tree0 ~comparator tree) + ;; + + let of_alist ~comparator alist = + match Tree0.of_alist alist ~compare_key:comparator.Comparator.compare with + | `Ok (tree, length) -> `Ok { comparator; tree; length } + | `Duplicate_key _ as z -> z + ;; + + let of_alist_or_error ~comparator alist = + Result.map (Tree0.of_alist_or_error alist ~comparator) + ~f:(fun tree -> of_tree0 ~comparator tree) + ;; + + let of_alist_exn ~comparator alist = + of_tree0 ~comparator (Tree0.of_alist_exn alist ~comparator) + ;; + + let of_alist_multi ~comparator alist = + of_tree0 ~comparator + (Tree0.of_alist_multi alist ~compare_key:comparator.Comparator.compare) + ;; + + let of_alist_fold ~comparator alist ~init ~f = + of_tree0 ~comparator + (Tree0.of_alist_fold alist ~init ~f ~compare_key:comparator.Comparator.compare) + ;; + + let of_alist_reduce ~comparator alist ~f = + of_tree0 ~comparator + (Tree0.of_alist_reduce alist ~f ~compare_key:comparator.Comparator.compare) + ;; + + let of_iteri ~comparator ~iteri = + match Tree0.of_iteri ~compare_key:comparator.Comparator.compare ~iteri with + | `Ok tree_length -> `Ok (of_tree0 ~comparator tree_length) + | `Duplicate_key _ as z -> z + ;; + + let of_increasing_iterator_unchecked ~comparator ~len ~f = + of_tree0 ~comparator (Tree0.of_increasing_iterator_unchecked ~len ~f, len) + + let of_increasing_sequence ~comparator seq = + Or_error.map ~f:(of_tree0 ~comparator) + (Tree0.of_increasing_sequence seq ~compare_key:comparator.Comparator.compare) + ;; + + let t_of_sexp_direct ~comparator k_of_sexp v_of_sexp sexp = + of_tree0 ~comparator (Tree0.t_of_sexp_direct k_of_sexp v_of_sexp sexp ~comparator) + ;; + + module Empty_without_value_restriction(K : Comparator.S1) = struct + let empty = { tree = Tree0.empty; comparator = K.comparator; length = 0 } + end + + module Tree = Tree +end + +include Accessors + +type ('k, 'cmp) comparator = + (module Comparator.S with type t = 'k and type comparator_witness = 'cmp) + +let comparator_s (type k cmp) t : (k, cmp) comparator = + (module struct + type t = k + type comparator_witness = cmp + let comparator = t.comparator + end) + +let to_comparator (type k cmp) ((module M) : (k, cmp) comparator) = M.comparator + +let empty m = Using_comparator.empty ~comparator:(to_comparator m) +let singleton m a = Using_comparator.singleton ~comparator:(to_comparator m) a +let of_alist m a = Using_comparator.of_alist ~comparator:(to_comparator m) a +let of_alist_or_error m a = Using_comparator.of_alist_or_error ~comparator:(to_comparator m) a +let of_alist_exn m a = Using_comparator.of_alist_exn ~comparator:(to_comparator m) a +let of_alist_multi m a = Using_comparator.of_alist_multi ~comparator:(to_comparator m) a +let of_alist_fold m a ~init ~f = Using_comparator.of_alist_fold ~comparator:(to_comparator m) a ~init ~f +let of_alist_reduce m a ~f = Using_comparator.of_alist_reduce ~comparator:(to_comparator m) a ~f +let of_sorted_array_unchecked m a = Using_comparator.of_sorted_array_unchecked ~comparator:(to_comparator m) a +let of_sorted_array m a = Using_comparator.of_sorted_array ~comparator:(to_comparator m) a +let of_iteri m ~iteri = Using_comparator.of_iteri ~iteri ~comparator:(to_comparator m) +let of_increasing_iterator_unchecked m ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~len ~f ~comparator:(to_comparator m) +let of_increasing_sequence m seq = Using_comparator.of_increasing_sequence ~comparator:(to_comparator m) seq + +module M(K : sig type t type comparator_witness end) = struct + type nonrec 'v t = (K.t, 'v, K.comparator_witness) t +end +module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t +end +module type Compare_m = sig end +module type Hash_fold_m = Hasher.S + +let sexp_of_m__t (type k) (module K : Sexp_of_m with type t = k) sexp_of_v t = + sexp_of_t K.sexp_of_t sexp_of_v (fun _ -> Sexp.Atom "_") t + +let m__t_of_sexp (type k cmp) + (module K : M_of_sexp with type t = k and type comparator_witness = cmp) + v_of_sexp sexp = + Using_comparator.t_of_sexp_direct ~comparator:K.comparator K.t_of_sexp v_of_sexp sexp + +let compare_m__t (module K : Compare_m) compare_v t1 t2 = + compare_direct compare_v t1 t2 + +let hash_fold_m__t (type k) (module K : Hash_fold_m with type t = k) hash_fold_v state = + hash_fold_direct K.hash_fold_t hash_fold_v state + +let merge_skewed t1 t2 ~combine = + let t1, t2, combine = + if length t2 <= length t1 + then t1, t2, combine + else t2, t1, (fun ~key v1 v2 -> combine ~key v2 v1) + in + fold t2 ~init:t1 ~f:(fun ~key ~data:v2 t1 -> + change t1 key ~f:(function + | None -> Some v2 + | Some v1 -> Some (combine ~key v1 v2))) + +module Poly = struct + type nonrec ('k, 'v) t = ('k, 'v, Comparator.Poly.comparator_witness) t + type nonrec ('k, 'v) tree = ('k, 'v) Tree0.t + + include Accessors + + let comparator = Comparator.Poly.comparator + + let of_tree tree = { tree; comparator; length = Tree0.length tree} + + include Using_comparator.Empty_without_value_restriction(Comparator.Poly) + + let singleton a = Using_comparator.singleton ~comparator a + let of_alist a = Using_comparator.of_alist ~comparator a + let of_alist_or_error a = Using_comparator.of_alist_or_error ~comparator a + let of_alist_exn a = Using_comparator.of_alist_exn ~comparator a + let of_alist_multi a = Using_comparator.of_alist_multi ~comparator a + let of_alist_fold a ~init ~f = Using_comparator.of_alist_fold ~comparator a ~init ~f + let of_alist_reduce a ~f = Using_comparator.of_alist_reduce ~comparator a ~f + let of_sorted_array_unchecked a = Using_comparator.of_sorted_array_unchecked ~comparator a + let of_sorted_array a = Using_comparator.of_sorted_array ~comparator a + let of_iteri ~iteri = Using_comparator.of_iteri ~iteri ~comparator + let of_increasing_iterator_unchecked ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~len ~f ~comparator + let of_increasing_sequence seq = Using_comparator.of_increasing_sequence ~comparator seq +end diff --git a/src/map.mli b/src/map.mli new file mode 100644 index 0000000..2bee42d --- /dev/null +++ b/src/map.mli @@ -0,0 +1 @@ +include Map_intf.Map (** @inline *) diff --git a/src/map_intf.ml b/src/map_intf.ml new file mode 100644 index 0000000..58ebd99 --- /dev/null +++ b/src/map_intf.ml @@ -0,0 +1,1936 @@ +open! Import +open! T + +module Or_duplicate = struct + type 'a t = [ `Ok of 'a | `Duplicate ] + [@@deriving_inline sexp_of] + let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | `Ok v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Ok"; _of_a v0] + | `Duplicate -> Ppx_sexp_conv_lib.Sexp.Atom "Duplicate" + [@@@end] +end + +module Without_comparator = struct + type ('key, 'cmp, 'z) t = 'z +end + +module With_comparator = struct + type ('key, 'cmp, 'z) t = comparator:('key, 'cmp) Comparator.t -> 'z +end + +module With_first_class_module = struct + type ('key, 'cmp, 'z) t = + (module Comparator.S with type t = 'key and type comparator_witness = 'cmp) -> 'z +end + +module Symmetric_diff_element = struct + type ('k, 'v) t = 'k * [ `Left of 'v | `Right of 'v | `Unequal of 'v * 'v ] + [@@deriving_inline compare, sexp] + let compare : + 'k 'v . + ('k -> 'k -> int) -> ('v -> 'v -> int) -> ('k, 'v) t -> ('k, 'v) t -> int + = + fun _cmp__k -> + fun _cmp__v -> + fun a__001_ -> + fun b__002_ -> + let (t__003_, t__004_) = a__001_ in + let (t__005_, t__006_) = b__002_ in + match _cmp__k t__003_ t__005_ with + | 0 -> + if Ppx_compare_lib.phys_equal t__004_ t__006_ + then 0 + else + (match (t__004_, t__006_) with + | (`Left _left__007_, `Left _right__008_) -> + _cmp__v _left__007_ _right__008_ + | (`Right _left__009_, `Right _right__010_) -> + _cmp__v _left__009_ _right__010_ + | (`Unequal _left__011_, `Unequal _right__012_) -> + let (t__013_, t__014_) = _left__011_ in + let (t__015_, t__016_) = _right__012_ in + (match _cmp__v t__013_ t__015_ with + | 0 -> _cmp__v t__014_ t__016_ + | n -> n) + | (x, y) -> Ppx_compare_lib.polymorphic_compare x y) + | n -> n + let t_of_sexp : + 'k 'v . + (Ppx_sexp_conv_lib.Sexp.t -> 'k) -> + (Ppx_sexp_conv_lib.Sexp.t -> 'v) -> + Ppx_sexp_conv_lib.Sexp.t -> ('k, 'v) t + = + let _tp_loc = "src/map_intf.ml.Symmetric_diff_element.t" in + fun _of_k -> + fun _of_v -> + function + | Ppx_sexp_conv_lib.Sexp.List (v0::v1::[]) -> + let v0 = _of_k v0 + and v1 = + (fun sexp -> + try + match sexp with + | Ppx_sexp_conv_lib.Sexp.Atom atom as _sexp -> + (match atom with + | "Left" -> + Ppx_sexp_conv_lib.Conv_error.ptag_takes_args + _tp_loc _sexp + | "Right" -> + Ppx_sexp_conv_lib.Conv_error.ptag_takes_args + _tp_loc _sexp + | "Unequal" -> + Ppx_sexp_conv_lib.Conv_error.ptag_takes_args + _tp_loc _sexp + | _ -> Ppx_sexp_conv_lib.Conv_error.no_variant_match ()) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + atom)::sexp_args) as _sexp -> + (match atom with + | "Left" as _tag -> + (match sexp_args with + | v0::[] -> let v0 = _of_v v0 in `Left v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.ptag_incorrect_n_args + _tp_loc _tag _sexp) + | "Right" as _tag -> + (match sexp_args with + | v0::[] -> let v0 = _of_v v0 in `Right v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.ptag_incorrect_n_args + _tp_loc _tag _sexp) + | "Unequal" as _tag -> + (match sexp_args with + | v0::[] -> + let v0 = + match v0 with + | Ppx_sexp_conv_lib.Sexp.List (v0::v1::[]) + -> + let v0 = _of_v v0 + and v1 = _of_v v1 in (v0, v1) + | sexp -> + Ppx_sexp_conv_lib.Conv_error.tuple_of_size_n_expected + _tp_loc 2 sexp in + `Unequal v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.ptag_incorrect_n_args + _tp_loc _tag _sexp) + | _ -> Ppx_sexp_conv_lib.Conv_error.no_variant_match ()) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List + _)::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_poly_var + _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_poly_var + _tp_loc sexp + with + | Ppx_sexp_conv_lib.Conv_error.No_variant_match -> + Ppx_sexp_conv_lib.Conv_error.no_matching_variant_found + _tp_loc sexp) v1 in + (v0, v1) + | sexp -> + Ppx_sexp_conv_lib.Conv_error.tuple_of_size_n_expected _tp_loc 2 + sexp + let sexp_of_t : + 'k 'v . + ('k -> Ppx_sexp_conv_lib.Sexp.t) -> + ('v -> Ppx_sexp_conv_lib.Sexp.t) -> + ('k, 'v) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_k -> + fun _of_v -> + function + | (v0, v1) -> + let v0 = _of_k v0 + and v1 = + match v1 with + | `Left v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Left"; _of_v v0] + | `Right v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Right"; _of_v v0] + | `Unequal v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Unequal"; + (let (v0, v1) = v0 in + let v0 = _of_v v0 + and v1 = _of_v v1 in Ppx_sexp_conv_lib.Sexp.List [v0; v1])] in + Ppx_sexp_conv_lib.Sexp.List [v0; v1] + [@@@end] +end + +module type Accessors_generic = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + type 'a key + type ('a, 'cmp, 'z) options + + val invariants + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> bool + ) options + + val is_empty : (_, _, _) t -> bool + + val length : (_, _, _) t -> int + + val add + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> key:'k key -> data:'v -> ('k, 'v, 'cmp) t Or_duplicate.t + ) options + + val add_exn + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> key:'k key -> data:'v -> ('k, 'v, 'cmp) t + ) options + + val set + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> key:'k key -> data:'v -> ('k, 'v, 'cmp) t + ) options + + val add_multi + : ('k, 'cmp, + ('k, 'v list, 'cmp) t + -> key:'k key + -> data:'v + -> ('k, 'v list, 'cmp) t + ) options + + val remove_multi + : ('k, 'cmp, + ('k, 'v list, 'cmp) t + -> 'k key + -> ('k, 'v list, 'cmp) t + ) options + + val find_multi + : ('k, 'cmp, + ('k, 'v list, 'cmp) t + -> 'k key -> 'v list + ) options + + val change + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> 'k key + -> f:('v option -> 'v option) + -> ('k, 'v, 'cmp) t + ) options + + val update + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> 'k key + -> f:('v option -> 'v) + -> ('k, 'v, 'cmp) t + ) options + + val find : ('k, 'cmp, ('k, 'v, 'cmp) t -> 'k key -> 'v option) options + val find_exn : ('k, 'cmp, ('k, 'v, 'cmp) t -> 'k key -> 'v ) options + + val remove + : ('k, 'cmp, ('k, 'v, 'cmp) t -> 'k key -> ('k, 'v, 'cmp) t + ) options + + val mem : ('k, 'cmp, ('k, _, 'cmp) t -> 'k key -> bool) options + + val iter_keys : ('k, _, _) t -> f:('k key -> unit) -> unit + val iter : ( _, 'v, _) t -> f:('v -> unit) -> unit + val iteri : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> unit) -> unit + + val iter2 + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k key + -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> unit) + -> unit + ) options + + val map : ('k, 'v1, 'cmp) t -> f:('v1 -> 'v2) -> ('k, 'v2, 'cmp) t + + val mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k key -> data:'v1 -> 'v2) + -> ('k, 'v2, 'cmp) t + + val fold : ('k, 'v, _) t -> init:'a -> f:(key:'k key -> data:'v -> 'a -> 'a) -> 'a + val fold_right : ('k, 'v, _) t -> init:'a -> f:(key:'k key -> data:'v -> 'a -> 'a) -> 'a + + val fold2 + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> init:'a + -> f:(key:'k key + -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> 'a + -> 'a) + -> 'a + ) options + + val filter_keys + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:('k key -> bool) + -> ('k, 'v, 'cmp) t + ) options + + val filter + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:('v -> bool) + -> ('k, 'v, 'cmp) t + ) options + + val filteri + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:(key:'k key -> data:'v -> bool) + -> ('k, 'v, 'cmp) t + ) options + + val filter_map + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:('v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + ) options + + val filter_mapi + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:(key:'k key -> data:'v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + ) options + + val partition_mapi + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:(key:'k key -> data:'v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + ) options + + val partition_map + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> f:('v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + ) options + + val partitioni_tf + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:(key:'k key -> data:'v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + ) options + + val partition_tf + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> f:('v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + ) options + + val compare_direct + : ('k, 'cmp, + ('v -> 'v -> int) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> int + ) options + + val equal + : ('k, 'cmp, + ('v -> 'v -> bool) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> bool + ) options + + val keys : ('k, _, _) t -> 'k key list + + val data : (_, 'v, _) t -> 'v list + + val to_alist + : ?key_order:[`Increasing|`Decreasing] + -> ('k, 'v, _) t + -> ('k key * 'v) list + + val validate + : name:('k key -> string) + -> 'v Validate.check + -> ('k, 'v, _) t Validate.check + + val merge + : ('k, 'cmp, + ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k key + -> [ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> 'v3 option) + -> ('k, 'v3, 'cmp) t + ) options + + val symmetric_diff + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> data_equal:('v -> 'v -> bool) + -> ('k key, 'v) Symmetric_diff_element.t Sequence.t + ) options + + val min_elt : ('k, 'v, _) t -> ('k key * 'v) option + val min_elt_exn : ('k, 'v, _) t -> 'k key * 'v + + val max_elt : ('k, 'v, _) t -> ('k key * 'v) option + val max_elt_exn : ('k, 'v, _) t -> 'k key * 'v + + val for_all : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val for_alli : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> bool) -> bool + val exists : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val existsi : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> bool) -> bool + val count : ('k, 'v, _) t -> f:( 'v -> bool) -> int + val counti : ('k, 'v, _) t -> f:(key:'k key -> data:'v -> bool) -> int + + val split + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> 'k key + -> ('k, 'v, 'cmp) t * ('k key * 'v) option * ('k, 'v, 'cmp) t + ) options + + val append + : ('k, 'cmp, + lower_part:('k, 'v, 'cmp) t + -> upper_part:('k, 'v, 'cmp) t + -> [ `Ok of ('k, 'v, 'cmp) t | `Overlapping_key_ranges ] + ) options + + val subrange + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> lower_bound:'k key Maybe_bound.t + -> upper_bound:'k key Maybe_bound.t + -> ('k, 'v, 'cmp) t + ) options + + val fold_range_inclusive + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> min:'k key + -> max:'k key + -> init:'a + -> f:(key:'k key -> data:'v -> 'a -> 'a) + -> 'a + ) options + + val range_to_alist + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> min:'k key -> max:'k key -> ('k key * 'v) list + ) options + + val closest_key + : ('k, 'cmp, + ('k, 'v, 'cmp) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'k key -> ('k key * 'v) option + ) options + + val nth + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> int -> ('k key * 'v) option + ) options + + val nth_exn + : ('k, 'cmp, + ('k, 'v, 'cmp) t -> int -> ('k key * 'v) + ) options + + val rank + : ('k, 'cmp, + ('k, _, 'cmp) t -> 'k key -> int option + ) options + + val to_tree : ('k, 'v, 'cmp) t -> ('k key, 'v, 'cmp) tree + + val to_sequence + : ('k, 'cmp, + ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'k key + -> ?keys_less_or_equal_to:'k key + -> ('k, 'v, 'cmp) t + -> ('k key * 'v) Sequence.t + ) options + +end + +module type Accessors1 = sig + type 'a t + type 'a tree + type key + val invariants : _ t -> bool + val is_empty : _ t -> bool + val length : _ t -> int + val add : 'a t -> key:key -> data:'a -> 'a t Or_duplicate.t + val add_exn : 'a t -> key:key -> data:'a -> 'a t + val set : 'a t -> key:key -> data:'a -> 'a t + val add_multi : 'a list t -> key:key -> data:'a -> 'a list t + val remove_multi : 'a list t -> key -> 'a list t + val find_multi : 'a list t -> key -> 'a list + val change : 'a t -> key -> f:('a option -> 'a option) -> 'a t + val update : 'a t -> key -> f:('a option -> 'a) -> 'a t + val find : 'a t -> key -> 'a option + val find_exn : 'a t -> key -> 'a + val remove : 'a t -> key -> 'a t + val mem : _ t -> key -> bool + + val iter_keys : _ t -> f:(key -> unit) -> unit + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(key:key -> data:'a -> unit) -> unit + val iter2 + : 'a t + -> 'b t + -> f:(key:key -> data:[ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> unit) + -> unit + val map : 'a t -> f:('a -> 'b) -> 'b t + val mapi : 'a t -> f:(key:key -> data:'a -> 'b) -> 'b t + val fold : 'a t -> init:'b -> f:(key:key -> data:'a -> 'b -> 'b) -> 'b + val fold_right : 'a t -> init:'b -> f:(key:key -> data:'a -> 'b -> 'b) -> 'b + val fold2 + : 'a t + -> 'b t + -> init:'c + -> f:(key:key -> data:[ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> 'c -> 'c) + -> 'c + + val filter_keys : 'a t -> f:(key -> bool) -> 'a t + val filter : 'a t -> f:('a -> bool) -> 'a t + val filteri : 'a t -> f:(key:key -> data:'a -> bool) -> 'a t + val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + val filter_mapi : 'a t -> f:(key:key -> data:'a -> 'b option) -> 'b t + val partition_mapi + : 'a t + -> f:(key:key -> data:'a -> [`Fst of 'b | `Snd of 'c]) + -> 'b t * 'c t + val partition_map + : 'a t + -> f:('a -> [`Fst of 'b | `Snd of 'c]) + -> 'b t * 'c t + val partitioni_tf + : 'a t + -> f:(key:key -> data:'a -> bool) + -> 'a t * 'a t + val partition_tf + : 'a t + -> f:('a -> bool) + -> 'a t * 'a t + val compare_direct : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool)-> 'a t -> 'a t -> bool + val keys : _ t -> key list + val data : 'a t -> 'a list + val to_alist : ?key_order:[`Increasing|`Decreasing] -> 'a t -> (key * 'a) list + val validate : name:(key -> string) -> 'a Validate.check -> 'a t Validate.check + val merge + : 'a t + -> 'b t + -> f:(key:key -> [ `Left of 'a | `Right of 'b | `Both of 'a * 'b ] -> 'c option) + -> 'c t + val symmetric_diff + : 'a t + -> 'a t + -> data_equal:('a -> 'a -> bool) + -> (key, 'a) Symmetric_diff_element.t Sequence.t + val min_elt : 'a t -> (key * 'a) option + val min_elt_exn : 'a t -> key * 'a + val max_elt : 'a t -> (key * 'a) option + val max_elt_exn : 'a t -> key * 'a + val for_all : 'a t -> f:( 'a -> bool) -> bool + val for_alli : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val exists : 'a t -> f:( 'a -> bool) -> bool + val existsi : 'a t -> f:(key:key -> data:'a -> bool) -> bool + val count : 'a t -> f:( 'a -> bool) -> int + val counti : 'a t -> f:(key:key -> data:'a -> bool) -> int + val split : 'a t -> key -> 'a t * (key * 'a) option * 'a t + val append + : lower_part:'a t + -> upper_part:'a t + -> [ `Ok of 'a t | `Overlapping_key_ranges ] + val subrange + : 'a t + -> lower_bound:key Maybe_bound.t + -> upper_bound:key Maybe_bound.t + -> 'a t + val fold_range_inclusive + : 'a t + -> min:key + -> max:key + -> init:'b + -> f:(key:key -> data:'a -> 'b -> 'b) + -> 'b + val range_to_alist : 'a t -> min:key -> max:key -> (key * 'a) list + val closest_key : 'a t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> key -> (key * 'a) option + val nth : 'a t -> int -> (key * 'a) option + val nth_exn : 'a t -> int -> (key * 'a) + val rank : _ t -> key -> int option + val to_tree : 'a t -> 'a tree + val to_sequence + : ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:key + -> ?keys_less_or_equal_to:key + -> 'a t + -> (key * 'a) Sequence.t +end + +module type Accessors2 = sig + type ('a, 'b) t + type ('a, 'b) tree + val invariants : (_, _) t -> bool + val is_empty : (_, _) t -> bool + val length : (_, _) t -> int + val add : ('a, 'b) t -> key:'a -> data:'b -> ('a, 'b) t Or_duplicate.t + val add_exn : ('a, 'b) t -> key:'a -> data:'b -> ('a, 'b) t + val set : ('a, 'b) t -> key:'a -> data:'b -> ('a, 'b) t + val add_multi : ('a, 'b list) t -> key:'a -> data:'b -> ('a, 'b list) t + val remove_multi : ('a, 'b list) t -> 'a -> ('a, 'b list) t + val find_multi : ('a, 'b list) t -> 'a -> 'b list + val change : ('a, 'b) t -> 'a -> f:('b option -> 'b option) -> ('a, 'b) t + val update : ('a, 'b) t -> 'a -> f:('b option -> 'b) -> ('a, 'b) t + val find : ('a, 'b) t -> 'a -> 'b option + val find_exn : ('a, 'b) t -> 'a -> 'b + val remove : ('a, 'b) t -> 'a -> ('a, 'b) t + val mem : ('a, 'b) t -> 'a -> bool + + val iter_keys : ('a, _) t -> f:('a -> unit) -> unit + val iter : ( _, 'b) t -> f:('b -> unit) -> unit + val iteri : ('a, 'b) t -> f:(key:'a -> data:'b -> unit) -> unit + val iter2 + : ('a, 'b) t + -> ('a, 'c) t + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> unit) + -> unit + val map : ('a, 'b) t -> f:('b -> 'c) -> ('a, 'c) t + val mapi : ('a, 'b) t -> f:(key:'a -> data:'b -> 'c) -> ('a, 'c) t + val fold : ('a, 'b) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold_right : ('a, 'b) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold2 + : ('a, 'b) t + -> ('a, 'c) t + -> init:'d + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd -> 'd) + -> 'd + + val filter_keys : ('a, 'b) t -> f:('a -> bool) -> ('a, 'b) t + val filter : ('a, 'b) t -> f:('b -> bool) -> ('a, 'b) t + val filteri : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> ('a, 'b) t + val filter_map : ('a, 'b) t -> f:('b -> 'c option) -> ('a, 'c) t + val filter_mapi : ('a, 'b) t -> f:(key:'a -> data:'b -> 'c option) -> ('a, 'c) t + val partition_mapi + : ('a, 'b) t + -> f:(key:'a -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + val partition_map + : ('a, 'b) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c) t * ('a, 'd) t + val partitioni_tf + : ('a, 'b) t + -> f:(key:'a -> data:'b -> bool) + -> ('a, 'b) t * ('a, 'b) t + val partition_tf + : ('a, 'b) t + -> f:('b -> bool) + -> ('a, 'b) t * ('a, 'b) t + val compare_direct : ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val equal : ('b -> 'b -> bool)-> ('a, 'b) t -> ('a, 'b) t -> bool + val keys : ('a, _) t -> 'a list + val data : (_, 'b) t -> 'b list + val to_alist : ?key_order:[`Increasing|`Decreasing] -> ('a, 'b) t -> ('a * 'b) list + val validate + : name:('a -> string) -> 'b Validate.check -> ('a, 'b) t Validate.check + val merge + : ('a, 'b) t + -> ('a, 'c) t + -> f:(key:'a -> [ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd option) + -> ('a, 'd) t + val symmetric_diff + : ('a, 'b) t + -> ('a, 'b) t + -> data_equal:('b -> 'b -> bool) + -> ('a, 'b) Symmetric_diff_element.t Sequence.t + val min_elt : ('a, 'b) t -> ('a * 'b) option + val min_elt_exn : ('a, 'b) t -> 'a * 'b + val max_elt : ('a, 'b) t -> ('a * 'b) option + val max_elt_exn : ('a, 'b) t -> 'a * 'b + val for_all : ( _, 'b) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> bool + val exists : ( _, 'b) t -> f:( 'b -> bool) -> bool + val existsi : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> bool + val count : ( _, 'b) t -> f:( 'b -> bool) -> int + val counti : ('a, 'b) t -> f:(key:'a -> data:'b -> bool) -> int + val split : ('a, 'b) t -> 'a -> ('a, 'b) t * ('a * 'b) option * ('a, 'b) t + val append + : lower_part:('a, 'b) t + -> upper_part:('a, 'b) t + -> [ `Ok of ('a, 'b) t | `Overlapping_key_ranges ] + val subrange + : ('a, 'b) t + -> lower_bound:'a Maybe_bound.t + -> upper_bound:'a Maybe_bound.t + -> ('a, 'b) t + val fold_range_inclusive + : ('a, 'b) t -> min:'a -> max:'a -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val range_to_alist : ('a, 'b) t -> min:'a -> max:'a -> ('a * 'b) list + val closest_key : ('a, 'b) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'a -> ('a * 'b) option + val nth : ('a, 'b) t -> int -> ('a * 'b) option + val nth_exn : ('a, 'b) t -> int -> ('a * 'b) + val rank : ('a, _) t -> 'a -> int option + val to_tree : ('a, 'b) t -> ('a, 'b) tree + val to_sequence + : ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'a + -> ?keys_less_or_equal_to:'a + -> ('a, 'b) t + -> ('a * 'b) Sequence.t +end + +module type Accessors3 = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + val invariants : (_, _, _) t -> bool + val is_empty : (_, _, _) t -> bool + val length : (_, _, _) t -> int + val add : ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b , 'cmp) t Or_duplicate.t + val add_exn : ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b , 'cmp) t + val set : ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b , 'cmp) t + val add_multi : ('a, 'b list, 'cmp) t -> key:'a -> data:'b -> ('a, 'b list, 'cmp) t + val remove_multi : ('a, 'b list, 'cmp) t -> 'a -> ('a, 'b list, 'cmp) t + val find_multi : ('a, 'b list, 'cmp) t -> 'a -> 'b list + val change : ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b option) -> ('a, 'b, 'cmp) t + val update : ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b) -> ('a, 'b, 'cmp) t + val find : ('a, 'b, 'cmp) t -> 'a -> 'b option + val find_exn : ('a, 'b, 'cmp) t -> 'a -> 'b + val remove : ('a, 'b, 'cmp) t -> 'a -> ('a, 'b, 'cmp) t + val mem : ('a, 'b, 'cmp) t -> 'a -> bool + + val iter_keys : ('a, _, 'cmp) t -> f:('a -> unit) -> unit + val iter : ( _, 'b, 'cmp) t -> f:('b -> unit) -> unit + val iteri : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> unit) -> unit + val iter2 + : ('a, 'b, 'cmp) t + -> ('a, 'c, 'cmp) t + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> unit) + -> unit + val map : ('a, 'b, 'cmp) t -> f:('b -> 'c) -> ('a, 'c, 'cmp) t + val mapi : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c) -> ('a, 'c, 'cmp) t + val fold : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold_right : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold2 + : ('a, 'b, 'cmp) t + -> ('a, 'c, 'cmp) t + -> init:'d + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd -> 'd) + -> 'd + + val filter_keys : ('a, 'b, 'cmp) t -> f:('a -> bool) -> ('a, 'b, 'cmp) t + val filter : ('a, 'b, 'cmp) t -> f:('b -> bool) -> ('a, 'b, 'cmp) t + val filteri : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> ('a, 'b, 'cmp) t + val filter_map : ('a, 'b, 'cmp) t -> f:('b -> 'c option) -> ('a, 'c, 'cmp) t + val filter_mapi + : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c option) -> ('a, 'c, 'cmp) t + val partition_mapi + : ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partition_map + : ('a, 'b, 'cmp) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partitioni_tf + : ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val partition_tf + : ('a, 'b, 'cmp) t + -> f:('b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val compare_direct : ('b -> 'b -> int) -> ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) t -> int + val equal : ('b -> 'b -> bool)-> ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) t -> bool + val keys : ('a, _, _) t -> 'a list + val data : (_, 'b, _) t -> 'b list + val to_alist : ?key_order:[`Increasing|`Decreasing] -> ('a, 'b, _) t -> ('a * 'b) list + val validate + : name:('a -> string) -> 'b Validate.check -> ('a, 'b, _) t Validate.check + val merge + : ('a, 'b, 'cmp) t -> ('a, 'c, 'cmp) t + -> f:(key:'a -> [ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd option) + -> ('a, 'd, 'cmp) t + val symmetric_diff + : ('a, 'b, 'cmp) t + -> ('a, 'b, 'cmp) t + -> data_equal:('b -> 'b -> bool) + -> ('a, 'b) Symmetric_diff_element.t Sequence.t + val min_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val min_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val max_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val max_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val for_all : ( _, 'b, _) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b, _) t -> f:(key:'a -> data:'b -> bool) -> bool + val exists : ( _, 'b, _) t -> f:( 'b -> bool) -> bool + val existsi : ('a, 'b, _) t -> f:(key:'a -> data:'b -> bool) -> bool + val count : ( _, 'b, _) t -> f:( 'b -> bool) -> int + val counti : ('a, 'b, _) t -> f:(key:'a -> data:'b -> bool) -> int + val split + : ('k, 'v, 'cmp) t + -> 'k + -> ('k, 'v, 'cmp) t * ('k * 'v) option * ('k, 'v, 'cmp) t + val append + : lower_part:('k, 'v, 'cmp) t + -> upper_part:('k, 'v, 'cmp) t + -> [ `Ok of ('k, 'v, 'cmp) t | `Overlapping_key_ranges ] + val subrange + : ('k, 'v, 'cmp) t + -> lower_bound:'k Maybe_bound.t + -> upper_bound:'k Maybe_bound.t + -> ('k, 'v, 'cmp) t + val fold_range_inclusive + : ('a, 'b, _) t + -> min:'a + -> max:'a + -> init:'c + -> f:(key:'a -> data:'b -> 'c -> 'c) + -> 'c + val range_to_alist : ('a, 'b, _) t -> min:'a -> max:'a -> ('a * 'b) list + val closest_key : ('a, 'b, _) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'a -> ('a * 'b) option + val nth : ('a, 'b, _) t -> int -> ('a * 'b) option + val nth_exn : ('a, 'b, _) t -> int -> ('a * 'b) + val rank : ('a, _, _) t -> 'a -> int option + val to_tree : ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) tree + val to_sequence + : ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'a + -> ?keys_less_or_equal_to:'a + -> ('a, 'b, _) t + -> ('a * 'b) Sequence.t +end + +module type Accessors3_with_comparator = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + val invariants : comparator:('a, 'cmp) Comparator.t -> ('a, 'b, 'cmp) t -> bool + val is_empty : ('a, 'b, 'cmp) t -> bool + val length : ('a, 'b, 'cmp) t -> int + val add + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b, 'cmp) t Or_duplicate.t + val add_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b, 'cmp) t + val set + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> key:'a -> data:'b -> ('a, 'b, 'cmp) t + val add_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b list, 'cmp) t -> key:'a -> data:'b -> ('a, 'b list, 'cmp) t + val remove_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b list, 'cmp) t -> 'a -> ('a, 'b list, 'cmp) t + val find_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b list, 'cmp) t -> 'a -> 'b list + val change + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b option) -> ('a, 'b, 'cmp) t + val update + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> f:('b option -> 'b) -> ('a, 'b, 'cmp) t + val find + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> 'b option + val find_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> 'b + val remove + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> ('a, 'b, 'cmp) t + val mem + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> bool + + val iter_keys : ('a, _, 'cmp) t -> f:('a -> unit) -> unit + val iter : ( _, 'b, 'cmp) t -> f:('b -> unit) -> unit + val iteri : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> unit) -> unit + val iter2 + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> ('a, 'c, 'cmp) t + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c ]-> unit) + -> unit + val map : ('a, 'b, 'cmp) t -> f:('b -> 'c) -> ('a, 'c, 'cmp) t + val mapi : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c) -> ('a, 'c, 'cmp) t + val fold : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold_right : ('a, 'b, _) t -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val fold2 + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> ('a, 'c, 'cmp) t + -> init:'d + -> f:(key:'a -> data:[ `Left of 'b | `Right of 'c | `Both of 'b * 'c] -> 'd -> 'd) + -> 'd + + val filter_keys + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:('a -> bool) -> ('a, 'b, 'cmp) t + val filter + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:('b -> bool) -> ('a, 'b, 'cmp) t + val filteri + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> ('a, 'b, 'cmp) t + val filter_map + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:('b -> 'c option) -> ('a, 'c, 'cmp) t + val filter_mapi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> 'c option) -> ('a, 'c, 'cmp) t + val partition_mapi + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partition_map + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:('b -> [`Fst of 'c | `Snd of 'd]) + -> ('a, 'c, 'cmp) t * ('a, 'd, 'cmp) t + val partitioni_tf + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:(key:'a -> data:'b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val partition_tf + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> f:('b -> bool) + -> ('a, 'b, 'cmp) t * ('a, 'b, 'cmp) t + val compare_direct + : comparator:('a, 'cmp) Comparator.t + -> ('b -> 'b -> int) + -> ('a, 'b, 'cmp) t + -> ('a, 'b, 'cmp) t + -> int + val equal + : comparator:('a, 'cmp) Comparator.t + -> ('b -> 'b -> bool) -> ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) t -> bool + val keys : ('a, _, _) t -> 'a list + val data : (_ , 'b, _) t -> 'b list + val to_alist + : ?key_order:[`Increasing|`Decreasing] -> ('a, 'b, _) t -> ('a * 'b) list + val validate + : name:('a -> string) -> 'b Validate.check -> ('a, 'b, _) t Validate.check + val merge + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> ('a, 'c, 'cmp) t + -> f:(key:'a -> [ `Left of 'b | `Right of 'c | `Both of 'b * 'c ] -> 'd option) + -> ('a, 'd, 'cmp) t + val symmetric_diff + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> ('a, 'b, 'cmp) t + -> data_equal:('b -> 'b -> bool) + -> ('a, 'b) Symmetric_diff_element.t Sequence.t + val min_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val min_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val max_elt : ('a, 'b, 'cmp) t -> ('a * 'b) option + val max_elt_exn : ('a, 'b, 'cmp) t -> 'a * 'b + val for_all : ('a, 'b, 'cmp) t -> f:( 'b -> bool) -> bool + val for_alli : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> bool + val exists : ('a, 'b, 'cmp) t -> f:( 'b -> bool) -> bool + val existsi : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> bool + val count : ('a, 'b, 'cmp) t -> f:( 'b -> bool) -> int + val counti : ('a, 'b, 'cmp) t -> f:(key:'a -> data:'b -> bool) -> int + val split + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> 'a + -> ('a, 'b, 'cmp) t * ('a * 'b) option * ('a, 'b, 'cmp) t + val append + : comparator:('a, 'cmp) Comparator.t + -> lower_part:('a, 'b, 'cmp) t + -> upper_part:('a, 'b, 'cmp) t + -> [ `Ok of ('a, 'b, 'cmp) t | `Overlapping_key_ranges ] + val subrange + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> lower_bound:'a Maybe_bound.t + -> upper_bound:'a Maybe_bound.t + -> ('a, 'b, 'cmp) t + val fold_range_inclusive + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> min:'a -> max:'a -> init:'c -> f:(key:'a -> data:'b -> 'c -> 'c) -> 'c + val range_to_alist + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> min:'a -> max:'a -> ('a * 'b) list + val closest_key + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'a -> ('a * 'b) option + val nth + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> int -> ('a * 'b) option + val nth_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> int -> ('a * 'b) + val rank + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) t -> 'a -> int option + val to_tree : ('a, 'b, 'cmp) t -> ('a, 'b, 'cmp) tree + val to_sequence + : comparator:('a, 'cmp) Comparator.t + -> ?order:[ `Increasing_key | `Decreasing_key ] + -> ?keys_greater_or_equal_to:'a + -> ?keys_less_or_equal_to:'a + -> ('a, 'b, 'cmp) t + -> ('a * 'b) Sequence.t +end + +(** Consistency checks (same as in [Container]). *) +module Check_accessors (T : T3) (Tree : T3) (Key : T1) (Options : T3) + (M : Accessors_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b, 'c) t := ('a, 'b, 'c) T.t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + with type 'a key := 'a Key.t) += struct end + +module Check_accessors1 (M : Accessors1) = + Check_accessors + (struct type ('a, 'b, 'c) t = 'b M.t end) + (struct type ('a, 'b, 'c) t = 'b M.tree end) + (struct type 'a t = M.key end) + (Without_comparator) + (M) + +module Check_accessors2 (M : Accessors2) = + Check_accessors + (struct type ('a, 'b, 'c) t = ('a, 'b) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_accessors3 (M : Accessors3) = + Check_accessors + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.tree end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_accessors3_with_comparator (M : Accessors3_with_comparator) = + Check_accessors + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.tree end) + (struct type 'a t = 'a end) + (With_comparator) + (M) + +module type Creators_generic = sig + type ('k, 'v, 'cmp) t + type ('k, 'v, 'cmp) tree + type 'k key + type ('a, 'cmp, 'z) options + + val empty : ('k, 'cmp, ('k, _, 'cmp) t) options + + val singleton : ('k, 'cmp, 'k key -> 'v -> ('k, 'v, 'cmp) t) options + + val of_sorted_array + : ('k, 'cmp, ('k key * 'v) array -> ('k, 'v, 'cmp) t Or_error.t) options + + val of_sorted_array_unchecked + : ('k, 'cmp, ('k key * 'v) array -> ('k, 'v, 'cmp) t) options + + val of_increasing_iterator_unchecked + : ('k, 'cmp, len:int -> f:(int -> 'k key * 'v) -> ('k, 'v, 'cmp) t) options + + val of_increasing_sequence + : ('k, 'cmp, ('k key * 'v) Sequence.t -> ('k, 'v, 'cmp) t Or_error.t) options + + val of_alist + : ('k, + 'cmp, + ('k key * 'v) list -> [ `Ok of ('k, 'v, 'cmp) t | `Duplicate_key of 'k key ] + ) options + + val of_alist_or_error + : ('k, 'cmp, ('k key * 'v) list -> ('k, 'v, 'cmp) t Or_error.t) options + + val of_alist_exn : ('k, 'cmp, ('k key * 'v) list -> ('k, 'v, 'cmp) t) options + + val of_alist_multi : ('k, 'cmp, ('k key * 'v) list -> ('k, 'v list, 'cmp) t) options + + val of_alist_fold + : ('k, 'cmp, + ('k key * 'v1) list + -> init:'v2 + -> f:('v2 -> 'v1 -> 'v2) + -> ('k, 'v2, 'cmp) t + ) options + + val of_alist_reduce + : ('k, 'cmp, + ('k key * 'v) list + -> f:('v -> 'v -> 'v) + -> ('k, 'v, 'cmp) t + ) options + + val of_iteri + : ('k, 'cmp, + iteri:(f:(key:'k key + -> data:'v + -> unit) + -> unit) + -> [ `Ok of ('k, 'v, 'cmp) t + | `Duplicate_key of 'k key ] + ) options + + val of_tree + : ('k, 'cmp, + ('k key, 'v, 'cmp) tree -> ('k, 'v, 'cmp) t + ) options +end + +module type Creators1 = sig + type 'a t + type 'a tree + type key + val empty : _ t + val singleton : key -> 'a -> 'a t + val of_alist : (key * 'a) list -> [ `Ok of 'a t | `Duplicate_key of key ] + val of_alist_or_error : (key * 'a) list -> 'a t Or_error.t + val of_alist_exn : (key * 'a) list -> 'a t + val of_alist_multi : (key * 'a) list -> 'a list t + val of_alist_fold : (key * 'a) list -> init:'b -> f:('b -> 'a -> 'b) -> 'b t + val of_alist_reduce : (key * 'a) list -> f:('a -> 'a -> 'a) -> 'a t + val of_sorted_array : (key * 'a) array -> 'a t Or_error.t + val of_sorted_array_unchecked : (key * 'a) array -> 'a t + val of_increasing_iterator_unchecked : len:int -> f:(int -> key * 'a) -> 'a t + val of_increasing_sequence : (key * 'a) Sequence.t -> 'a t Or_error.t + val of_iteri : iteri:(f:(key:key -> data:'v -> unit) -> unit) + -> [ `Ok of 'v t | `Duplicate_key of key ] + val of_tree : 'a tree -> 'a t +end + +module type Creators2 = sig + type ('a, 'b) t + type ('a, 'b) tree + val empty : (_, _) t + val singleton : 'a -> 'b -> ('a, 'b) t + val of_alist : ('a * 'b) list -> [ `Ok of ('a, 'b) t | `Duplicate_key of 'a ] + val of_alist_or_error : ('a * 'b) list -> ('a, 'b) t Or_error.t + val of_alist_exn : ('a * 'b) list -> ('a, 'b) t + val of_alist_multi : ('a * 'b) list -> ('a, 'b list) t + val of_alist_fold : ('a * 'b) list -> init:'c -> f:('c -> 'b -> 'c) -> ('a, 'c) t + val of_alist_reduce : ('a * 'b) list -> f:('b -> 'b -> 'b) -> ('a, 'b) t + val of_sorted_array : ('a * 'b) array -> ('a, 'b) t Or_error.t + val of_sorted_array_unchecked : ('a * 'b) array -> ('a, 'b) t + val of_increasing_iterator_unchecked : len:int -> f:(int -> 'a * 'b) -> ('a, 'b) t + val of_increasing_sequence : ('a * 'b) Sequence.t -> ('a, 'b) t Or_error.t + val of_iteri : iteri:(f:(key:'a -> data:'b -> unit) -> unit) + -> [ `Ok of ('a, 'b) t | `Duplicate_key of 'a ] + val of_tree : ('a, 'b) tree -> ('a, 'b) t +end + +module type Creators3_with_comparator = sig + type ('a, 'b, 'cmp) t + type ('a, 'b, 'cmp) tree + val empty : comparator:('a, 'cmp) Comparator.t -> ('a, _, 'cmp) t + val singleton : comparator:('a, 'cmp) Comparator.t -> 'a -> 'b -> ('a, 'b, 'cmp) t + val of_alist + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> [ `Ok of ('a, 'b, 'cmp) t | `Duplicate_key of 'a ] + val of_alist_or_error + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> ('a, 'b, 'cmp) t Or_error.t + val of_alist_exn + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> ('a, 'b, 'cmp) t + val of_alist_multi + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> ('a, 'b list, 'cmp) t + val of_alist_fold + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> init:'c -> f:('c -> 'b -> 'c) -> ('a, 'c, 'cmp) t + val of_alist_reduce + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) list -> f:('b -> 'b -> 'b) -> ('a, 'b, 'cmp) t + val of_sorted_array + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) array -> ('a, 'b, 'cmp) t Or_error.t + val of_sorted_array_unchecked + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) array -> ('a, 'b, 'cmp) t + val of_increasing_iterator_unchecked + : comparator:('a, 'cmp) Comparator.t + -> len:int -> f:(int -> 'a * 'b) -> ('a, 'b, 'cmp) t + val of_increasing_sequence + : comparator:('a, 'cmp) Comparator.t + -> ('a * 'b) Sequence.t -> ('a, 'b, 'cmp) t Or_error.t + val of_iteri + : comparator:('a, 'cmp) Comparator.t + -> iteri:(f:(key:'a -> data:'b -> unit) -> unit) + -> [ `Ok of ('a, 'b, 'cmp) t | `Duplicate_key of 'a ] + val of_tree + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'b, 'cmp) tree -> ('a, 'b, 'cmp) t +end + +module Check_creators (T : T3) (Tree : T3) (Key : T1) (Options : T3) + (M : Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b, 'c) t := ('a, 'b, 'c) T.t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + with type 'a key := 'a Key.t) += struct end + +module Check_creators1 (M : Creators1) = + Check_creators + (struct type ('a, 'b, 'c) t = 'b M.t end) + (struct type ('a, 'b, 'c) t = 'b M.tree end) + (struct type 'a t = M.key end) + (Without_comparator) + (M) + +module Check_creators2 (M : Creators2) = + Check_creators + (struct type ('a, 'b, 'c) t = ('a, 'b) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_creators3_with_comparator (M : Creators3_with_comparator) = + Check_creators + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.t end) + (struct type ('a, 'b, 'c) t = ('a, 'b, 'c) M.tree end) + (struct type 'a t = 'a end) + (With_comparator) + (M) + +module type Creators_and_accessors_generic = sig + include Creators_generic + include Accessors_generic + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) tree + with type 'a key := 'a key + with type ('a, 'b, 'c) options := ('a, 'b, 'c) options +end + +module type Creators_and_accessors1 = sig + include Creators1 + include Accessors1 + with type 'a t := 'a t + with type 'a tree := 'a tree + with type key := key +end + +module type Creators_and_accessors2 = sig + include Creators2 + include Accessors2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree +end + +module type Creators_and_accessors3_with_comparator = sig + include Creators3_with_comparator + include Accessors3_with_comparator + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) tree +end + +module type S_poly = Creators_and_accessors2 + +module type For_deriving = sig + type ('a, 'b, 'c) t + + module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t + end + module type Compare_m = sig end + module type Hash_fold_m = Hasher.S + + val sexp_of_m__t + : (module Sexp_of_m with type t = 'k) + -> ('v -> Sexp.t) + -> ('k, 'v, 'cmp) t + -> Sexp.t + + val m__t_of_sexp + : (module M_of_sexp with type t = 'k and type comparator_witness = 'cmp) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v, 'cmp) t + + val compare_m__t + : (module Compare_m) + -> ('v -> 'v -> int) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> int + + val hash_fold_m__t + : (module Hash_fold_m with type t = 'k) + -> (Hash.state -> 'v -> Hash.state) + -> (Hash.state -> ('k, 'v, _) t -> Hash.state) +end + +module type Map = sig + (** [Map] is a functional data structure (balanced binary tree) implementing finite maps + over a totally-ordered domain, called a "key". *) + + type ('key, +'value, 'cmp) t + + module Or_duplicate = Or_duplicate + + type ('k, 'cmp) comparator = + (module Comparator.S with type t = 'k and type comparator_witness = 'cmp) + + (** Test if the invariants of the internal AVL search tree hold. *) + val invariants : (_, _, _) t -> bool + + (** Returns a first-class module that can be used to build other map/set/etc. + with the same notion of comparison. *) + val comparator_s : ('a, _, 'cmp) t -> ('a, 'cmp) comparator + + val comparator : ('a, _, 'cmp) t -> ('a, 'cmp) Comparator.t + + (** The empty map. *) + val empty : ('a, 'cmp) comparator -> ('a, 'b, 'cmp) t + + (** A map with one (key, data) pair. *) + val singleton : ('a, 'cmp) comparator -> 'a -> 'b -> ('a, 'b, 'cmp) t + + (** Creates a map from an association list with unique keys. *) + val of_alist + : ('a, 'cmp) comparator + -> ('a * 'b) list + -> [ `Ok of ('a, 'b, 'cmp) t | `Duplicate_key of 'a ] + + (** Creates a map from an association list with unique keys, returning an error if + duplicate ['a] keys are found. *) + val of_alist_or_error + : ('a, 'cmp) comparator + -> ('a * 'b) list -> ('a, 'b, 'cmp) t Or_error.t + + (** Creates a map from an association list with unique keys, raising an exception if + duplicate ['a] keys are found. *) + val of_alist_exn + : ('a, 'cmp) comparator + -> ('a * 'b) list -> ('a, 'b, 'cmp) t + + (** Creates a map from an association list with possibly repeated keys. The values in + the map for a given key appear in the same order as they did in the association + list. *) + val of_alist_multi + : ('a, 'cmp) comparator + -> ('a * 'b) list -> ('a, 'b list, 'cmp) t + + (** Combines an association list into a map, folding together bound values with common + keys. *) + val of_alist_fold + : ('a, 'cmp) comparator + -> ('a * 'b) list -> init:'c -> f:('c -> 'b -> 'c) -> ('a, 'c, 'cmp) t + + (** Combines an association list into a map, reducing together bound values with common + keys. *) + val of_alist_reduce + : ('a, 'cmp) comparator + -> ('a * 'b) list -> f:('b -> 'b -> 'b) -> ('a, 'b, 'cmp) t + + (** [of_iteri ~iteri] behaves like [of_alist], except that instead of taking a concrete + data structure, it takes an iteration function. For instance, to convert a string table + into a map: [of_iteri (module String) ~f:(Hashtbl.iteri table)]. It is faster than + adding the elements one by one. *) + val of_iteri + : ('a, 'cmp) comparator + -> iteri:(f:(key:'a -> data:'b -> unit) -> unit) + -> [ `Ok of ('a, 'b, 'cmp) t + | `Duplicate_key of 'a ] + + (** Creates a map from a sorted array of key-data pairs. The input array must be sorted + (either in ascending or descending order), as given by the relevant comparator, and + must not contain duplicate keys. If either of these conditions does not hold, + an error is returned. *) + val of_sorted_array + : ('a, 'cmp) comparator + -> ('a * 'b) array -> ('a, 'b, 'cmp) t Or_error.t + + (** Like [of_sorted_array] except that it returns a map with broken invariants when an + [Error] would have been returned. *) + val of_sorted_array_unchecked + : ('a, 'cmp) comparator + -> ('a * 'b) array -> ('a, 'b, 'cmp) t + + (** [of_increasing_iterator_unchecked c ~len ~f] behaves like [of_sorted_array_unchecked c + (Array.init len ~f)], with the additional restriction that a decreasing order is not + supported. The advantage is not requiring you to allocate an intermediate array. [f] + will be called with 0, 1, ... [len - 1], in order. *) + val of_increasing_iterator_unchecked + : ('a, 'cmp) comparator + -> len:int + -> f:(int -> 'a * 'b) + -> ('a, 'b, 'cmp) t + + (** [of_increasing_sequence c seq] behaves like [of_sorted_array c (Sequence.to_array + seq)], but does not allocate the intermediate array. + + The sequence will be folded over once, and the additional time complexity is {e O(n)}. + *) + val of_increasing_sequence + : ('k, 'cmp) comparator + -> ('k * 'v) Sequence.t + -> ('k, 'v, 'cmp) t Or_error.t + + (** Tests whether a map is empty. *) + val is_empty : (_, _, _) t -> bool + + (** [length map] returns the number of elements in [map]. O(1), but [Tree.length] is + O(n). *) + val length : (_, _, _) t -> int + + (** Returns a new map with the specified new binding; if the key was already bound, its + previous binding disappears. *) + val set : ('k, 'v, 'cmp) t -> key:'k -> data:'v -> ('k, 'v, 'cmp) t + + (** [add t ~key ~data] adds a new entry to [t] mapping [key] to [data] and returns [`Ok] + with the new map, or if [key] is already present in [t], returns [`Duplicate]. *) + val add : ('k, 'v, 'cmp) t -> key:'k -> data:'v -> ('k, 'v, 'cmp) t Or_duplicate.t + val add_exn : ('k, 'v, 'cmp) t -> key:'k -> data:'v -> ('k, 'v, 'cmp) t + + (** If [key] is not present then add a singleton list, otherwise, cons data onto the + head of the existing list. *) + val add_multi + : ('k, 'v list, 'cmp) t + -> key:'k + -> data:'v + -> ('k, 'v list, 'cmp) t + + (** If the key is present, then remove its head element; if the result is empty, remove + the key. *) + val remove_multi + : ('k, 'v list, 'cmp) t + -> 'k + -> ('k, 'v list, 'cmp) t + + (** Returns the value bound to the given key, or the empty list if there is none. *) + val find_multi + : ('k, 'v list, 'cmp) t + -> 'k + -> 'v list + + (** [change t key ~f] returns a new map [m] that is the same as [t] on all keys except + for [key], and whose value for [key] is defined by [f], i.e., [find m key = f (find + t key)]. *) + val change + : ('k, 'v, 'cmp) t + -> 'k + -> f:('v option -> 'v option) + -> ('k, 'v, 'cmp) t + + (** [update t key ~f] is [change t key ~f:(fun o -> Some (f o))]. *) + val update + : ('k, 'v, 'cmp) t + -> 'k + -> f:('v option -> 'v) + -> ('k, 'v, 'cmp) t + + (** Returns [Some value] bound to the given key, or [None] if none exists. *) + val find : ('k, 'v, 'cmp) t -> 'k -> 'v option + + (** Returns the value bound to the given key, raising [Caml.Not_found] of [Not_found_s] + if none exists. *) + val find_exn : ('k, 'v, 'cmp) t -> 'k -> 'v + + (** Returns a new map with any binding for the key in question removed. *) + val remove : ('k, 'v, 'cmp) t -> 'k -> ('k, 'v, 'cmp) t + + (** [mem map key] tests whether [map] contains a binding for [key]. *) + val mem : ('k, _, 'cmp) t -> 'k -> bool + + val iter_keys : ('k, _, _) t -> f:('k -> unit) -> unit + val iter : (_, 'v, _) t -> f:('v -> unit) -> unit + val iteri : ('k, 'v, _) t -> f:(key:'k -> data:'v -> unit) -> unit + + (** Iterates two maps side by side. The complexity of this function is O(M + N). If two + inputs are [[(0, a); (1, a)]] and [[(1, b); (2, b)]], [f] will be called with [[(0, + `Left a); (1, `Both (a, b)); (2, `Right b)]]. *) + val iter2 + : ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] -> unit) + -> unit + + (** Returns a new map with bound values replaced by [f] applied to the bound values.*) + val map : ('k, 'v1, 'cmp) t -> f:('v1 -> 'v2) -> ('k, 'v2, 'cmp) t + + (** Like [map], but the passed function takes both [key] and [data] as arguments. *) + val mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k -> data:'v1 -> 'v2) + -> ('k, 'v2, 'cmp) t + + (** Folds over keys and data in the map in increasing order of [key]. *) + val fold : ('k, 'v, _) t -> init:'a -> f:(key:'k -> data:'v -> 'a -> 'a) -> 'a + + (** Folds over keys and data in the map in decreasing order of [key]. *) + val fold_right : ('k, 'v, _) t -> init:'a -> f:(key:'k -> data:'v -> 'a -> 'a) -> 'a + + (** Folds over two maps side by side, like [iter2]. *) + val fold2 + : ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> init:'a + -> f:(key:'k -> data:[ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] -> 'a -> 'a) + -> 'a + + (** [filter], [filteri], [filter_keys], [filter_map], and [filter_mapi] run in O(n * lg + n) time; they simply accumulate each key & data pair retained by [f] into a new map + using [add]. *) + val filter_keys : ('k, 'v, 'cmp) t -> f:('k -> bool) -> ('k, 'v, 'cmp) t + val filter : ('k, 'v, 'cmp) t -> f:('v -> bool) -> ('k, 'v, 'cmp) t + val filteri : ('k, 'v, 'cmp) t -> f:(key:'k -> data:'v -> bool) -> ('k, 'v, 'cmp) t + + (** Returns a new map with bound values filtered by [f] applied to the bound values. *) + val filter_map + : ('k, 'v1, 'cmp) t + -> f:('v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + + (** Like [filter_map], but the passed function takes both [key] and [data] as + arguments. *) + val filter_mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k -> data:'v1 -> 'v2 option) + -> ('k, 'v2, 'cmp) t + + (** [partition_mapi t ~f] returns two new [t]s, with each key in [t] appearing in + exactly one of the resulting maps depending on its mapping in [f]. *) + val partition_mapi + : ('k, 'v1, 'cmp) t + -> f:(key:'k -> data:'v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + + (** [partition_map t ~f = partition_mapi t ~f:(fun ~key:_ ~data -> f data)] *) + val partition_map + : ('k, 'v1, 'cmp) t + -> f:('v1 -> [`Fst of 'v2 | `Snd of 'v3]) + -> ('k, 'v2, 'cmp) t * ('k, 'v3, 'cmp) t + + (** + {[ + partitioni_tf t ~f + = + partition_mapi t ~f:(fun ~key ~data -> + if f ~key ~data + then `Fst data + else `Snd data) + ]} *) + val partitioni_tf + : ('k, 'v, 'cmp) t + -> f:(key:'k -> data:'v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + + (** [partition_tf t ~f = partitioni_tf t ~f:(fun ~key:_ ~data -> f data)] *) + val partition_tf + : ('k, 'v, 'cmp) t + -> f:('v -> bool) + -> ('k, 'v, 'cmp) t * ('k, 'v, 'cmp) t + + (** Returns a total ordering between maps. The first argument is a total ordering used + to compare data associated with equal keys in the two maps. *) + val compare_direct + : ('v -> 'v -> int) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> int + + (** Hash function: a building block to use when hashing data structures containing maps in + them. [hash_fold_direct hash_fold_key] is compatible with [compare_direct] iff + [hash_fold_key] is compatible with [(comparator m).compare] of the map [m] being + hashed. *) + val hash_fold_direct + : 'k Hash.folder + -> 'v Hash.folder + -> ('k, 'v, 'cmp) t Hash.folder + + (** [equal cmp m1 m2] tests whether the maps [m1] and [m2] are equal, that is, contain + the same keys and associate each key with the same value. [cmp] is the equality + predicate used to compare the values associated with the keys. *) + val equal + : ('v -> 'v -> bool) + -> ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> bool + + (** Returns a list of the keys in the given map. *) + val keys : ('k, _, _) t -> 'k list + + (** Returns a list of the data in the given map. *) + val data : (_, 'v, _) t -> 'v list + + (** Creates an association list from the given map. *) + val to_alist + : ?key_order : [ `Increasing | `Decreasing ] (** default is [`Increasing] *) + -> ('k, 'v, _) t + -> ('k * 'v) list + + val validate : name:('k -> string) -> 'v Validate.check -> ('k, 'v, _) t Validate.check + + (** {2 Additional operations on maps} *) + + (** Merges two maps. The runtime is O(length(t1) + length(t2)). You shouldn't use this + function to merge a list of maps; consider using [merge_skewed] instead. *) + val merge + : ('k, 'v1, 'cmp) t + -> ('k, 'v2, 'cmp) t + -> f:(key:'k + -> [ `Left of 'v1 | `Right of 'v2 | `Both of 'v1 * 'v2 ] + -> 'v3 option) + -> ('k, 'v3, 'cmp) t + + (** A special case of [merge], [merge_skewed t1 t2] is a map containing all the + bindings of [t1] and [t2]. Bindings that appear in both [t1] and [t2] are + combined into a single value using the [combine] function. In a call + [combine ~key v1 v2], the value [v1] comes from [t1] and [v2] from [t2]. + + The runtime of [merge_skewed] is [O(l1 * log(l2))], where [l1] is the length + of the smaller map and [l2] the length of the larger map. This is likely to + be faster than [merge] when one of the maps is a lot smaller, or when you + merge a list of maps. *) + val merge_skewed + : ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> combine:(key:'k -> 'v -> 'v -> 'v) + -> ('k, 'v, 'cmp) t + + module Symmetric_diff_element : sig + type ('k, 'v) t = 'k * [ `Left of 'v | `Right of 'v | `Unequal of 'v * 'v ] + [@@deriving_inline compare, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('k -> 'k -> int) -> + ('v -> 'v -> int) -> ('k, 'v) t -> ('k, 'v) t -> int + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('k,'v) t := ('k, 'v) t + end[@@ocaml.doc "@inline"] + [@@@end] + end + + (** [symmetric_diff t1 t2 ~data_equal] returns a list of changes between [t1] and [t2]. + It is intended to be efficient in the case where [t1] and [t2] share a large amount of + structure. The keys in the output sequence will be in sorted order. *) + val symmetric_diff + : ('k, 'v, 'cmp) t + -> ('k, 'v, 'cmp) t + -> data_equal:('v -> 'v -> bool) + -> ('k, 'v) Symmetric_diff_element.t Sequence.t + + (** [min_elt map] returns [Some (key, data)] pair corresponding to the minimum key in + [map], or [None] if empty. *) + val min_elt : ('k, 'v, _) t -> ('k * 'v) option + val min_elt_exn : ('k, 'v, _) t -> 'k * 'v + + (** [max_elt map] returns [Some (key, data)] pair corresponding to the maximum key in + [map], or [None] if [map] is empty. *) + val max_elt : ('k, 'v, _) t -> ('k * 'v) option + val max_elt_exn : ('k, 'v, _) t -> 'k * 'v + + (** These functions have the same semantics as similar functions in [List]. *) + + val for_all : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val for_alli : ('k, 'v, _) t -> f:(key:'k -> data:'v -> bool) -> bool + val exists : ('k, 'v, _) t -> f:( 'v -> bool) -> bool + val existsi : ('k, 'v, _) t -> f:(key:'k -> data:'v -> bool) -> bool + val count : ('k, 'v, _) t -> f:( 'v -> bool) -> int + val counti : ('k, 'v, _) t -> f:(key:'k -> data:'v -> bool) -> int + + (** [split t key] returns a map of keys strictly less than [key], the mapping of [key] if + any, and a map of keys strictly greater than [key]. + + Runtime is O(m + log n), where n is the size of the input map and m is the size of + the smaller of the two output maps. The O(m) term is due to the need to calculate + the length of the output maps. *) + val split + : ('k, 'v, 'cmp) t + -> 'k + -> ('k, 'v, 'cmp) t * ('k * 'v) option * ('k, 'v, 'cmp) t + + (** [append ~lower_part ~upper_part] returns [`Ok map] where [map] contains all the + [(key, value)] pairs from the two input maps if all the keys from [lower_part] are + less than all the keys from [upper_part]. Otherwise it returns + [`Overlapping_key_ranges]. + + Runtime is O(log n) where n is the size of the larger input map. This can be + significantly faster than [Map.merge] or repeated [Map.add]. + + {[ + assert (match Map.append ~lower_part ~upper_part with + | `Ok whole_map -> + Map.to_alist whole_map + = List.append (to_alist lower_part) (to_alist upper_part) + | `Overlapping_key_ranges -> true); + ]} *) + val append + : lower_part:('k, 'v, 'cmp) t + -> upper_part:('k, 'v, 'cmp) t + -> [ `Ok of ('k, 'v, 'cmp) t + | `Overlapping_key_ranges ] + + (** [subrange t ~lower_bound ~upper_bound] returns a map containing all the entries from + [t] whose keys lie inside the interval indicated by [~lower_bound] and + [~upper_bound]. If this interval is empty, an empty map is returned. + + Runtime is O(m + log n), where n is the size of the input map and m is the size of + the output map. The O(m) term is due to the need to calculate the length of the + output map. *) + val subrange + : ('k, 'v, 'cmp) t + -> lower_bound:'k Maybe_bound.t + -> upper_bound:'k Maybe_bound.t + -> ('k, 'v, 'cmp) t + + (** [fold_range_inclusive t ~min ~max ~init ~f] folds [f] (with initial value [~init]) + over all keys (and their associated values) that are in the range [[min, max]] + (inclusive). *) + val fold_range_inclusive + : ('k, 'v, 'cmp) t + -> min:'k + -> max:'k + -> init:'a + -> f:(key:'k -> data:'v -> 'a -> 'a) + -> 'a + + (** [range_to_alist t ~min ~max] returns an associative list of the elements whose keys + lie in [[min, max]] (inclusive), with the smallest key being at the head of the + list. *) + val range_to_alist : ('k, 'v, 'cmp) t -> min:'k -> max:'k -> ('k * 'v) list + + (** [closest_key t dir k] returns the [(key, value)] pair in [t] with [key] closest to + [k] that satisfies the given inequality bound. + + For example, [closest_key t `Less_than k] would be the pair with the closest key to + [k] where [key < k]. + + [to_sequence] can be used to get the same results as [closest_key]. It is less + efficient for individual lookups but more efficient for finding many elements starting + at some value. *) + val closest_key + : ('k, 'v, 'cmp) t + -> [ `Greater_or_equal_to + | `Greater_than + | `Less_or_equal_to + | `Less_than + ] + -> 'k + -> ('k * 'v) option + + (** [nth t n] finds the (key, value) pair of rank n (i.e., such that there are exactly n + keys strictly less than the found key), if one exists. O(log(length t) + n) time. *) + val nth : ('k, 'v, _) t -> int -> ('k * 'v) option + val nth_exn : ('k, 'v, _) t -> int -> ('k * 'v) + + (** [rank t k] If [k] is in [t], returns the number of keys strictly less than [k] in + [t], and [None] otherwise. *) + val rank : ('k, 'v, 'cmp) t -> 'k -> int option + + (** [to_sequence ?order ?keys_greater_or_equal_to ?keys_less_or_equal_to t] gives a + sequence of key-value pairs between [keys_less_or_equal_to] and + [keys_greater_or_equal_to] inclusive, presented in [order]. If + [keys_greater_or_equal_to > keys_less_or_equal_to], the sequence is empty. Cost is + O(log n) up front and amortized O(1) to produce each element. *) + val to_sequence + : ?order : [ `Increasing_key (** default *) | `Decreasing_key ] + -> ?keys_greater_or_equal_to : 'k + -> ?keys_less_or_equal_to : 'k + -> ('k, 'v, 'cmp) t + -> ('k * 'v) Sequence.t + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_to_int_map = int Map.M(String).t + ]} + + which stands for: + + {[ + type string_to_int_map = (String.t, int, String.comparator_witness) Map.t + ]} + + The point is that [int Map.M(String).t] supports deriving, whereas the second syntax + doesn't (because there is no such thing as, say, [String.sexp_of_comparator_witness] + -- instead you would want to pass the comparator directly). + + In addition, the requirements to use [@@deriving_inline][@@@end] on the key module are only those + needed to satisfy what you are trying to derive on the map itself. Say you write: + + {[ + type t = int Map.M(X).t [@@deriving_inline hash][@@@end] + ]} + + then this will be well typed exactly if [X] contains at least: + - a type [t] with no parameters + - a comparator witness + - a [hash_fold_t] function with the right type *) + module M (K : sig type t type comparator_witness end) : sig + type nonrec 'v t = (K.t, 'v, K.comparator_witness) t + end + + include For_deriving + with type ('key, 'value, 'cmp) t := ('key, 'value, 'cmp) t + + (** A polymorphic Map. *) + module Poly : S_poly + with type ('key, +'value) t = ('key, 'value, Comparator.Poly.comparator_witness) t + + (** [Using_comparator] is a similar interface as the toplevel of [Map], except the + functions take a [~comparator:('k, 'cmp) Comparator.t], whereas the functions at the + toplevel of [Map] take a [('k, 'cmp) comparator]. *) + module Using_comparator : sig + type nonrec ('k, +'v, 'cmp) t = ('k, 'v, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('k -> Ppx_sexp_conv_lib.Sexp.t) -> + ('v -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('k, 'v, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('k, 'cmp) Comparator.t + -> (Sexp.t -> 'k) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v, 'cmp) t + + module Tree : sig + type ('k, +'v, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('k -> Ppx_sexp_conv_lib.Sexp.t) -> + ('v -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('k, 'v, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('k, 'cmp) Comparator.t + -> (Sexp.t -> 'k) + -> (Sexp.t -> 'v) + -> Sexp.t + -> ('k, 'v, 'cmp) t + + include Creators_and_accessors3_with_comparator + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) t + + val empty_without_value_restriction : (_, _, _) t + end + + include Accessors3 + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + include Creators3_with_comparator + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Tree.t + + val comparator : ('a, _, 'cmp) t -> ('a, 'cmp) Comparator.t + + val hash_fold_direct + : 'k Hash.folder + -> 'v Hash.folder + -> ('k, 'v, 'cmp) t Hash.folder + + (** To get around the value restriction, apply the functor and include it. You + can see an example of this in the [Poly] submodule below. *) + module Empty_without_value_restriction (K : Comparator.S1) : sig + val empty : ('a K.t, 'v, K.comparator_witness) t + end + end + + + (** {2 Modules and module types for extending [Map]} + + For use in extensions of Base, like [Core_kernel]. *) + + module With_comparator = With_comparator + module With_first_class_module = With_first_class_module + module Without_comparator = Without_comparator + + module type For_deriving = For_deriving + + module type S_poly = S_poly + module type Accessors1 = Accessors1 + module type Accessors2 = Accessors2 + module type Accessors3 = Accessors3 + module type Accessors3_with_comparator = Accessors3_with_comparator + module type Accessors_generic = Accessors_generic + module type Creators1 = Creators1 + module type Creators2 = Creators2 + module type Creators3_with_comparator = Creators3_with_comparator + module type Creators_and_accessors1 = Creators_and_accessors1 + module type Creators_and_accessors2 = Creators_and_accessors2 + module type Creators_and_accessors3_with_comparator = Creators_and_accessors3_with_comparator + module type Creators_and_accessors_generic = Creators_and_accessors_generic + module type Creators_generic = Creators_generic +end diff --git a/src/maybe_bound.ml b/src/maybe_bound.ml new file mode 100644 index 0000000..16637ba --- /dev/null +++ b/src/maybe_bound.ml @@ -0,0 +1,164 @@ +open! Import + +type 'a t = Incl of 'a | Excl of 'a | Unbounded [@@deriving_inline enumerate, sexp] +let all : 'a . 'a list -> 'a t list = + fun _all_of_a -> + Ppx_enumerate_lib.List.append + (let rec map l acc = + match l with + | [] -> Ppx_enumerate_lib.List.rev acc + | enumerate__001_::l -> map l ((Incl enumerate__001_) :: acc) in + map _all_of_a []) + (Ppx_enumerate_lib.List.append + (let rec map l acc = + match l with + | [] -> Ppx_enumerate_lib.List.rev acc + | enumerate__002_::l -> map l ((Excl enumerate__002_) :: acc) in + map _all_of_a []) [Unbounded]) +let t_of_sexp : type a. + (Ppx_sexp_conv_lib.Sexp.t -> a) -> Ppx_sexp_conv_lib.Sexp.t -> a t = + let _tp_loc = "src/maybe_bound.ml.t" in + fun _of_a -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("incl"|"Incl" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Incl v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc _tag + _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("excl"|"Excl" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Excl v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc _tag + _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("unbounded"|"Unbounded") -> Unbounded + | Ppx_sexp_conv_lib.Sexp.Atom ("incl"|"Incl") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("excl"|"Excl") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("unbounded"|"Unbounded"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : type a. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | Incl v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Incl"; v0] + | Excl v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Excl"; v0] + | Unbounded -> Ppx_sexp_conv_lib.Sexp.Atom "Unbounded" +[@@@end] + +type interval_comparison = + | Below_lower_bound + | In_range + | Above_upper_bound +[@@deriving_inline sexp, compare, hash] +let interval_comparison_of_sexp : + Ppx_sexp_conv_lib.Sexp.t -> interval_comparison = + let _tp_loc = "src/maybe_bound.ml.interval_comparison" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("below_lower_bound"|"Below_lower_bound") -> + Below_lower_bound + | Ppx_sexp_conv_lib.Sexp.Atom ("in_range"|"In_range") -> In_range + | Ppx_sexp_conv_lib.Sexp.Atom ("above_upper_bound"|"Above_upper_bound") -> + Above_upper_bound + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("below_lower_bound"|"Below_lower_bound"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("in_range"|"In_range"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("above_upper_bound"|"Above_upper_bound"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_interval_comparison : + interval_comparison -> Ppx_sexp_conv_lib.Sexp.t = + function + | Below_lower_bound -> Ppx_sexp_conv_lib.Sexp.Atom "Below_lower_bound" + | In_range -> Ppx_sexp_conv_lib.Sexp.Atom "In_range" + | Above_upper_bound -> Ppx_sexp_conv_lib.Sexp.Atom "Above_upper_bound" +let compare_interval_comparison : + interval_comparison -> interval_comparison -> int = + Ppx_compare_lib.polymorphic_compare +let (hash_fold_interval_comparison : + Ppx_hash_lib.Std.Hash.state -> + interval_comparison -> Ppx_hash_lib.Std.Hash.state) + = + (fun hsv -> + fun arg -> + match arg with + | Below_lower_bound -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | In_range -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Above_upper_bound -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 : + Ppx_hash_lib.Std.Hash.state -> + interval_comparison -> Ppx_hash_lib.Std.Hash.state) +let (hash_interval_comparison : + interval_comparison -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in + hash_fold_interval_comparison hsv arg) in + fun x -> func x +[@@@end] + +let map t ~f = + match t with + | Incl incl -> Incl (f incl) + | Excl excl -> Excl (f excl) + | Unbounded -> Unbounded + +let is_lower_bound t ~of_:a ~compare = + match t with + | Incl incl -> compare incl a <= 0 + | Excl excl -> compare excl a < 0 + | Unbounded -> true + +let is_upper_bound t ~of_:a ~compare = + match t with + | Incl incl -> compare a incl <= 0 + | Excl excl -> compare a excl < 0 + | Unbounded -> true + +let bounds_crossed ~lower ~upper ~compare = + match lower with + | Unbounded -> false + | (Incl lower | Excl lower) -> + match upper with + | Unbounded -> false + | (Incl upper | Excl upper) -> + compare lower upper > 0 + +let check_interval_exn ~lower ~upper ~compare = + if bounds_crossed ~lower ~upper ~compare then + failwith "Maybe_bound.compare_to_interval_exn: lower bound > upper bound" + +let compare_to_interval_exn ~lower ~upper a ~compare = + check_interval_exn ~lower ~upper ~compare; + if not (is_lower_bound lower ~of_:a ~compare) then Below_lower_bound else + if not (is_upper_bound upper ~of_:a ~compare) then Above_upper_bound else + In_range + +let interval_contains_exn ~lower ~upper a ~compare = + match compare_to_interval_exn ~lower ~upper a ~compare with + | In_range -> true + | Below_lower_bound + | Above_upper_bound -> false + diff --git a/src/maybe_bound.mli b/src/maybe_bound.mli new file mode 100644 index 0000000..3553e56 --- /dev/null +++ b/src/maybe_bound.mli @@ -0,0 +1,63 @@ +(** Used for specifying a bound (either upper or lower) as inclusive, exclusive, or + unbounded. *) + +open! Import + +type 'a t = Incl of 'a | Excl of 'a | Unbounded [@@deriving_inline enumerate, sexp] +include +sig + [@@@ocaml.warning "-32"] + val all : 'a list -> 'a t list + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +val map : 'a t -> f:('a -> 'b) -> 'b t + +val is_lower_bound : 'a t -> of_:'a -> compare:('a -> 'a -> int) -> bool +val is_upper_bound : 'a t -> of_:'a -> compare:('a -> 'a -> int) -> bool + +(** [interval_contains_exn ~lower ~upper x ~compare] raises if [lower] and [upper] are + crossed. *) +val interval_contains_exn + : lower : 'a t + -> upper : 'a t + -> 'a + -> compare : ('a -> 'a -> int) + -> bool + +(** [bounds_crossed ~lower ~upper ~compare] returns true if [lower > upper]. + + It ignores whether the bounds are [Incl] or [Excl]. *) +val bounds_crossed: lower:'a t -> upper: 'a t -> compare:('a -> 'a -> int) -> bool + +type interval_comparison = + | Below_lower_bound + | In_range + | Above_upper_bound +[@@deriving_inline sexp, compare, hash] +include +sig + [@@@ocaml.warning "-32"] + val sexp_of_interval_comparison : + interval_comparison -> Ppx_sexp_conv_lib.Sexp.t + val interval_comparison_of_sexp : + Ppx_sexp_conv_lib.Sexp.t -> interval_comparison + val compare_interval_comparison : + interval_comparison -> interval_comparison -> int + val hash_fold_interval_comparison : + Ppx_hash_lib.Std.Hash.state -> + interval_comparison -> Ppx_hash_lib.Std.Hash.state + val hash_interval_comparison : + interval_comparison -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +(** [compare_to_interval_exn ~lower ~upper x ~compare] raises if [lower] and [upper] are + crossed. *) +val compare_to_interval_exn + : lower : 'a t + -> upper : 'a t + -> 'a + -> compare : ('a -> 'a -> int) + -> interval_comparison diff --git a/src/monad.ml b/src/monad.ml new file mode 100644 index 0000000..d7ccf25 --- /dev/null +++ b/src/monad.ml @@ -0,0 +1,107 @@ +open! Import + +module List = List0 + +include Monad_intf + +module type Basic_general = sig + type ('a, 'i, 'j, 'd, 'e) t + + val bind + : ('a, 'i, 'j, 'd, 'e) t + -> f:('a -> ('b, 'j, 'k, 'd, 'e) t) + -> ('b, 'i, 'k, 'd, 'e) t + + val map + : [ `Define_using_bind + | `Custom of (('a, 'i, 'j, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'i, 'j, 'd, 'e) t) + ] + + val return : 'a -> ('a, 'i, 'i, 'd, 'e) t +end + +module Make_general (M : Basic_general) = struct + + let bind = M.bind + let return = M.return + + let map_via_bind ma ~f = M.bind ma ~f:(fun a -> M.return (f a)) + + let map = + match M.map with + | `Define_using_bind -> map_via_bind + | `Custom x -> x + + module Monad_infix = struct + let (>>=) t f = bind t ~f + let (>>|) t f = map t ~f + end + include Monad_infix + + module Let_syntax = struct + + let return = return + include Monad_infix + + module Let_syntax = struct + let return = return + let bind = bind + let map = map + let both a b = a >>= fun a -> b >>| fun b -> (a, b) + module Open_on_rhs = struct end + end + end + + let join t = t >>= fun t' -> t' + + let ignore_m t = map t ~f:(fun _ -> ()) + + let all = + let rec loop vs = function + | [] -> return (List.rev vs) + | t :: ts -> t >>= fun v -> loop (v :: vs) ts + in + fun ts -> loop [] ts + + let rec all_unit = function + | [] -> return () + | t :: ts -> t >>= fun () -> all_unit ts + + let all_ignore = all_unit + +end + +module Make_indexed (M : Basic_indexed) + : S_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = ('a, 'i, 'j) M.t + include (M : Basic_indexed with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t) + end) + +module Make3 (M : Basic3) : S3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = ('a, 'd, 'e) M.t + include (M : Basic3 with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t) + end) + +module Make2 (M : Basic2) : S2 with type ('a, 'd) t := ('a, 'd) M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = ('a, 'd) M.t + include (M : Basic2 with type ('a, 'b) t := ('a, 'b) M.t) + end) + +module Make (M : Basic) : S with type 'a t := 'a M.t = + Make_general (struct + type ('a, 'i, 'j, 'd, 'e) t = 'a M.t + include (M : Basic with type 'a t := 'a M.t) + end) + +module Ident = struct + type 'a t = 'a + include Make (struct + type nonrec 'a t = 'a t + let bind a ~f = f a + let return a = a + let map = `Custom (fun a ~f -> f a) + end) +end diff --git a/src/monad.mli b/src/monad.mli new file mode 100644 index 0000000..65be892 --- /dev/null +++ b/src/monad.mli @@ -0,0 +1 @@ +include Monad_intf.Monad (** @inline *) diff --git a/src/monad_intf.ml b/src/monad_intf.ml new file mode 100644 index 0000000..34ae035 --- /dev/null +++ b/src/monad_intf.ml @@ -0,0 +1,365 @@ +open! Import + +module type Basic = sig + type 'a t + val bind : 'a t -> f:('a -> 'b t) -> 'b t + val return : 'a -> 'a t + (** The following identities ought to hold (for some value of =): + + - [return x >>= f = f x] + - [t >>= fun x -> return x = t] + - [(t >>= f) >>= g = t >>= fun x -> (f x >>= g)] + + Note: [>>=] is the infix notation for [bind]) *) + + (** The [map] argument to [Monad.Make] says how to implement the monad's [map] function. + [`Define_using_bind] means to define [map t ~f = bind t ~f:(fun a -> return (f a))]. + [`Custom] overrides the default implementation, presumably with something more + efficient. + + Some other functions returned by [Monad.Make] are defined in terms of [map], so + passing in a more efficient [map] will improve their efficiency as well. *) + val map : [ `Define_using_bind + | `Custom of ('a t -> f:('a -> 'b) -> 'b t) + ] +end + +module type Infix = sig + type 'a t + + (** [t >>= f] returns a computation that sequences the computations represented by two + monad elements. The resulting computation first does [t] to yield a value [v], and + then runs the computation returned by [f v]. *) + val (>>=) : 'a t -> ('a -> 'b t) -> 'b t + + (** [t >>| f] is [t >>= (fun a -> return (f a))]. *) + val (>>|) : 'a t -> ('a -> 'b) -> 'b t +end + +(** Opening a module of this type allows one to use the [%bind] and [%map] syntax + extensions defined by ppx_let, and brings [return] into scope. *) +module type Syntax = sig + type 'a t + module Let_syntax : sig + + (** These are convenient to have in scope when programming with a monad: *) + + val return : 'a -> 'a t + include Infix with type 'a t := 'a t + + module Let_syntax : sig + val return : 'a -> 'a t + val bind : 'a t -> f:('a -> 'b t) -> 'b t + val map : 'a t -> f:('a -> 'b) -> 'b t + val both : 'a t -> 'b t -> ('a * 'b) t + module Open_on_rhs : sig end + end + end +end + +module type S_without_syntax = sig + type 'a t + include Infix with type 'a t := 'a t + + module Monad_infix : Infix with type 'a t := 'a t + + (** [bind t ~f] = [t >>= f] *) + val bind : 'a t -> f:('a -> 'b t) -> 'b t + + (** [return v] returns the (trivial) computation that returns v. *) + val return : 'a -> 'a t + + (** [map t ~f] is t >>| f. *) + val map : 'a t -> f:('a -> 'b) -> 'b t + + (** [join t] is [t >>= (fun t' -> t')]. *) + val join : 'a t t -> 'a t + + (** [ignore_m t] is [map t ~f:(fun _ -> ())]. [ignore_m] used to be called [ignore], + but we decided that was a bad name, because it shadowed the widely used + [Caml.ignore]. Some monads still do [let ignore = ignore_m] for historical + reasons. *) + val ignore_m : 'a t -> unit t + + + val all : 'a t list -> 'a list t + + (** Like [all], but ensures that every monadic value in the list produces a unit value, + all of which are discarded rather than being collected into a list. *) + val all_unit : unit t list -> unit t + + val all_ignore : unit t list -> unit t [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +module type S = sig + type 'a t + include S_without_syntax with type 'a t := 'a t + include Syntax with type 'a t := 'a t +end + +(** Multi parameter monad. The second parameter gets unified across all the computation. + This is used to encode monads working on a multi parameter data structure like + ([('a,'b) result]). *) +module type Basic2 = sig + type ('a, 'e) t + val bind : ('a, 'e) t -> f:('a -> ('b, 'e) t) -> ('b, 'e) t + val map : [ `Define_using_bind + | `Custom of (('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t) + ] + val return : 'a -> ('a, _) t +end + +(** Same as Infix, except the monad type has two arguments. The second is always just + passed through. *) +module type Infix2 = sig + type ('a, 'e) t + val (>>=) : ('a, 'e) t -> ('a -> ('b, 'e) t) -> ('b, 'e) t + val (>>|) : ('a, 'e) t -> ('a -> 'b) -> ('b, 'e) t +end + +module type Syntax2 = sig + type ('a, 'e) t + + module Let_syntax : sig + val return : 'a -> ('a, _) t + include Infix2 with type ('a,'e) t := ('a,'e) t + module Let_syntax : sig + val return : 'a -> ('a, _) t + val bind : ('a, 'e) t -> f:('a -> ('b, 'e) t) -> ('b, 'e) t + val map : ('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t + val both : ('a, 'e) t -> ('b, 'e) t -> ('a * 'b, 'e) t + module Open_on_rhs : sig end + end + end +end + +(** The same as S except the monad type has two arguments. The second is always just + passed through. *) +module type S2 = sig + type ('a, 'e) t + include Infix2 with type ('a, 'e) t := ('a, 'e) t + include Syntax2 with type ('a, 'e) t := ('a, 'e) t + + module Monad_infix : Infix2 with type ('a, 'e) t := ('a, 'e) t + + val bind : ('a, 'e) t -> f:('a -> ('b, 'e) t) -> ('b, 'e) t + + val return : 'a -> ('a, _) t + + val map : ('a, 'e) t -> f:('a -> 'b) -> ('b, 'e) t + + val join : (('a, 'e) t, 'e) t -> ('a, 'e) t + + val ignore_m : (_, 'e) t -> (unit, 'e) t + + val all : ('a, 'e) t list -> ('a list, 'e) t + + val all_unit : (unit, 'e) t list -> (unit, 'e) t + + val all_ignore : (unit, 'e) t list -> (unit, 'e) t + [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +(** Multi parameter monad. The second and third parameters get unified across all the + computation. *) +module type Basic3 = sig + type ('a, 'd, 'e) t + val bind : ('a, 'd, 'e) t -> f:('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + val map : [ `Define_using_bind + | `Custom of (('a, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'd, 'e) t) + ] + val return : 'a -> ('a, _, _) t +end + +(** Same as Infix, except the monad type has three arguments. The second and third are + always just passed through. *) +module type Infix3 = sig + type ('a, 'd, 'e) t + val (>>=) : ('a, 'd, 'e) t -> ('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + val (>>|) : ('a, 'd, 'e) t -> ('a -> 'b) -> ('b, 'd, 'e) t +end + +module type Syntax3 = sig + type ('a, 'd, 'e) t + + module Let_syntax : sig + val return : 'a -> ('a, _, _) t + include Infix3 with type ('a,'d,'e) t := ('a,'d,'e) t + module Let_syntax : sig + val return : 'a -> ('a, _, _) t + val bind : ('a, 'd, 'e) t -> f:('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + val map : ('a, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'd, 'e) t + val both : ('a, 'd, 'e) t -> ('b, 'd, 'e) t -> ('a * 'b, 'd, 'e) t + module Open_on_rhs : sig end + end + end +end + +(** The same as S except the monad type has three arguments. The second and third are + always just passed through. *) +module type S3 = sig + type ('a, 'd, 'e) t + include Infix3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) t + include Syntax3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) t + + module Monad_infix : Infix3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) t + + val bind : ('a, 'd, 'e) t -> f:('a -> ('b, 'd, 'e) t) -> ('b, 'd, 'e) t + + val return : 'a -> ('a, _, _) t + + val map : ('a, 'd, 'e) t -> f:('a -> 'b) -> ('b, 'd, 'e) t + + val join : (('a, 'd, 'e) t, 'd, 'e) t -> ('a, 'd, 'e) t + + val ignore_m : (_, 'd, 'e) t -> (unit, 'd, 'e) t + + val all : ('a, 'd, 'e) t list -> ('a list, 'd, 'e) t + + val all_unit : (unit, 'd, 'e) t list -> (unit, 'd, 'e) t + + val all_ignore : (unit, 'd, 'e) t list -> (unit, 'd, 'e) t + [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +(** Indexed monad, in the style of Atkey. The second and third parameters are composed + across all computation. To see this more clearly, you can look at the type of bind: + + {[ + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + ]} + + and isolate some of the type variables to see their individual behaviors: + + {[ + val bind : 'a -> f:('a -> 'b ) -> 'b + val bind : 'i, 'j -> 'j, 'k -> 'i, 'k + ]} + + For more information on Atkey-style indexed monads, see: + + {v + Parameterised Notions of Computation + Robert Atkey + http://bentnib.org/paramnotions-jfp.pdf + v} *) +module type Basic_indexed = sig + type ('a, 'i, 'j) t + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + val map : [ `Define_using_bind + | `Custom of (('a, 'i, 'j) t -> f:('a -> 'b) -> ('b, 'i, 'j) t) + ] + val return : 'a -> ('a, 'i, 'i) t +end + +(** Same as Infix, except the monad type has three arguments. The second and third are + compose across all computation. *) +module type Infix_indexed = sig + type ('a, 'i, 'j) t + val (>>=) : ('a, 'i, 'j) t -> ('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + val (>>|) : ('a, 'i, 'j) t -> ('a -> 'b) -> ('b, 'i, 'j) t +end + +module type Syntax_indexed = sig + type ('a, 'i, 'j) t + + module Let_syntax : sig + val return : 'a -> ('a, 'i, 'i) t + include Infix_indexed with type ('a,'i,'j) t := ('a,'i,'j) t + module Let_syntax : sig + val return : 'a -> ('a, 'i, 'i) t + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + val map : ('a, 'i, 'j) t -> f:('a -> 'b) -> ('b, 'i, 'j) t + val both : ('a, 'i, 'j) t -> ('b, 'j, 'k) t -> ('a * 'b, 'i, 'k) t + module Open_on_rhs : sig end + end + end +end + +(** The same as S except the monad type has three arguments. The second and third are + composed across all computation. *) +module type S_indexed = sig + type ('a, 'i, 'j) t + include Infix_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) t + include Syntax_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) t + + module Monad_infix : Infix_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) t + + val bind : ('a, 'i, 'j) t -> f:('a -> ('b, 'j, 'k) t) -> ('b, 'i, 'k) t + + val return : 'a -> ('a, 'i, 'i) t + + val map : ('a, 'i, 'j) t -> f:('a -> 'b) -> ('b, 'i, 'j) t + + val join : (('a, 'j, 'k) t, 'i, 'j) t -> ('a, 'i, 'k) t + + val ignore_m : (_, 'i, 'j) t -> (unit, 'i, 'j) t + + val all : ('a, 'i, 'i) t list -> ('a list, 'i, 'i) t + + val all_unit : (unit, 'i, 'i) t list -> (unit, 'i, 'i) t + + val all_ignore : (unit, 'i, 'i) t list -> (unit, 'i, 'i) t + [@@deprecated "[since 2018-02] Use [all_unit]"] +end + +module S_to_S2 (X : S) : (S2 with type ('a, 'e) t = 'a X.t) = struct + type ('a, 'e) t = 'a X.t + include (X : S with type 'a t := 'a X.t) +end + +module S2_to_S3 (X : S2) : (S3 with type ('a, 'd, 'e) t = ('a, 'd) X.t) = struct + type ('a, 'd, 'e) t = ('a, 'd) X.t + include (X : S2 with type ('a, 'd) t := ('a, 'd) X.t) +end + +module S_to_S_indexed (X : S) : (S_indexed with type ('a, 'i, 'j) t = 'a X.t) = struct + type ('a, 'i, 'j) t = 'a X.t + include (X : S with type 'a t := 'a X.t) +end + +module S2_to_S (X : S2) : (S with type 'a t = ('a, unit) X.t) = struct + type 'a t = ('a, unit) X.t + include (X : S2 with type ('a, 'e) t := ('a, 'e) X.t) +end + +module S3_to_S2 (X : S3) : (S2 with type ('a, 'e) t = ('a, 'e, unit) X.t) = struct + type ('a, 'e) t = ('a, 'e, unit) X.t + include (X : S3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) X.t) +end + +module S_indexed_to_S2 (X : S_indexed) : (S2 with type ('a, 'e) t = ('a, 'e, 'e) X.t) = +struct + type ('a, 'e) t = ('a, 'e, 'e) X.t + include (X : S_indexed with type ('a, 'i, 'j) t := ('a, 'i, 'j) X.t) +end + +module type Monad = sig + (** A monad is an abstraction of the concept of sequencing of computations. A value of + type ['a monad] represents a computation that returns a value of type ['a]. *) + + module type Basic = Basic + module type Basic2 = Basic2 + module type Basic3 = Basic3 + module type Basic_indexed = Basic_indexed + module type Infix = Infix + module type Infix2 = Infix2 + module type Infix3 = Infix3 + module type Infix_indexed = Infix_indexed + module type Syntax = Syntax + module type Syntax2 = Syntax2 + module type Syntax3 = Syntax3 + module type Syntax_indexed = Syntax_indexed + module type S_without_syntax = S_without_syntax + module type S = S + module type S2 = S2 + module type S3 = S3 + module type S_indexed = S_indexed + + module Make (X : Basic ) : S with type 'a t := 'a X.t + module Make2 (X : Basic2) : S2 with type ('a, 'e) t := ('a, 'e) X.t + module Make3 (X : Basic3) : S3 with type ('a, 'd, 'e) t := ('a, 'd, 'e) X.t + module Make_indexed (X : Basic_indexed) : S_indexed with type ('a, 'd, 'e) t := ('a, 'd, 'e) X.t + + module Ident : S with type 'a t = 'a +end diff --git a/src/nativeint.ml b/src/nativeint.ml new file mode 100644 index 0000000..914fb03 --- /dev/null +++ b/src/nativeint.ml @@ -0,0 +1,258 @@ +open! Import +open! Caml.Nativeint +include Nativeint_replace_polymorphic_compare + +module T = struct + type t = nativeint [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_nativeint + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_nativeint in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = nativeint_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_nativeint + [@@@end] + let compare = Nativeint_replace_polymorphic_compare.compare + + let to_string = to_string + let of_string = of_string +end + +include T +include Comparator.Make(T) +include Comparable.Validate_with_zero (struct + include T + let zero = zero + end) + + +module Conv = Int_conversions +include Conv.Make (T) +include Conv.Make_hex(struct + open Nativeint_replace_polymorphic_compare + type t = nativeint [@@deriving_inline compare, hash] + let compare : t -> t -> int = compare_nativeint + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_nativeint + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_nativeint in fun x -> func x + [@@@end] + + let zero = zero + let neg = neg + let (<) = (<) + let to_string i = Printf.sprintf "%nx" i + let of_string s = Caml.Scanf.sscanf s "%nx" Fn.id + + let module_name = "Base.Nativeint.Hex" + end) + +include Pretty_printer.Register (struct + type nonrec t = t + let to_string = to_string + let module_name = "Base.Nativeint" + end) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Nativeint_replace_polymorphic_compare + +let num_bits = Word_size.num_bits Word_size.word_size +let float_lower_bound = Float0.lower_bound_for_int num_bits +let float_upper_bound = Float0.upper_bound_for_int num_bits + +let shift_right_logical = shift_right_logical +let shift_right = shift_right +let shift_left = shift_left +let bit_not = lognot +let bit_xor = logxor +let bit_or = logor +let bit_and = logand +let min_value = min_int +let max_value = max_int +let abs = abs +let pred = pred +let succ = succ +let rem = rem +let neg = neg +let minus_one = minus_one +let one = one +let zero = zero +let to_float = to_float +let of_float_unchecked = of_float +let of_float f = + if Float_replace_polymorphic_compare.(>=) f float_lower_bound + && Float_replace_polymorphic_compare.(<=) f float_upper_bound + then + of_float f + else + Printf.invalid_argf "Nativeint.of_float: argument (%f) is out of range or NaN" + (Float0.box f) + () + +module Pow2 = struct + open! Import + open Nativeint_replace_polymorphic_compare + + module Sys = Sys0 + + let raise_s = Error.raise_s + + let non_positive_argument () = + Printf.invalid_argf "argument must be strictly positive" () + + let ( lor ) = Caml.Nativeint.logor;; + let ( lsr ) = Caml.Nativeint.shift_right_logical;; + let ( land ) = Caml.Nativeint.logand;; + + (** "ceiling power of 2" - Least power of 2 greater than or equal to x. *) + let ceil_pow2 (x : nativeint) = + if x <= 0n then non_positive_argument (); + let x = Caml.Nativeint.pred x in + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + (* The next line is superfluous on 32-bit architectures, but it's faster to do it + anyway than to branch *) + let x = x lor (x lsr 32) in + Caml.Nativeint.succ x + ;; + + (** "floor power of 2" - Largest power of 2 less than or equal to x. *) + let floor_pow2 x = + if x <= 0n then non_positive_argument (); + let x = x lor (x lsr 1) in + let x = x lor (x lsr 2) in + let x = x lor (x lsr 4) in + let x = x lor (x lsr 8) in + let x = x lor (x lsr 16) in + let x = x lor (x lsr 32) in + Caml.Nativeint.sub x (x lsr 1) + ;; + + let is_pow2 x = + if x <= 0n then non_positive_argument (); + (x land (Caml.Nativeint.pred x)) = 0n + ;; + + (* C stub for nativeint clz to use the CLZ/BSR instruction where possible *) + external nativeint_clz : nativeint -> int = "Base_int_math_nativeint_clz" [@@noalloc] + + (** Hacker's Delight Second Edition p106 *) + let floor_log2 i = + if Poly.( <= ) i Caml.Nativeint.zero then + raise_s (Sexp.message "[Nativeint.floor_log2] got invalid input" + ["", sexp_of_nativeint i]); + Sys.word_size_in_bits - 1 - nativeint_clz i + ;; + + (** Hacker's Delight Second Edition p106 *) + let ceil_log2 i = + if Poly.( <= ) i Caml.Nativeint.zero then + raise_s (Sexp.message "[Nativeint.ceil_log2] got invalid input" + ["", sexp_of_nativeint i]); + if Caml.Nativeint.equal i Caml.Nativeint.one + then 0 + else Sys.word_size_in_bits - nativeint_clz (Caml.Nativeint.pred i) + ;; +end +include Pow2 + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let ( / ) = div +let ( * ) = mul +let ( - ) = sub +let ( + ) = add +let ( ~- ) = neg + +let incr r = r := !r + one +let decr r = r := !r - one + +let of_nativeint t = t +let of_nativeint_exn = of_nativeint +let to_nativeint t = t +let to_nativeint_exn = to_nativeint + +let popcount = Popcount.nativeint_popcount + +let of_int = Conv.int_to_nativeint +let of_int_exn = of_int +let to_int = Conv.nativeint_to_int +let to_int_exn = Conv.nativeint_to_int_exn +let to_int_trunc = Conv.nativeint_to_int_trunc +let of_int32 = Conv.int32_to_nativeint +let of_int32_exn = of_int32 +let to_int32 = Conv.nativeint_to_int32 +let to_int32_exn = Conv.nativeint_to_int32_exn +let to_int32_trunc = Conv.nativeint_to_int32_trunc +let of_int64 = Conv.int64_to_nativeint +let of_int64_exn = Conv.int64_to_nativeint_exn +let of_int64_trunc = Conv.int64_to_nativeint_trunc +let to_int64 = Conv.nativeint_to_int64 + +let pow b e = of_int_exn (Int_math.int_pow (to_int_exn b) (to_int_exn e)) +let ( ** ) b e = pow b e + +module Pre_O = struct + let ( + ) = ( + ) + let ( - ) = ( - ) + let ( * ) = ( * ) + let ( / ) = ( / ) + let ( ~- ) = ( ~- ) + let ( ** ) = ( ** ) + include (Nativeint_replace_polymorphic_compare : Comparisons.Infix with type t := t) + let abs = abs + let neg = neg + let zero = zero + let of_int_exn = of_int_exn +end + +module O = struct + include Pre_O + include Int_math.Make (struct + type nonrec t = t + include Pre_O + let rem = rem + let to_float = to_float + let of_float = of_float + let of_string = T.of_string + let to_string = T.to_string + end) + + let ( land ) = bit_and + let ( lor ) = bit_or + let ( lxor ) = bit_xor + let ( lnot ) = bit_not + let ( lsl ) = shift_left + let ( asr ) = shift_right + let ( lsr ) = shift_right_logical +end + +include O (* [Nativeint] and [Nativeint.O] agree value-wise *) + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Nativeint_replace_polymorphic_compare diff --git a/src/nativeint.mli b/src/nativeint.mli new file mode 100644 index 0000000..0ba4d53 --- /dev/null +++ b/src/nativeint.mli @@ -0,0 +1,27 @@ +(** Processor-native integers. *) + +open! Import + +include Int_intf.S with type t = nativeint + +(** {2 Conversion functions} *) + +val of_int : int -> t +val to_int : t -> int option + +val of_int32 : int32 -> t +val to_int32 : t -> int32 option + +val of_nativeint : nativeint -> t +val to_nativeint : t -> nativeint + +val of_int64 : int64 -> t option + +(** {3 Truncating conversions} + + These functions return the least-significant bits of the input. In cases where + optional conversions return [Some x], truncating conversions return [x]. *) + +val to_int_trunc : t -> int +val to_int32_trunc : t -> int32 +val of_int64_trunc : int64 -> t diff --git a/src/obj_array.ml b/src/obj_array.ml new file mode 100644 index 0000000..0989345 --- /dev/null +++ b/src/obj_array.ml @@ -0,0 +1,179 @@ +open! Import + +module Int = Int0 +module String = String0 +module Array = Array0 + +(* We maintain the property that all values of type [t] do not have the tag + [double_array_tag]. Some functions below assume this in order to avoid testing the + tag, and will segfault if this property doesn't hold. *) +type t = Caml.Obj.t array + +let invariant t = + assert (Caml.Obj.tag (Caml.Obj.repr t) <> Caml.Obj.double_array_tag); +;; + +let length = Array.length + +let swap t i j = Array.swap t i j + +let sexp_of_t t = + Sexp.Atom (String.concat ~sep:"" + [ "<Obj_array.t of length "; + Int.to_string (length t); + ">" + ]) +;; + +let zero_obj = Caml.Obj.repr (0 : int) + +(* We call [Array.create] with a value that is not a float so that the array doesn't get + tagged with [Double_array_tag]. *) +let create_zero ~len = Array.create ~len zero_obj + +let create ~len x = + (* If we can, use [Array.create] directly. *) + if (Caml.Obj.tag x) <> Caml.Obj.double_tag then begin + Array.create ~len x + end else begin + (* Otherwise use [create_zero] and set the contents *) + let t = create_zero ~len in + let x = Sys.opaque_identity x in + for i = 0 to (len - 1) do + Array.unsafe_set t i x + done; + t + end + +let empty = [||] + +type not_a_float = Not_a_float_0 | Not_a_float_1 of int +let _not_a_float_0 = Not_a_float_0 +let _not_a_float_1 = Not_a_float_1 42 + +let get t i = + (* Make the compiler believe [t] is an array not containing floats so it does not check + if [t] is tagged with [Double_array_tag]. It is NOT ok to use [int array] since (if + this function is inlined and the array contains in-heap boxed values) wrong register + typing may result, leading to a failure to register necessary GC roots. *) + Caml.Obj.repr (Array.get (Caml.Obj.magic (t : t) : not_a_float array) i : not_a_float) +;; + +let [@inline always] unsafe_get t i = + (* Make the compiler believe [t] is an array not containing floats so it does not check + if [t] is tagged with [Double_array_tag]. *) + Caml.Obj.repr + (Array.unsafe_get (Caml.Obj.magic (t : t) : not_a_float array) i : not_a_float) + +let [@inline always] unsafe_set_with_caml_modify t i obj = + (* Same comment as [unsafe_get]. Sys.opaque_identity prevents the compiler from + potentially wrongly guessing the type of the array based on the type of element, that + is prevent the implication: (Obj.tag obj = Obj.double_tag) => (Obj.tag t = + Obj.double_array_tag) which flambda has tried in the past (at least that's assuming + the compiler respects Sys.opaque_identity, which is not always the case). *) + Array.unsafe_set + (Caml.Obj.magic (t : t) : not_a_float array) + i + (Caml.Obj.obj (Sys.opaque_identity obj) : not_a_float) +;; + +let [@inline always] unsafe_set_int_assuming_currently_int t i int = + (* This skips [caml_modify], which is OK if both the old and new values are integers. *) + Array.unsafe_set (Caml.Obj.magic (t : t) : int array) i (Sys.opaque_identity int) +;; + +(* For [set] and [unsafe_set], if a pointer is involved, we first do a physical-equality + test to see if the pointer is changing. If not, we don't need to do the [set], which + saves a call to [caml_modify]. We think this physical-equality test is worth it + because it is very cheap (both values are already available from the [is_int] test) + and because [caml_modify] is expensive. *) + +let set t i obj = + (* We use [get] first but then we use [Array.unsafe_set] since we know that [i] is + valid. *) + let old_obj = get t i in + if Caml.Obj.is_int old_obj && Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else if not (phys_equal old_obj obj) + then unsafe_set_with_caml_modify t i obj +;; + +let [@inline always] unsafe_set t i obj = + let old_obj = unsafe_get t i in + if Caml.Obj.is_int old_obj && Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else if not (phys_equal old_obj obj) + then unsafe_set_with_caml_modify t i obj +;; + +let [@inline always] unsafe_set_omit_phys_equal_check t i obj = + let old_obj = unsafe_get t i in + if Caml.Obj.is_int old_obj && Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else unsafe_set_with_caml_modify t i obj +;; + +let singleton obj = + create ~len:1 obj +;; + +(* Pre-condition: t.(i) is an integer. *) +let unsafe_set_assuming_currently_int t i obj = + if Caml.Obj.is_int obj + then unsafe_set_int_assuming_currently_int t i (Caml.Obj.obj obj : int) + else + (* [t.(i)] is an integer and [obj] is not, so we do not need to check if they are + equal. *) + unsafe_set_with_caml_modify t i obj +;; + +let unsafe_set_int t i int = + let old_obj = unsafe_get t i in + if Caml.Obj.is_int old_obj + then unsafe_set_int_assuming_currently_int t i int + else unsafe_set_with_caml_modify t i (Caml.Obj.repr int) +;; + +let unsafe_clear_if_pointer t i = + let old_obj = unsafe_get t i in + if not (Caml.Obj.is_int old_obj) then unsafe_set_with_caml_modify t i (Caml.Obj.repr 0); +;; + +(** [unsafe_blit] is like [Array.blit], except it uses our own for-loop to avoid + caml_modify when possible. Its performance is still not comparable to a memcpy. *) +let unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len = + (* When [phys_equal src dst], we need to check whether [dst_pos < src_pos] and have the + for loop go in the right direction so that we don't overwrite data that we still need + to read. When [not (phys_equal src dst)], doing this is harmless. From a + memory-performance perspective, it doesn't matter whether one loops up or down. + Constant-stride access, forward or backward, should be indistinguishable (at least on + an intel i7). So, we don't do a check for [phys_equal src dst] and always loop up in + that case. *) + if dst_pos < src_pos + then + for i = 0 to len - 1 do + unsafe_set dst (dst_pos + i) (unsafe_get src (src_pos + i)) + done + else + for i = len - 1 downto 0 do + unsafe_set dst (dst_pos + i) (unsafe_get src (src_pos + i)) + done; +;; + +include + Blit.Make + (struct + type nonrec t = t + let create = create_zero + let length = length + let unsafe_blit = unsafe_blit + end) +;; + +let copy src = + let dst = create_zero ~len:(length src) in + blito ~src ~dst (); + dst +;; + +let truncate t ~len = Caml.Obj.truncate (Caml.Obj.repr (t : t)) len diff --git a/src/obj_array.mli b/src/obj_array.mli new file mode 100644 index 0000000..2481bb9 --- /dev/null +++ b/src/obj_array.mli @@ -0,0 +1,72 @@ +(** This module is deprecated for external use. Users should replace occurrences of + [Obj_array.t] in their code with [Obj.t Uniform_array.t]. + + This module is here for the implementing [Uniform_array] internally, and exposed + through [Not_exposed_properly] to ease the transition for users. +*) + +open! Import + +type t [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Blit. S with type t := t +include Invariant.S with type t := t + +(** [create ~len x] returns an obj-array of length [len], all of whose indices have value + [x]. *) +val create : len:int -> Caml.Obj.t -> t + +(** [create_zero ~len] returns an obj-array of length [len], all of whose indices have + value [Caml.Obj.repr 0]. *) +val create_zero : len:int -> t + +(** [copy t] returns a new array with the same elements as [t]. *) +val copy : t -> t + +val singleton : Caml.Obj.t -> t + +val empty : t + +val length : t -> int + +(** [get t i] and [unsafe_get t i] return the object at index [i]. [set t i o] and + [unsafe_set t i o] set index [i] to [o]. In no case is the object copied. The + [unsafe_*] variants omit the bounds check of [i]. *) +val get : t -> int -> Caml.Obj.t +val unsafe_get : t -> int -> Caml.Obj.t +val set : t -> int -> Caml.Obj.t -> unit +val unsafe_set : t -> int -> Caml.Obj.t -> unit + +val swap : t -> int -> int -> unit + +(** [unsafe_set_assuming_currently_int t i obj] sets index [i] of [t] to [obj], but only + works correctly if [Caml.Obj.is_int (get t i)]. This precondition saves a dynamic + check. + + [unsafe_set_int_assuming_currently_int] is similar, except the value being set is an + int. + + [unsafe_set_int] is similar but does not assume anything about the target. *) +val unsafe_set_assuming_currently_int : t -> int -> Caml.Obj.t -> unit + +val unsafe_set_int_assuming_currently_int : t -> int -> int -> unit +val unsafe_set_int : t -> int -> int -> unit + +(** [unsafe_set_omit_phys_equal_check] is like [unsafe_set], except it doesn't do a + [phys_equal] check to try to skip [caml_modify]. It is safe to call this even if the + values are [phys_equal]. *) +val unsafe_set_omit_phys_equal_check : t -> int -> Caml.Obj.t -> unit + +(** [unsafe_clear_if_pointer t i] prevents [t.(i)] from pointing to anything to prevent + space leaks. It does this by setting [t.(i)] to [Caml.Obj.repr 0]. As a performance hack, + it only does this when [not (Caml.Obj.is_int t.(i))]. *) +val unsafe_clear_if_pointer : t -> int -> unit + +(** [truncate t ~len] shortens [t]'s length to [len]. It is an error if [len <= 0] or + [len > length t].*) +val truncate : t -> len:int -> unit + diff --git a/src/option.ml b/src/option.ml new file mode 100644 index 0000000..591dcc0 --- /dev/null +++ b/src/option.ml @@ -0,0 +1,200 @@ +open! Import + +type 'a t = 'a option [@@deriving_inline sexp, compare, hash] +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + option_of_sexp +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_option +let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = compare_option +let hash_fold_t : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + = hash_fold_option +[@@@end] + +let is_none = function None -> true | _ -> false + +let is_some = function Some _ -> true | _ -> false + +let value_map o ~default ~f = + match o with + | Some x -> f x + | None -> default + +let iter o ~f = + match o with + | None -> () + | Some a -> f a +;; + +let invariant f t = iter t ~f + +let map2 o1 o2 ~f = + match o1, o2 with + | Some a1, Some a2 -> Some (f a1 a2) + | _ -> None + +let call x ~f = + match f with + | None -> () + | Some f -> f x + +let value t ~default = + match t with + | None -> default + | Some x -> x +;; + +let value_exn ?here ?error ?message t = + match t with + | Some x -> x + | None -> + let error = + match here, error, message with + | None , None , None -> Error.of_string "Option.value_exn None" + | None , None , Some m -> Error.of_string m + | None , Some e, None -> e + | None , Some e, Some m -> Error.tag e ~tag:m + | Some p, None , None -> + Error.create "Option.value_exn" p Source_code_position0.sexp_of_t + | Some p, None , Some m -> + Error.create m p Source_code_position0.sexp_of_t + | Some p, Some e, _ -> + Error.create (value message ~default:"") (e, p) + (sexp_of_pair Error.sexp_of_t Source_code_position0.sexp_of_t) + in + Error.raise error +;; + +let to_array t = + match t with + | None -> [||] + | Some x -> [|x|] +;; + +let to_list t = + match t with + | None -> [] + | Some x -> [x] +;; + +let min_elt t ~compare:_ = t +let max_elt t ~compare:_ = t + +let sum (type a) (module M : Container.Summable with type t = a) t ~f = + match t with + | None -> M.zero + | Some x -> f x +;; + +let for_all t ~f = + match t with + | None -> true + | Some x -> f x +;; + +let exists t ~f = + match t with + | None -> false + | Some x -> f x +;; + +let mem t a ~equal = + match t with + | None -> false + | Some a' -> equal a a' +;; + +let length t = + match t with + | None -> 0 + | Some _ -> 1 +;; + +let is_empty = is_none + +let fold t ~init ~f = + match t with + | None -> init + | Some x -> f init x +;; + +let count t ~f = + match t with + | None -> 0 + | Some a -> if f a then 1 else 0 +;; + +let find t ~f = + match t with + | None -> None + | Some x -> if f x then Some x else None +;; + +let find_map t ~f = + match t with + | None -> None + | Some a -> f a +;; + +let equal f t t' = + match t, t' with + | None, None -> true + | Some x, Some x' -> f x x' + | _ -> false + +let some x = Some x + +let both x y = + match x,y with + | Some a, Some b -> Some (a,b) + | _ -> None + +let first_some x y = + match x with + | Some _ -> x + | None -> y + +let some_if cond x = if cond then Some x else None + +let merge a b ~f = + match a, b with + | None, x | x, None -> x + | Some a, Some b -> Some (f a b) + +let filter t ~f = + match t with + | Some v as o when f v -> o + | _ -> None + +let try_with f = + try Some (f ()) + with _ -> None + +include Monad.Make (struct + type 'a t = 'a option + let return x = Some x + let map t ~f = + match t with + | None -> None + | Some a -> Some (f a) + ;; + let map = `Custom map + let bind o ~f = + match o with + | None -> None + | Some x -> f x + end) + +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t + +let validate ~none ~some t = + let module V = Validate in + match t with + | None -> V.name "none" (V.protect none ()) + | Some x -> V.name "some" (V.protect some x ) +;; diff --git a/src/option.mli b/src/option.mli new file mode 100644 index 0000000..cb4de45 --- /dev/null +++ b/src/option.mli @@ -0,0 +1,75 @@ +(** Option type. *) + +open! Import + +type 'a t = 'a option [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Container.S1 with type 'a t := 'a t +include Equal.S1 with type 'a t := 'a t +include Invariant.S1 with type 'a t := 'a t + +(** Options form a monad, where [return x = Some x], [(None >>= f) = None], and [(Some x + >>= f) = f x]. *) +include Monad.S with type 'a t := 'a t + +(** [is_none t] returns true iff [t = None]. *) +val is_none : 'a t -> bool + +(** [is_some t] returns true iff [t = Some x]. *) +val is_some : 'a t -> bool + +(** [value_map ~default ~f] is the same as [function Some x -> f x | None -> default]. *) +val value_map : 'a t -> default:'b -> f:('a -> 'b) -> 'b + +(** [map2 o f] maps ['a option] and ['b option] to a ['c option] using [~f]. *) +val map2 : 'a t -> 'b t -> f:('a -> 'b -> 'c) -> 'c t + +(** [call x f] runs an optional function [~f] on the argument. *) +val call : 'a -> f:('a -> unit) t -> unit + +(** [value None ~default] = [default] + + [value (Some x) ~default] = [x] *) +val value : 'a t -> default:'a -> 'a + +(** [value_exn (Some x)] = [x]. [value_exn None] raises an error whose contents contain + the supplied [~here], [~error], and [message], or a default message if none are + supplied. *) +val value_exn + : ?here:Source_code_position0.t + -> ?error:Error.t + -> ?message:string + -> 'a t + -> 'a + +val some : 'a -> 'a t + +val both : 'a t -> 'b t -> ('a * 'b) t + +val first_some : 'a t -> 'a t -> 'a t + +val some_if : bool -> 'a -> 'a t + +(** [merge a b ~f] merges together the values from [a] and [b] using [f]. If both [a] and + [b] are [None], returns [None]. If only one is [Some], returns that one, and if both + are [Some], returns [Some] of the result of applying [f] to the contents of [a] and + [b]. *) +val merge : 'a t -> 'a t -> f:('a -> 'a -> 'a) -> 'a t + +val filter : 'a t -> f:('a -> bool) -> 'a t + +(** [try_with f] returns [Some x] if [f] returns [x] and [None] if [f] raises an + exception. See [Result.try_with] if you'd like to know which exception. *) +val try_with : (unit -> 'a) -> 'a t + +val validate : none:unit Validate.check -> some:'a Validate.check -> 'a t Validate.check diff --git a/src/option_array.ml b/src/option_array.ml new file mode 100644 index 0000000..569f54c --- /dev/null +++ b/src/option_array.ml @@ -0,0 +1,171 @@ +open! Import + + +(** ['a Cheap_option.t] is like ['a option], but it doesn't box [some _] values. + + There are several things that are unsafe about it: + + - [float t array] (or any array-backed container) is not memory-safe + because float array optimization is incompatible with unboxed option + optimization. You have to use [Uniform_array.t] instead of [array]. + + - Nested options (['a t t]) don't work. They are believed to be + memory-safe, but not parametric. + + - A record with [float t]s in it should be safe, but it's only [t] being + abstract that gives you safety. If the compiler was smart enough to peek + through the module signature then it could decide to construct a float + array instead. *) +module Cheap_option = struct + + (* This is taken from core_kernel. Rather than expose it in the public + interface of base, just keep a copy around here. *) + let phys_same (type a) (type b) (a : a) (b : b) = + phys_equal a (Caml.Obj.magic b : a) + + module T0 : sig + type 'a t + val none : _ t + val some : 'a -> 'a t + val is_none : _ t -> bool + val is_some : _ t -> bool + val value_exn : 'a t -> 'a + val value_unsafe : 'a t -> 'a + end + = struct + type +'a t + + (* Being a pointer, no one outside this module can construct a value that is + [phys_same] as this one. + + It would be simpler to use this value as [none], but we use an immediate instead + because it lets us avoid caml_modify when setting to [none], making certain + benchmarks significantly faster (e.g. ../bench/array_queue.exe). + + this code is duplicated in Moption, and if we find yet another place where we want + it we should reconsider making it shared. *) + let none_substitute : _ t = Caml.Obj.obj (Caml.Obj.new_block Caml.Obj.abstract_tag 1) + + let none : _ t = + (* The number was produced by + [< /dev/urandom tr -c -d '1234567890abcdef' | head -c 16]. + + The idea is that a random number will have lower probability to collide with + anything than any number we can choose ourselves. + + We are using a polymorphic variant instead of an integer constant because there + is a compiler bug where it wrongly assumes that the result of [if _ then c else + y] is not a pointer if [c] is an integer compile-time constant. This is being + fixed in https://github.com/ocaml/ocaml/pull/555. The "memory corruption" test + below demonstrates the issue. *) + Caml.Obj.magic `x6e8ee3478e1d7449 + + let is_none x = phys_equal x none + + let is_some x = not (phys_equal x none) + + let some (type a) (x : a) : a t = + if phys_same x none + then none_substitute + else (Caml.Obj.magic x) + + let value_unsafe (type a) (x : a t) : a = + if phys_equal x none_substitute + then Caml.Obj.magic none + else Caml.Obj.magic x + + let value_exn x = + if is_some x + then value_unsafe x + else failwith "Option_array.get_some_exn: the element is [None]" + + end + + module T1 = struct + include T0 + let of_option = function + | None -> none + | Some x -> some x + let to_option x = + if is_some x + then Some (value_unsafe x) + else None + + let to_sexpable = to_option + let of_sexpable = of_option + end + + include T1 + include Sexpable.Of_sexpable1(Option)(T1) +end + +type 'a t = 'a Cheap_option.t Uniform_array.t [@@deriving_inline sexp] +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + let _tp_loc = "src/option_array.ml.t" in + fun _of_a -> + fun t -> Uniform_array.t_of_sexp (Cheap_option.t_of_sexp _of_a) t +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + fun v -> Uniform_array.sexp_of_t (Cheap_option.sexp_of_t _of_a) v +[@@@end] + +let empty = Uniform_array.empty + +let create ~len = Uniform_array.create ~len Cheap_option.none + +let init n ~f = Uniform_array.init n ~f:(fun i -> Cheap_option.of_option (f i)) + +let init_some n ~f = Uniform_array.init n ~f:(fun i -> Cheap_option.some (f i)) + +let length = Uniform_array.length + +let get t i = Cheap_option.to_option (Uniform_array.get t i) + +let get_some_exn t i = Cheap_option.value_exn (Uniform_array.get t i) + +let is_none t i = Cheap_option.is_none (Uniform_array.get t i) + +let is_some t i = Cheap_option.is_some (Uniform_array.get t i) + +let set t i x = Uniform_array.set t i (Cheap_option.of_option x) + +let set_some t i x = Uniform_array.set t i (Cheap_option.some x) + +let set_none t i = Uniform_array.set t i Cheap_option.none + +let swap t i j = Uniform_array.swap t i j + +let unsafe_get t i = Cheap_option.to_option (Uniform_array.unsafe_get t i) + +let unsafe_get_some_exn t i = Cheap_option.value_exn (Uniform_array.unsafe_get t i) + +let unsafe_is_some t i = Cheap_option.is_some (Uniform_array.unsafe_get t i) + +let unsafe_set t i x = Uniform_array.unsafe_set t i (Cheap_option.of_option x) + +let unsafe_set_some t i x = Uniform_array.unsafe_set t i (Cheap_option.some x) + +let unsafe_set_none t i = Uniform_array.unsafe_set t i (Cheap_option.none) + +let clear t = + for i = 0 to length t - 1 + do + unsafe_set_none t i + done + +include + Blit.Make1_generic + (struct + type nonrec 'a t = 'a t + let length = length + let create_like ~len _ = create ~len + let unsafe_blit = Uniform_array.unsafe_blit + end) + +let copy = Uniform_array.copy + +module For_testing = struct + module Unsafe_cheap_option = Cheap_option +end diff --git a/src/option_array.mli b/src/option_array.mli new file mode 100644 index 0000000..a39a5d0 --- /dev/null +++ b/src/option_array.mli @@ -0,0 +1,93 @@ +(** ['a Option_array.t] is a compact representation of ['a option array]: it avoids + allocating heap objects representing [Some x], usually representing them with [x] + instead. It uses a special representation for [None] that's guaranteed to never + collide with any representation of [Some x]. *) + +open! Import + +type 'a t [@@deriving_inline sexp] +include +sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +val empty : _ t + +(** Initially filled with all [None] *) +val create : len:int -> _ t + + +val init_some : int -> f:(int -> 'a ) -> 'a t +val init : int -> f:(int -> 'a option) -> 'a t + +val length : _ t -> int + +(** [get t i] returns the element number [i] of array [t], raising if [i] is outside the + range 0 to [length t - 1]. *) +val get : 'a t -> int -> 'a option + +(** Raises if the element number [i] is [None]. *) +val get_some_exn : 'a t -> int -> 'a + +(** [is_none t i = Option.is_none (get t i)] *) +val is_none : _ t -> int -> bool + +(** [is_some t i = Option.is_some (get t i)] *) +val is_some : _ t -> int -> bool + +(** These can cause arbitrary behavior when used for an out-of-bounds array access. *) + +val unsafe_get : 'a t -> int -> 'a option +val unsafe_get_some_exn : 'a t -> int -> 'a +val unsafe_is_some : _ t -> int -> bool + +(** [set t i x] modifies array [t] in place, replacing element number [i] with [x], + raising if [i] is outside the range 0 to [length t - 1]. *) +val set : 'a t -> int -> 'a option -> unit +val set_some : 'a t -> int -> 'a -> unit +val set_none : _ t -> int -> unit + +val swap : _ t -> int -> int -> unit + +(** Replaces all the elements of the array with [None]. *) +val clear : _ t -> unit + +(** Unsafe versions of [set*]. Can cause arbitrary behaviour when used for an + out-of-bounds array access. *) + +val unsafe_set : 'a t -> int -> 'a option -> unit +val unsafe_set_some : 'a t -> int -> 'a -> unit +val unsafe_set_none : _ t -> int -> unit + +include Blit.S1 with type 'a t := 'a t + +(** Makes a (shallow) copy of the array. *) +val copy : 'a t -> 'a t + +(**/**) + +module For_testing : sig + module Unsafe_cheap_option : sig + type 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + val none : _ t + val some : 'a -> 'a t + + val is_none : _ t -> bool + val is_some : _ t -> bool + + val value_exn : 'a t -> 'a + val value_unsafe : 'a t -> 'a + + val to_option : 'a t -> 'a Option.t + val of_option : 'a Option.t -> 'a t + end +end diff --git a/src/or_error.ml b/src/or_error.ml new file mode 100644 index 0000000..4ded9a3 --- /dev/null +++ b/src/or_error.ml @@ -0,0 +1,129 @@ +open! Import + +type 'a t = ('a, Error.t) Result.t [@@deriving_inline compare, hash, sexp] +let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = + fun _cmp__a -> + fun a__001_ -> + fun b__002_ -> Result.compare _cmp__a Error.compare a__001_ b__002_ +let hash_fold_t : + 'a . + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_a -> + fun hsv -> + fun arg -> Result.hash_fold_t _hash_fold_a Error.hash_fold_t hsv arg +let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + let _tp_loc = "src/or_error.ml.t" in + fun _of_a -> fun t -> Result.t_of_sexp _of_a Error.t_of_sexp t +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> fun v -> Result.sexp_of_t _of_a Error.sexp_of_t v +[@@@end] + +let invariant invariant_a t = + match t with + | Ok a -> invariant_a a + | Error error -> Error.invariant error +;; + +include (Result : Monad.S2 + with type ('a, 'b) t := ('a, 'b) Result.t + with module Let_syntax := Result.Let_syntax) + +include Applicative.Make (struct + type nonrec 'a t = 'a t + let return = return + let apply f x = + Result.combine f x + ~ok:(fun f x -> f x) + ~err:(fun e1 e2 -> Error.of_list [e1; e2]) + let map = `Custom map + end) + +module Let_syntax = struct + let return = return + include Monad_infix + module Let_syntax = struct + let return = return + let map = map + let bind = bind + let both = both (* from Applicative.Make *) + module Open_on_rhs = struct end + end +end + +let ok = Result.ok +let is_ok = Result.is_ok +let is_error = Result.is_error + +let ignore = ignore_m + +let try_with ?(backtrace = false) f = + try Ok (f ()) + with exn -> Error (Error.of_exn exn ?backtrace:(if backtrace then Some `Get else None)) +;; + +let try_with_join ?backtrace f = join (try_with ?backtrace f) + +let ok_exn = function + | Ok x -> x + | Error err -> Error.raise err +;; + +let of_exn ?backtrace exn = Error (Error.of_exn ?backtrace exn) + +let of_exn_result ?backtrace = function + | Ok _ as z -> z + | Error exn -> of_exn ?backtrace exn +;; + +let error ?strict message a sexp_of_a = + Error (Error.create ?strict message a sexp_of_a) +;; + +let error_s sexp = Error (Error.create_s sexp) + +let error_string message = Error (Error.of_string message) + +let errorf format = Printf.ksprintf error_string format + +let tag t ~tag = Result.map_error t ~f:(Error.tag ~tag) +let tag_arg t message a sexp_of_a = + Result.map_error t ~f:(fun e -> Error.tag_arg e message a sexp_of_a) +;; + +let unimplemented s = error "unimplemented" s sexp_of_string + +let combine_errors l = Result.map_error (Result.combine_errors l) ~f:Error.of_list + +let combine_errors_unit l = Result.map (combine_errors l) ~f:(fun (_ : unit list) -> ()) + +let filter_ok_at_least_one l = + let ok, errs = List.partition_map l ~f:Result.ok_fst in + match ok with + | [] -> Error (Error.of_list errs) + | _ -> Ok ok +;; + +let find_ok l = + match List.find_map l ~f:Result.ok with + | Some x -> Ok x + | None -> + Error (Error.of_list (List.map l ~f:(function + | Ok _ -> assert false + | Error err -> err))) +;; + +let find_map_ok l ~f = + With_return.with_return (fun {return} -> + Error (Error.of_list (List.map l ~f:(fun elt -> + match f elt with + | (Ok _ as x) -> return x + | Error err -> err)))) +;; + +let map = Result.map +let iter = Result.iter +let iter_error = Result.iter_error diff --git a/src/or_error.mli b/src/or_error.mli new file mode 100644 index 0000000..390270b --- /dev/null +++ b/src/or_error.mli @@ -0,0 +1,126 @@ +(** Type for tracking errors in an [Error.t]. This is a specialization of the [Result] + type, where the [Error] constructor carries an [Error.t]. + + A common idiom is to wrap a function that is not implemented on all platforms, e.g., + + {[val do_something_linux_specific : (unit -> unit) Or_error.t]} +*) + +open! Import + +(** Serialization and comparison of an [Error] force the error's lazy message. *) +type 'a t = ('a, Error.t) Result.t [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> 'a t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +(** [Applicative] functions don't have quite the same semantics as + [Applicative.of_Monad(Or_error)] would give -- [apply (Error e1) (Error e2)] returns + the combination of [e1] and [e2], whereas it would only return [e1] if it were defined + using [bind]. *) +include Applicative.S with type 'a t := 'a t +include Invariant.S1 with type 'a t := 'a t +include Monad.S with type 'a t := 'a t + +val is_ok : _ t -> bool +val is_error : _ t -> bool + +val ignore : _ t -> unit t + +(** [try_with f] catches exceptions thrown by [f] and returns them in the [Result.t] as an + [Error.t]. [try_with_join] is like [try_with], except that [f] can throw exceptions + or return an [Error] directly, without ending up with a nested error; it is equivalent + to [Result.join (try_with f)]. *) +val try_with : ?backtrace:bool (** defaults to [false] *) -> (unit -> 'a ) -> 'a t +val try_with_join : ?backtrace:bool (** defaults to [false] *) -> (unit -> 'a t) -> 'a t + + +(** [ok t] returns [None] if [t] is an [Error], and otherwise returns the contents of the + [Ok] constructor. *) +val ok : 'ok t -> 'ok option + +(** [ok_exn t] throws an exception if [t] is an [Error], and otherwise returns the + contents of the [Ok] constructor. *) +val ok_exn : 'a t -> 'a + +(** [of_exn ?backtrace exn] is [Error (Error.of_exn ?backtrace exn)]. *) +val of_exn : ?backtrace:[ `Get | `This of string ] -> exn -> _ t + +(** [of_exn_result ?backtrace (Ok a) = Ok a] + + [of_exn_result ?backtrace (Error exn) = of_exn ?backtrace exn] *) +val of_exn_result : ?backtrace:[ `Get | `This of string ] -> ('a, exn) Result.t -> 'a t + +(** [error] is a wrapper around [Error.create]: + + {[ + error ?strict message a sexp_of_a + = Error (Error.create ?strict message a sexp_of_a) + ]} + + As with [Error.create], [sexp_of_a a] is lazily computed when the info is converted + to a sexp. So, if [a] is mutated in the time between the call to [create] and the + sexp conversion, those mutations will be reflected in the sexp. Use [~strict:()] to + force [sexp_of_a a] to be computed immediately. *) +val error + : ?strict : unit + -> string + -> 'a + -> ('a -> Sexp.t) + -> _ t + +val error_s : Sexp.t -> _ t + +(** [error_string message] is [Error (Error.of_string message)]. *) +val error_string : string -> _ t + +(** [errorf format arg1 arg2 ...] is [Error (sprintf format arg1 arg2 ...)]. Note that it + calculates the string eagerly, so when performance matters you may want to use [error] + instead. *) +val errorf : ('a, unit, string, _ t) format4 -> 'a + +(** [tag t ~tag] is [Result.map_error t ~f:(Error.tag ~tag)]. + [tag_arg] is similar. *) +val tag : 'a t -> tag:string -> 'a t +val tag_arg : 'a t -> string -> 'b -> ('b -> Sexp.t) -> 'a t + +(** For marking a given value as unimplemented. Typically combined with conditional + compilation, where on some platforms the function is defined normally, and on some + platforms it is defined as unimplemented. The supplied string should be the name of + the function that is unimplemented. *) +val unimplemented : string -> _ t + +val map : 'a t -> f:('a -> 'b) -> 'b t +val iter : 'a t -> f:('a -> unit) -> unit +val iter_error : _ t -> f:(Error.t -> unit) -> unit + +(** [combine_errors ts] returns [Ok] if every element in [ts] is [Ok], else it returns + [Error] with all the errors in [ts]. More precisely: + + - [combine_errors [Ok a1; ...; Ok an] = Ok [a1; ...; an]] + - {[ combine_errors [...; Error e1; ...; Error en; ...] + = Error (Error.of_list [e1; ...; en]) ]} *) +val combine_errors : 'a t list -> 'a list t + +(** [combine_errors_unit ts] returns [Ok] if every element in [ts] is [Ok ()], else it + returns [Error] with all the errors in [ts], like [combine_errors]. *) +val combine_errors_unit : unit t list -> unit t + +(** [filter_ok_at_least_one ts] returns all values in [ts] that are [Ok] if there is at + least one, otherwise it returns the same error as [combine_errors ts]. *) +val filter_ok_at_least_one : 'a t list -> 'a list t + +(** [find_ok ts] returns the first value in [ts] that is [Ok], otherwise it returns the + same error as [combine_errors ts]. *) +val find_ok : 'a t list -> 'a t + +(** [find_map_ok l ~f] returns the first value in [l] for which [f] returns [Ok], + otherwise it returns the same error as [combine_errors (List.map l ~f)]. *) +val find_map_ok : 'a list -> f:('a -> 'b t) -> 'b t diff --git a/src/ordered_collection_common.ml b/src/ordered_collection_common.ml new file mode 100644 index 0000000..9e1fda3 --- /dev/null +++ b/src/ordered_collection_common.ml @@ -0,0 +1,46 @@ +open! Import + +let invalid_argf = Printf.invalid_argf + +let [@inline never] slow_check_pos_len_exn ~pos ~len ~total_length = + if pos < 0 + then invalid_argf "Negative position: %d" pos (); + if len < 0 + then invalid_argf "Negative length: %d" len (); + (* We use [pos > total_length - len] rather than [pos + len > total_length] to avoid the + possibility of overflow. *) + if pos > total_length - len + then invalid_argf "pos + len past end: %d + %d > %d" pos len total_length () +;; + +let check_pos_len_exn ~pos ~len ~total_length = + (* This is better than [slow_check_pos_len_exn] for two reasons: + + - much less inlined code + - only one conditional jump + + The reason it works is that checking [< 0] is testing the highest order bit, so + [a < 0 || b < 0] is the same as [a lor b < 0]. + + [pos + len] can overflow, so [pos > total_length - len] is not equivalent to + [total_length - len - pos < 0], we need to test for [pos + len] overflow as + well. *) + let stop = pos + len in + if pos lor len lor stop lor (total_length - stop) < 0 then + slow_check_pos_len_exn ~pos ~len ~total_length +;; + +let get_pos_len_exn ?(pos = 0) ?len () ~total_length = + let len = match len with Some i -> i | None -> total_length - pos in + check_pos_len_exn ~pos ~len ~total_length; + pos, len +;; + +let get_pos_len ?pos ?len () ~total_length = + try Result.Ok (get_pos_len_exn () ?pos ?len ~total_length) + with Invalid_argument s -> Or_error.error_string s +;; + +module Private = struct + let slow_check_pos_len_exn = slow_check_pos_len_exn +end diff --git a/src/ordered_collection_common.mli b/src/ordered_collection_common.mli new file mode 100644 index 0000000..2397835 --- /dev/null +++ b/src/ordered_collection_common.mli @@ -0,0 +1,39 @@ +(** Functions for ordered collections. *) + +open! Import + +(** [get_pos_len], [get_pos_len_exn], and [check_pos_len_exn] are intended to be used + by functions that take a sequence (array, string, bigstring, ...) and an optional + [pos] and [len] specifying a subrange of the sequence. Such functions should call + [get_pos_len] with the length of the sequence and the optional [pos] and [len], and it + will return the [pos] and [len] specifying the range, where the default [pos] is zero + and the default [len] is to go to the end of the sequence. + + It should be the case that: + + {[ + pos >= 0 && len >= 0 && pos + len <= total_length + ]} + + Note that this allows [pos = total_length] and [len = 0], i.e., an empty subrange + at the end of the sequence. + + [get_pos_len] returns [(pos', len')] specifying a subrange where: + + {v + pos' = match pos with None -> 0 | Some i -> i + len' = match len with None -> total_length - pos' | Some i -> i + v} *) +val get_pos_len : ?pos:int -> ?len:int -> unit -> total_length:int -> (int * int) Or_error.t +val get_pos_len_exn : ?pos:int -> ?len:int -> unit -> total_length:int -> int * int + +(** [check_pos_len_exn ~pos ~len ~total_length] raises unless [pos >= 0 && len >= 0 && + pos + len <= total_length]. *) +val check_pos_len_exn : pos:int -> len:int -> total_length:int -> unit + +(*_ See the Jane Street Style Guide for an explanation of [Private] submodules: + + https://opensource.janestreet.com/standards/#private-submodules *) +module Private : sig + val slow_check_pos_len_exn : pos:int -> len:int -> total_length:int -> unit +end diff --git a/src/ordering.ml b/src/ordering.ml new file mode 100644 index 0000000..f812e99 --- /dev/null +++ b/src/ordering.ml @@ -0,0 +1,70 @@ +open! Import + +type t = Less | Equal | Greater [@@deriving_inline compare, hash, enumerate, sexp] +let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare +let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Less -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | Equal -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Greater -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) +let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x +let all : t list = [Less; Equal; Greater] +let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/ordering.ml.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("less"|"Less") -> Less + | Ppx_sexp_conv_lib.Sexp.Atom ("equal"|"Equal") -> Equal + | Ppx_sexp_conv_lib.Sexp.Atom ("greater"|"Greater") -> Greater + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("less"|"Less"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("equal"|"Equal"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("greater"|"Greater"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Less -> Ppx_sexp_conv_lib.Sexp.Atom "Less" + | Equal -> Ppx_sexp_conv_lib.Sexp.Atom "Equal" + | Greater -> Ppx_sexp_conv_lib.Sexp.Atom "Greater" +[@@@end] + +let equal a b = compare a b = 0 + +module Export = struct + type _ordering = t = + | Less + | Equal + | Greater +end + +let of_int n = + if n < 0 + then Less + else if n = 0 + then Equal + else Greater +;; + +let to_int = function + | Less -> -1 + | Equal -> 0 + | Greater -> 1 +;; diff --git a/src/ordering.mli b/src/ordering.mli new file mode 100644 index 0000000..8293e9d --- /dev/null +++ b/src/ordering.mli @@ -0,0 +1,74 @@ +(** [Ordering] is intended to make code that matches on the result of a comparison + more concise and easier to read. + + For example, instead of writing: + + {[ + let r = compare x y in + if r < 0 then + ... + else if r = 0 then + ... + else + ... + ]} + + you could simply write: + + {[ + match Ordering.of_int (compare x y) with + | Less -> ... + | Equal -> ... + | Greater -> ... + ]} + +*) + +open! Import + +type t = + | Less + | Equal + | Greater +[@@deriving_inline compare, enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Equal.S with type t := t + +(** [of_int n] is: + + {v + Less if n < 0 + Equal if n = 0 + Greater if n > 0 + v} *) +val of_int : int -> t + +(** [to_int t] is: + + {v + Less -> -1 + Equal -> 0 + Greater -> 1 + v} + + It can be useful when writing a comparison function to allow one to return + [Ordering.t] values and transform them to [int]s later. *) +val to_int : t -> int + +module Export : sig + type _ordering = t = + | Less + | Equal + | Greater +end diff --git a/src/poly0.ml b/src/poly0.ml new file mode 100644 index 0000000..7fa2a9a --- /dev/null +++ b/src/poly0.ml @@ -0,0 +1,18 @@ +(** Primitives for polymorphic compare. *) + +(*_ Polymorphic compiler primitives can't be aliases as this doesn't play well with + inlining. (If aliased without a type annotation, the compiler would implement them + using the generic code doing a C call, and it's this code that would be inlined.) As a + result we have to copy the [external ...] declaration here. *) +external ( < ) : 'a -> 'a -> bool = "%lessthan" +external ( <= ) : 'a -> 'a -> bool = "%lessequal" +external ( <> ) : 'a -> 'a -> bool = "%notequal" +external ( = ) : 'a -> 'a -> bool = "%equal" +external ( > ) : 'a -> 'a -> bool = "%greaterthan" +external ( >= ) : 'a -> 'a -> bool = "%greaterequal" +external ascending : 'a -> 'a -> int = "%compare" +external compare : 'a -> 'a -> int = "%compare" +external equal : 'a -> 'a -> bool = "%equal" +let descending x y = compare y x +let max = Caml.max +let min = Caml.min diff --git a/src/poly0.mli b/src/poly0.mli new file mode 100644 index 0000000..01bfc8e --- /dev/null +++ b/src/poly0.mli @@ -0,0 +1,22 @@ +(** A module containing the ad-hoc polymorphic comparison functions. Useful when + you want to use polymorphic compare in some small scope of a file within which + polymorphic compare has been hidden *) + +external compare : 'a -> 'a -> int = "%compare" + +(** [ascending] is identical to [compare]. [descending x y = ascending y x]. These are + intended to be mnemonic when used like [List.sort ~compare:ascending] and [List.sort + ~compare:descending], since they cause the list to be sorted in ascending or + descending order, respectively. *) +val ascending : 'a -> 'a -> int +val descending : 'a -> 'a -> int + +external ( < ) : 'a -> 'a -> bool = "%lessthan" +external ( <= ) : 'a -> 'a -> bool = "%lessequal" +external ( <> ) : 'a -> 'a -> bool = "%notequal" +external ( = ) : 'a -> 'a -> bool = "%equal" +external ( > ) : 'a -> 'a -> bool = "%greaterthan" +external ( >= ) : 'a -> 'a -> bool = "%greaterequal" +external equal : 'a -> 'a -> bool = "%equal" +val min : 'a -> 'a -> 'a +val max : 'a -> 'a -> 'a diff --git a/src/popcount.ml b/src/popcount.ml new file mode 100644 index 0000000..34e140d --- /dev/null +++ b/src/popcount.ml @@ -0,0 +1,39 @@ +open! Import + +(* C stub for int popcount to use the POPCNT instruction where possible *) +external int_popcount : int -> int = "Base_int_math_int_popcount" [@@noalloc] + +(* To maintain javascript compatibility and enable unboxing, we implement popcount in + OCaml rather than use C stubs. Implementation adapted from: + https://en.wikipedia.org/wiki/Hamming_weight#Efficient_implementation *) +let int64_popcount = + let open Caml.Int64 in + let ( + ) = add in + let ( - ) = sub in + let ( * ) = mul in + let ( lsr ) = shift_right_logical in + let ( land ) = logand in + let m1 = 0x5555555555555555L in (* 0b01010101... *) + let m2 = 0x3333333333333333L in (* 0b00110011... *) + let m4 = 0x0f0f0f0f0f0f0f0fL in (* 0b00001111... *) + let h01 = 0x0101010101010101L in (* 1 bit set per byte *) + (fun x -> + (* gather the bit count for every pair of bits *) + let x = x - ((x lsr 1) land m1) in + (* gather the bit count for every 4 bits *) + let x = (x land m2) + ((x lsr 2) land m2) in + (* gather the bit count for every byte *) + let x = (x + (x lsr 4)) land m4 in + (* sum the bit counts in the top byte and shift it down *) + to_int ((x * h01) lsr 56)) [@inline] + +let int32_popcount = + (* On 64-bit systems, this is faster than implementing using [int32] arithmetic. *) + let mask = 0xffff_ffffL in + (fun x -> int64_popcount (Caml.Int64.logand (Caml.Int64.of_int32 x) mask)) [@inline] + +let nativeint_popcount = + match Caml.Nativeint.size with + | 32 -> (fun x -> int32_popcount (Caml.Nativeint.to_int32 x)) [@inline] + | 64 -> (fun x -> int64_popcount (Caml.Int64.of_nativeint x)) [@inline] + | _ -> assert false diff --git a/src/popcount.mli b/src/popcount.mli new file mode 100644 index 0000000..8aa4fee --- /dev/null +++ b/src/popcount.mli @@ -0,0 +1,11 @@ +(** This module exposes popcount functions (which count the number of ones in a bitstring) + for the various integer types. + + Functions are exposed in their respective modules. *) + +open! Import + +val int_popcount : int -> int +val int32_popcount : int32 -> int +val int64_popcount : int64 -> int +val nativeint_popcount : nativeint -> int diff --git a/src/pow_overflow_bounds.ml b/src/pow_overflow_bounds.ml new file mode 100644 index 0000000..b3fd8e6 --- /dev/null +++ b/src/pow_overflow_bounds.ml @@ -0,0 +1,425 @@ +(* This file was autogenerated by ../generate/generate_pow_overflow_bounds.exe *) + +open! Import + +module Array = Array0 + +(* We have to use Int64.to_int_exn instead of int constants to make + sure that file can be preprocessed on 32-bit machines. *) + +let overflow_bound_max_int32_value : int32 = + 2147483647l + +let int32_positive_overflow_bounds : int32 array = + [| 2147483647l + ; 2147483647l + ; 46340l + ; 1290l + ; 215l + ; 73l + ; 35l + ; 21l + ; 14l + ; 10l + ; 8l + ; 7l + ; 5l + ; 5l + ; 4l + ; 4l + ; 3l + ; 3l + ; 3l + ; 3l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 2l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + ; 1l + |] + +let overflow_bound_max_int_value : int = + (-1) lsr 1 + +let int_positive_overflow_bounds : int array = + match Int_conversions.num_bits_int with + | 32 -> Array.map int32_positive_overflow_bounds ~f:Caml.Int32.to_int + | 63 -> + [| Caml.Int64.to_int 4611686018427387903L + ; Caml.Int64.to_int 4611686018427387903L + ; Caml.Int64.to_int 2147483647L + ; 1664510 + ; 46340 + ; 5404 + ; 1290 + ; 463 + ; 215 + ; 118 + ; 73 + ; 49 + ; 35 + ; 27 + ; 21 + ; 17 + ; 14 + ; 12 + ; 10 + ; 9 + ; 8 + ; 7 + ; 7 + ; 6 + ; 5 + ; 5 + ; 5 + ; 4 + ; 4 + ; 4 + ; 4 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 3 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 1 + ; 1 + |] + | 31 -> + [| 1073741823 + ; 1073741823 + ; 32767 + ; 1023 + ; 181 + ; 63 + ; 31 + ; 19 + ; 13 + ; 10 + ; 7 + ; 6 + ; 5 + ; 4 + ; 4 + ; 3 + ; 3 + ; 3 + ; 3 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 2 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + ; 1 + |] + | _ -> assert false + +let overflow_bound_max_int63_on_int64_value : int64 = + 4611686018427387903L + +let int63_on_int64_positive_overflow_bounds : int64 array = + [| 4611686018427387903L + ; 4611686018427387903L + ; 2147483647L + ; 1664510L + ; 46340L + ; 5404L + ; 1290L + ; 463L + ; 215L + ; 118L + ; 73L + ; 49L + ; 35L + ; 27L + ; 21L + ; 17L + ; 14L + ; 12L + ; 10L + ; 9L + ; 8L + ; 7L + ; 7L + ; 6L + ; 5L + ; 5L + ; 5L + ; 4L + ; 4L + ; 4L + ; 4L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 1L + ; 1L + |] + +let overflow_bound_max_int64_value : int64 = + 9223372036854775807L + +let int64_positive_overflow_bounds : int64 array = + [| 9223372036854775807L + ; 9223372036854775807L + ; 3037000499L + ; 2097151L + ; 55108L + ; 6208L + ; 1448L + ; 511L + ; 234L + ; 127L + ; 78L + ; 52L + ; 38L + ; 28L + ; 22L + ; 18L + ; 15L + ; 13L + ; 11L + ; 9L + ; 8L + ; 7L + ; 7L + ; 6L + ; 6L + ; 5L + ; 5L + ; 5L + ; 4L + ; 4L + ; 4L + ; 4L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 3L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 2L + ; 1L + |] + +let int64_negative_overflow_bounds : int64 array = + [| -9223372036854775807L + ; -9223372036854775807L + ; -3037000499L + ; -2097151L + ; -55108L + ; -6208L + ; -1448L + ; -511L + ; -234L + ; -127L + ; -78L + ; -52L + ; -38L + ; -28L + ; -22L + ; -18L + ; -15L + ; -13L + ; -11L + ; -9L + ; -8L + ; -7L + ; -7L + ; -6L + ; -6L + ; -5L + ; -5L + ; -5L + ; -4L + ; -4L + ; -4L + ; -4L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -3L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -2L + ; -1L + |] diff --git a/src/ppx_compare_lib.ml b/src/ppx_compare_lib.ml new file mode 100644 index 0000000..6e6809c --- /dev/null +++ b/src/ppx_compare_lib.ml @@ -0,0 +1,113 @@ +open Import0 + +let phys_equal = phys_equal +external polymorphic_compare : 'a -> 'a -> int = "%compare" +external polymorphic_equal : 'a -> 'a -> bool = "%equal" +external ( && ) : bool -> bool -> bool = "%sequand" + +let compare_abstract ~type_name _ _ = + Printf.ksprintf failwith + "Compare called on the type %s, which is abstract in an implementation." + type_name + +let equal_abstract ~type_name _ _ = + Printf.ksprintf failwith + "Equal called on the type %s, which is abstract in an implementation." + type_name + +type 'a compare = 'a -> 'a -> int +type 'a equal = 'a -> 'a -> bool + +module Builtin = struct + let compare_bool : bool compare = Poly.compare + let compare_char : char compare = Poly.compare + let compare_float : float compare = Poly.compare + let compare_int : int compare = Poly.compare + let compare_int32 : int32 compare = Poly.compare + let compare_int64 : int64 compare = Poly.compare + let compare_nativeint : nativeint compare = Poly.compare + let compare_string : string compare = Poly.compare + let compare_unit : unit compare = Poly.compare + + let compare_array compare_elt a b = + if phys_equal a b then + 0 + else + let len_a = Array0.length a in + let len_b = Array0.length b in + let ret = compare len_a len_b in + if ret <> 0 then ret + else + let rec loop i = + if i = len_a then + 0 + else + let l = Array0.unsafe_get a i + and r = Array0.unsafe_get b i in + let res = compare_elt l r in + if res <> 0 then res + else loop (i + 1) + in + loop 0 + + let rec compare_list compare_elt a b = + match a, b with + | [] , [] -> 0 + | [] , _ -> -1 + | _ , [] -> 1 + | x::xs, y::ys -> + let res = compare_elt x y in + if res <> 0 then res + else compare_list compare_elt xs ys + + let compare_option compare_elt a b = + match a, b with + | None , None -> 0 + | None , Some _ -> -1 + | Some _, None -> 1 + | Some a, Some b -> compare_elt a b + + let compare_ref compare_elt a b = compare_elt !a !b + + let equal_bool : bool equal = Poly.equal + let equal_char : char equal = Poly.equal + let equal_int : int equal = Poly.equal + let equal_int32 : int32 equal = Poly.equal + let equal_int64 : int64 equal = Poly.equal + let equal_nativeint : nativeint equal = Poly.equal + let equal_string : string equal = Poly.equal + let equal_unit : unit equal = Poly.equal + + (* [Poly.equal] is IEEE compliant, which is not what we want here. *) + let equal_float x y = equal_int (compare_float x y) 0 + + let equal_array equal_elt a b = + phys_equal a b || + (let len_a = Array0.length a in + let len_b = Array0.length b in + equal len_a len_b && + (let rec loop i = + i = len_a || + (let l = Array0.unsafe_get a i + and r = Array0.unsafe_get b i in + equal_elt l r && loop (i + 1)) + in + loop 0)) + + let rec equal_list equal_elt a b = + match a, b with + | [] , [] -> true + | [] , _ + | _ , [] -> false + | x::xs, y::ys -> + equal_elt x y && equal_list equal_elt xs ys + + let equal_option equal_elt a b = + match a, b with + | None , None -> true + | None , Some _ + | Some _, None -> false + | Some a, Some b -> equal_elt a b + + let equal_ref equal_elt a b = equal_elt !a !b +end diff --git a/src/ppx_compare_lib.mli b/src/ppx_compare_lib.mli new file mode 100644 index 0000000..60aeaef --- /dev/null +++ b/src/ppx_compare_lib.mli @@ -0,0 +1,50 @@ +(** Runtime support for auto-generated comparators. Users are not intended to use this + module directly. *) + +val phys_equal : 'a -> 'a -> bool + +(*_ /!\ WARNING /!\ all these functions need to declared "external" in order to get the + lazy behavior for ( && ) (relied upon by [@@deriving_inline equal][@@@end]) and the type-based + specialization for equal/compare. *) +external polymorphic_compare : 'a -> 'a -> int = "%compare" +external polymorphic_equal : 'a -> 'a -> bool = "%equal" +external ( && ) : bool -> bool -> bool = "%sequand" + +type 'a compare = 'a -> 'a -> int +type 'a equal = 'a -> 'a -> bool + +(** Raise when fully applied *) +val compare_abstract : type_name:string -> _ compare +val equal_abstract : type_name:string -> _ equal + +module Builtin : sig + val compare_bool : bool compare + val compare_char : char compare + val compare_float : float compare + val compare_int : int compare + val compare_int32 : int32 compare + val compare_int64 : int64 compare + val compare_nativeint : nativeint compare + val compare_string : string compare + val compare_unit : unit compare + + val compare_array : 'a compare -> 'a array compare + val compare_list : 'a compare -> 'a list compare + val compare_option : 'a compare -> 'a option compare + val compare_ref : 'a compare -> 'a ref compare + + val equal_bool : bool equal + val equal_char : char equal + val equal_float : float equal + val equal_int : int equal + val equal_int32 : int32 equal + val equal_int64 : int64 equal + val equal_nativeint : nativeint equal + val equal_string : string equal + val equal_unit : unit equal + + val equal_array : 'a equal -> 'a array equal + val equal_list : 'a equal -> 'a list equal + val equal_option : 'a equal -> 'a option equal + val equal_ref : 'a equal -> 'a ref equal +end diff --git a/src/ppx_enumerate_lib.ml b/src/ppx_enumerate_lib.ml new file mode 100644 index 0000000..bde251f --- /dev/null +++ b/src/ppx_enumerate_lib.ml @@ -0,0 +1 @@ +module List = List diff --git a/src/ppx_hash_lib.ml b/src/ppx_hash_lib.ml new file mode 100644 index 0000000..17a4578 --- /dev/null +++ b/src/ppx_hash_lib.ml @@ -0,0 +1,6 @@ +(** This module is for use by ppx_hash, and is thus not in the interface of Base. *) +module Std = struct + + (** @canonical Base.Hash *) + module Hash = Hash +end diff --git a/src/ppx_sexp_conv_lib.ml b/src/ppx_sexp_conv_lib.ml new file mode 100644 index 0000000..44973cf --- /dev/null +++ b/src/ppx_sexp_conv_lib.ml @@ -0,0 +1 @@ +include Sexplib diff --git a/src/pretty_printer.ml b/src/pretty_printer.ml new file mode 100644 index 0000000..5a6be9f --- /dev/null +++ b/src/pretty_printer.ml @@ -0,0 +1,29 @@ +open! Import + +let r = ref [ "Base.Sexp.pp_hum" ] + +let all () = !r + +let register p = r := p :: !r + +module type S = sig + type t + val pp : Formatter.t -> t -> unit +end + +module Register_pp (M : sig + include S + val module_name : string + end) = struct + include M + let () = register (M.module_name ^ ".pp") +end + +module Register (M : sig + type t + val module_name : string + val to_string : t -> string + end) = Register_pp (struct + include M + let pp formatter t = Caml.Format.pp_print_string formatter (M.to_string t) + end) diff --git a/src/pretty_printer.mli b/src/pretty_printer.mli new file mode 100644 index 0000000..a0c40d7 --- /dev/null +++ b/src/pretty_printer.mli @@ -0,0 +1,44 @@ +(** A list of pretty printers for various types, for use in toplevels. + + [Pretty_printer] has a [string list ref] with the names of [pp] functions matching the + interface: + + {[ + val pp : Format.formatter -> t -> unit + ]} + + The names are actually OCaml identifier names, e.g., "Base.Int.pp". Code for + building toplevels (this code is not in Base) evaluates the strings to yield the + pretty printers and register them with the OCaml runtime. *) + +open! Import + +(** [all ()] returns all pretty printers that have been [register]ed. *) +val all : unit -> string list + +(** Modules that provide a pretty printer will match [S]. *) +module type S = sig + type t + val pp : Formatter.t -> t -> unit +end + +(** [Register] builds a [pp] function from a [to_string] function, and adds the + [module_name ^ ".pp"] to the list of pretty printers. The idea is to statically + guarantee that one has the desired [pp] function at the same point where the [name] is + added. *) +module Register (M : sig + type t + val module_name : string + val to_string : t -> string + end) : S with type t := M.t + +(** [Register_pp] is like [Register], but allows a custom [pp] function rather than using + [to_string]. *) +module Register_pp (M : sig + include S + val module_name : string + end) : S with type t := M.t + +(** [register name] adds [name] to the list of pretty printers. Use the [Register] + functor if possible. *) +val register : string -> unit diff --git a/src/printf.ml b/src/printf.ml new file mode 100644 index 0000000..abbe5d2 --- /dev/null +++ b/src/printf.ml @@ -0,0 +1,8 @@ +open! Import0 + +include Caml.Printf + +(** failwith, invalid_arg, and exit accepting printf's format. *) + +let failwithf fmt = ksprintf (fun s () -> failwith s) fmt +let invalid_argf fmt = ksprintf (fun s () -> invalid_arg s) fmt diff --git a/src/printf.mli b/src/printf.mli new file mode 100644 index 0000000..a0fbabc --- /dev/null +++ b/src/printf.mli @@ -0,0 +1,132 @@ +(** Functions for formatted output. + + [fprintf] and related functions format their arguments according to the given format + string. The format string is a character string which contains two types of objects: + plain characters, which are simply copied to the output channel, and conversion + specifications, each of which causes conversion and printing of arguments. + + Conversion specifications have the following form: + + {[% [flags] [width] [.precision] type]} + + In short, a conversion specification consists in the [%] character, followed by + optional modifiers and a type which is made of one or two characters. + + The types and their meanings are: + + - [d], [i]: convert an integer argument to signed decimal. + - [u], [n], [l], [L], or [N]: convert an integer argument to unsigned + decimal. Warning: [n], [l], [L], and [N] are used for [scanf], and should not be used + for [printf]. + - [x]: convert an integer argument to unsigned hexadecimal, using lowercase letters. + - [X]: convert an integer argument to unsigned hexadecimal, using uppercase letters. + - [o]: convert an integer argument to unsigned octal. + - [s]: insert a string argument. + - [S]: convert a string argument to OCaml syntax (double quotes, escapes). + - [c]: insert a character argument. + - [C]: convert a character argument to OCaml syntax (single quotes, escapes). + - [f]: convert a floating-point argument to decimal notation, in the style [dddd.ddd]. + - [F]: convert a floating-point argument to OCaml syntax ([dddd.] or [dddd.ddd] or + [d.ddd e+-dd]). + - [e] or [E]: convert a floating-point argument to decimal notation, in the style + [d.ddd e+-dd] (mantissa and exponent). + - [g] or [G]: convert a floating-point argument to decimal notation, in style [f] or + [e], [E] (whichever is more compact). Moreover, any trailing zeros are removed from + the fractional part of the result and the decimal-point character is removed if there + is no fractional part remaining. + - [h] or [H]: convert a floating-point argument to hexadecimal notation, in the style + [0xh.hhhh e+-dd] (hexadecimal mantissa, exponent in decimal and denotes a power of 2). + - [B]: convert a boolean argument to the string true or false + - [b]: convert a boolean argument (deprecated; do not use in new programs). + - [ld], [li], [lu], [lx], [lX], [lo]: convert an int32 argument to the format + specified by the second letter (decimal, hexadecimal, etc). + - [nd], [ni], [nu], [nx], [nX], [no]: convert a nativeint argument to the format + specified by the second letter. + - [Ld], [Li], [Lu], [Lx], [LX], [Lo]: convert an int64 argument to the format + specified by the second letter. + - [a]: user-defined printer. Take two arguments and apply the first one to outchan + (the current output channel) and to the second argument. The first argument must + therefore have type [out_channel -> 'b -> unit] and the second ['b]. The output + produced by the function is inserted in the output of [fprintf] at the current point. + - [t]: same as [%a], but take only one argument (with type [out_channel -> unit]) and + apply it to [outchan]. + - [{ fmt %}]: convert a format string argument to its type digest. The argument must + have the same type as the internal format string [fmt]. + - [( fmt %)]: format string substitution. Take a format string argument and substitute + it to the internal format string fmt to print following arguments. The argument must + have the same type as the internal format string fmt. + - [!]: take no argument and flush the output. + - [%]: take no argument and output one [%] character. + - [@]: take no argument and output one [@] character. + - [,]: take no argument and output nothing: a no-op delimiter for conversion + specifications. + + The optional [flags] are: + + - [-]: left-justify the output (default is right justification). + - [0]: for numerical conversions, pad with zeroes instead of spaces. + - [+]: for signed numerical conversions, prefix number with a [+] sign if positive. + - space: for signed numerical conversions, prefix number with a space if positive. + - [#]: request an alternate formatting style for the hexadecimal and octal integer + types ([x], [X], [o], [lx], [lX], [lo], [Lx], [LX], [Lo]). + + The optional [width] is an integer indicating the minimal width of the result. For + instance, [%6d] prints an integer, prefixing it with spaces to fill at least 6 + characters. + + The optional [precision] is a dot [.] followed by an integer indicating how many + digits follow the decimal point in the [%f], [%e], and [%E] conversions. For instance, + [%.4f] prints a [float] with 4 fractional digits. + + The integer in a [width] or [precision] can also be specified as [*], in which case an + extra integer argument is taken to specify the corresponding [width] or + [precision]. This integer argument precedes immediately the argument to print. For + instance, [%.*f] prints a float with as many fractional digits as the value of the + argument given before the float. +*) + +open! Import0 + +(** Same as [fprintf], but does not print anything. Useful for ignoring some material when + conditionally printing. *) +val ifprintf : 'a -> ('r, 'a, 'c, unit) format4 -> 'r + +(** Same as [fprintf], but instead of printing on an output channel, returns a string. *) +val sprintf : ('r, unit, string) format -> 'r + +(** Same as [fprintf], but instead of printing on an output channel, appends the formatted + arguments to the given extensible buffer. *) +val bprintf : Caml.Buffer.t -> ('r, Caml.Buffer.t, unit) format -> 'r + +(** Same as [sprintf], but instead of returning the string, passes it to the first + argument. *) +val ksprintf : (string -> 'a) -> ('r, unit, string, 'a) format4 -> 'r + +(** Same as [bprintf], but instead of returning immediately, passes the buffer, after + printing, to its first argument. *) +val kbprintf : (Caml.Buffer.t -> 'a) -> Caml.Buffer.t -> ('r, Caml.Buffer.t, unit, 'a) format4 -> 'r + +(** {6 Formatting error and exit functions} + + These functions have a polymorphic return type, since they do not return. Naively, + this doesn't mix well with variadic functions: if you define, say, + + {[ + let f fmt = ksprintf (fun s -> failwith s) fmt + ]} + + then you find that [f "%d" : int -> 'a], as you'd expect, and [f "%d" 7 : 'a]. The + problem with this is that ['a] unifies with (say) [int -> 'b], so [f "%d" 7 4] is not + a type error -- the [4] is simply ignored. + + To mitigate this problem, these functions all take a final unit parameter. These + rarely arise as formatting positional parameters (they can do with e.g. "%a", but not + in a useful way) so they serve as an effective signpost for + "end of formatting arguments". *) + + +(** Raises [Failure]. *) +val failwithf : ('r, unit, string, unit -> _) format4 -> 'r + +(** Raises [Invalid_arg]. *) +val invalid_argf : ('r, unit, string, unit -> _) format4 -> 'r diff --git a/src/queue.ml b/src/queue.ml new file mode 100644 index 0000000..65617c9 --- /dev/null +++ b/src/queue.ml @@ -0,0 +1,503 @@ +open! Import + + +(* [t] stores the [t.length] queue elements at consecutive increasing indices of [t.elts], + mod the capacity of [t], which is [Option_array.length t.elts]. The capacity is + required to be a power of two (user-requested capacities are rounded up to the nearest + power), so that mod can quickly be computed using [land t.mask], where [t.mask = + capacity t - 1]. So, queue element [i] is at [t.elts.( (t.front + i) land t.mask )]. + + [num_mutations] is used to detect modification during iteration. *) +type 'a t = + { mutable num_mutations : int + ; mutable front : int + ; mutable mask : int + ; mutable length : int + ; mutable elts : 'a Option_array.t + } +[@@deriving_inline sexp_of] +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | { num_mutations = v_num_mutations; front = v_front; mask = v_mask; + length = v_length; elts = v_elts } -> + let bnds = [] in + let bnds = + let arg = Option_array.sexp_of_t _of_a v_elts in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "elts"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_length in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "length"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_mask in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "mask"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_front in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "front"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_num_mutations in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "num_mutations"; arg]) + :: bnds in + Ppx_sexp_conv_lib.Sexp.List bnds +[@@@end] + +module type S = Queue_intf.S + +let inc_num_mutations t = t.num_mutations <- t.num_mutations + 1 + +let capacity t = t.mask + 1 + +let elts_index t i = (t.front + i) land t.mask + +let unsafe_get t i = Option_array.unsafe_get_some_exn t.elts (elts_index t i) +let unsafe_is_set t i = Option_array.unsafe_is_some t.elts (elts_index t i) +let unsafe_set t i a = Option_array.unsafe_set_some t.elts (elts_index t i) a +let unsafe_unset t i = Option_array.unsafe_set_none t.elts (elts_index t i) + +let check_index_exn t i = + if i < 0 || i >= t.length + then Error.raise_s + (Sexp.message "Queue index out of bounds" + [ "index" , i |> Int.sexp_of_t + ; "length", t.length |> Int.sexp_of_t ]) +;; + +let get t i = check_index_exn t i; unsafe_get t i +let set t i a = + check_index_exn t i; + inc_num_mutations t; + unsafe_set t i a; +;; + +let is_empty t = t.length = 0 + +let length { length; _ } = length + +let ensure_no_mutation t num_mutations = + if t.num_mutations <> num_mutations + then Error.raise_s + (Sexp.message "mutation of queue during iteration" + [ "", t |> sexp_of_t (fun _ -> Sexp.Atom "_")]) +;; + +let compare = + let rec unsafe_compare_from compare_elt pos ~t1 ~t2 ~len1 ~len2 ~mut1 ~mut2 = + match pos = len1, pos = len2 with + | true , true -> 0 + | true , false -> -1 + | false, true -> 1 + | false, false -> + let x = compare_elt (unsafe_get t1 pos) (unsafe_get t2 pos) in + ensure_no_mutation t1 mut1; + ensure_no_mutation t2 mut2; + match x with + | 0 -> unsafe_compare_from compare_elt (pos + 1) ~t1 ~t2 ~len1 ~len2 ~mut1 ~mut2 + | n -> n + in + fun compare_elt t1 t2 -> + if phys_equal t1 t2 then + 0 + else + unsafe_compare_from compare_elt 0 ~t1 ~t2 + ~len1:t1.length + ~len2:t2.length + ~mut1:t1.num_mutations + ~mut2:t2.num_mutations +;; + +let equal = + let rec unsafe_equal_from equal_elt pos ~t1 ~t2 ~mut1 ~mut2 ~len = + pos = len + || + (let b = equal_elt (unsafe_get t1 pos) (unsafe_get t2 pos) in + ensure_no_mutation t1 mut1; + ensure_no_mutation t2 mut2; + b && unsafe_equal_from equal_elt (pos + 1) ~t1 ~t2 ~mut1 ~mut2 ~len) + in + fun equal_elt t1 t2 -> + phys_equal t1 t2 + || + (let len1 = t1.length in + let len2 = t2.length in + len1 = len2 + && + unsafe_equal_from equal_elt 0 ~t1 ~t2 + ~len:len1 + ~mut1:t1.num_mutations + ~mut2:t2.num_mutations) +;; + +let invariant invariant_a t = + let + { num_mutations + ; mask = _ + ; elts + ; front + ; length } = t + in + assert (front >= 0); + assert (front < capacity t); + let capacity = capacity t in + assert (capacity = Option_array.length elts); + assert (capacity >= 1); + assert (Int.is_pow2 capacity); + assert (length >= 0); + assert (length <= capacity); + for i = 0 to capacity- 1 do + if i < t.length + then (invariant_a (unsafe_get t i); ensure_no_mutation t num_mutations) + else assert (not (unsafe_is_set t i)) + done +;; + +let create (type a) ?capacity () : a t = + let capacity = + match capacity with + | None -> 1 + | Some capacity -> + if capacity < 0 + then Error.raise_s + (Sexp.message "cannot have queue with negative capacity" + [ "capacity", capacity |> Int.sexp_of_t ]) + else if capacity = 0 + then 1 + else Int.ceil_pow2 capacity + in + { num_mutations = 0 + ; front = 0 + ; mask = capacity - 1 + ; length = 0 + ; elts = Option_array.create ~len:capacity + } +;; + +let blit_to_array ~src dst = + assert (src.length <= Option_array.length dst); + let front_len = Int.min src.length (capacity src - src.front) in + let rest_len = src.length - front_len in + Option_array.blit ~len:front_len ~src:src.elts ~src_pos:src.front ~dst ~dst_pos:0; + Option_array.blit ~len:rest_len ~src:src.elts ~src_pos:0 ~dst ~dst_pos:front_len; +;; + +let set_capacity t desired_capacity = + (* We allow arguments less than 1 to [set_capacity], but translate them to 1 to simplify + the code that relies on the array length being a power of 2. *) + inc_num_mutations t; + let new_capacity = Int.ceil_pow2 (max 1 (max desired_capacity t.length)) in + if new_capacity <> capacity t then begin + let dst = Option_array.create ~len:new_capacity in + blit_to_array ~src:t dst; + t.front <- 0; + t.mask <- new_capacity - 1; + t.elts <- dst; + end; +;; + +let enqueue t a = + inc_num_mutations t; + if t.length = capacity t then set_capacity t (2 * t.length); + unsafe_set t t.length a; + t.length <- t.length + 1; +;; + +let dequeue_nonempty t = + inc_num_mutations t; + let elts = t.elts in + let front = t.front in + let res = Option_array.get_some_exn elts front in + Option_array.set_none elts front; + t.front <- elts_index t 1; + t.length <- t.length - 1; + res +;; + +let dequeue_exn t = + if is_empty t + then raise Caml.Queue.Empty + else dequeue_nonempty t +;; + +let dequeue t = + if is_empty t + then None + else Some (dequeue_nonempty t) +;; + +let front_nonempty t = Option_array.unsafe_get_some_exn t.elts t.front +let last_nonempty t = unsafe_get t (t.length - 1) + +let peek t = + if is_empty t + then None + else Some (front_nonempty t) +;; + +let peek_exn t = + if is_empty t + then raise Caml.Queue.Empty + else front_nonempty t +;; + +let last t = + if is_empty t + then None + else Some (last_nonempty t) +;; + +let last_exn t = + if is_empty t + then raise Caml.Queue.Empty + else last_nonempty t +;; + +let clear t = + inc_num_mutations t; + if t.length > 0 then begin + for i = 0 to t.length - 1 do + unsafe_unset t i; + done; + t.length <- 0; + t.front <- 0; + end; +;; + +let blit_transfer ~src ~dst ?len () = + inc_num_mutations src; + inc_num_mutations dst; + let len = + match len with + | None -> src.length + | Some len -> + if len < 0 + then Error.raise_s + (Sexp.message "Queue.blit_transfer: negative length" + [ "length", len |> Int.sexp_of_t ]); + min len src.length + in + if len > 0 then begin + set_capacity dst (max (capacity dst) (dst.length + len)); + let dst_start = dst.front + dst.length in + for i = 0 to len - 1 do + (* This is significantly faster than simply [enqueue dst (dequeue_nonempty src)] *) + let src_i = (src.front + i) land src.mask in + let dst_i = (dst_start + i) land dst.mask in + Option_array.unsafe_set_some dst.elts dst_i + (Option_array.unsafe_get_some_exn src.elts src_i); + Option_array.unsafe_set_none src.elts src_i; + done; + dst.length <- dst.length + len; + src.front <- (src.front + len) land src.mask; + src.length <- src.length - len; + end; +;; + +let enqueue_all t l = + (* Traversing the list up front to compute its length is probably (but not definitely) + better than doubling the underlying array size several times for large queues. *) + set_capacity t (Int.max (capacity t) (t.length + List.length l)); + List.iter l ~f:(fun x -> enqueue t x) +;; + +let fold t ~init ~f = + if t.length = 0 + then init + else begin + let num_mutations = t.num_mutations in + let r = ref init in + for i = 0 to t.length - 1 do + r := f !r (unsafe_get t i); + ensure_no_mutation t num_mutations; + done; + !r + end; +;; + +let foldi t ~init ~f = + let i = ref 0 in + fold t ~init ~f:(fun acc a -> + let acc = f !i acc a in + i := !i + 1; + acc) +;; + + +(* [iter] is implemented directly because implementing it in terms of [fold] is + slower. *) +let iter t ~f = + let num_mutations = t.num_mutations in + for i = 0 to t.length - 1 do + f (unsafe_get t i); + ensure_no_mutation t num_mutations; + done; +;; + +let iteri t ~f = + let num_mutations = t.num_mutations in + for i = 0 to t.length - 1 do + f i (unsafe_get t i); + ensure_no_mutation t num_mutations; + done; +;; + +module C = + Indexed_container.Make (struct + type nonrec 'a t = 'a t + let fold = fold + let iter = `Custom iter + let length = `Custom length + let foldi = `Custom foldi + let iteri = `Custom iteri + end) + +let count = C.count +let exists = C.exists +let find = C.find +let find_map = C.find_map +let fold_result = C.fold_result +let fold_until = C.fold_until +let for_all = C.for_all +let max_elt = C.max_elt +let mem = C.mem +let min_elt = C.min_elt +let sum = C.sum +let to_list = C.to_list + +let counti = C.counti +let existsi = C.existsi +let find_mapi = C.find_mapi +let findi = C.findi +let for_alli = C.for_alli + + +(* For [concat_map], [filter_map], and [filter], we don't create [t_result] with [t]'s + capacity because we have no idea how many elements [t_result] will ultimately hold. *) +let concat_map t ~f = + let t_result = create () in + iter t ~f:(fun a -> List.iter (f a) ~f:(fun b -> enqueue t_result b)); + t_result +;; + +let concat_mapi t ~f = + let t_result = create () in + iteri t ~f:(fun i a -> List.iter (f i a) ~f:(fun b -> enqueue t_result b)); + t_result +;; + +let filter_map t ~f = + let t_result = create () in + iter t ~f:(fun a -> + match f a with + | None -> () + | Some b -> enqueue t_result b); + t_result +;; + +let filter_mapi t ~f = + let t_result = create () in + iteri t ~f:(fun i a -> + match f i a with + | None -> () + | Some b -> enqueue t_result b); + t_result +;; + +let filter t ~f = + let t_result = create () in + iter t ~f:(fun a -> if f a then enqueue t_result a); + t_result +;; + +let filteri t ~f = + let t_result = create () in + iteri t ~f:(fun i a -> if f i a then enqueue t_result a); + t_result +;; + +let filter_inplace t ~f = + let t2 = filter t ~f in + clear t; + blit_transfer ~src:t2 ~dst:t (); +;; + +let filteri_inplace t ~f = + let t2 = filteri t ~f in + clear t; + blit_transfer ~src:t2 ~dst:t (); +;; + +let copy src = + let dst = create ~capacity:src.length () in + blit_to_array ~src dst.elts; + dst.length <- src.length; + dst +;; + +let of_list l = + (* Traversing the list up front to compute its length is probably (but not definitely) + better than doubling the underlying array size several times for large queues. *) + let t = create ~capacity:(List.length l) () in + List.iter l ~f:(fun x -> enqueue t x); + t +;; + +(* The queue [t] returned by [create] will have [t.length = 0], [t.front = 0], and + [capacity t = Int.ceil_pow2 len]. So, we only have to set [t.length] to [len] after + the blit to maintain all the invariants: [t.length] is equal to the number of elements + in the queue, [t.front] is the array index of the first element in the queue, and + [capacity t = Option_array.length t.elts]. *) +let init len ~f = + if len < 0 + then Error.raise_s + (Sexp.message "Queue.init: negative length" + [ "length", len |> Int.sexp_of_t ]); + let t = create ~capacity:len () in + assert (Option_array.length t.elts >= len); + for i = 0 to len - 1 do + Option_array.unsafe_set_some t.elts i (f i); + done; + t.length <- len; + t +;; + +let of_array a = init (Array.length a) ~f:(Array.unsafe_get a) + +let to_array t = Array.init t.length ~f:(fun i -> unsafe_get t i) + +let map ta ~f = + let num_mutations = ta.num_mutations in + let tb = create ~capacity:ta.length () in + tb.length <- ta.length; + for i = 0 to ta.length - 1 do + let b = f (unsafe_get ta i) in + ensure_no_mutation ta num_mutations; + Option_array.unsafe_set_some tb.elts i b; + done; + tb +;; + +let mapi t ~f = + let i = ref 0 in + map t ~f:(fun a -> + let result = f !i a in + i := !i + 1; + result) +;; + +let singleton x = + let t = create () in + enqueue t x; + t +;; + +let sexp_of_t sexp_of_a t = to_list t |> List.sexp_of_t sexp_of_a +let t_of_sexp a_of_sexp sexp = List.t_of_sexp a_of_sexp sexp |> of_list diff --git a/src/queue.mli b/src/queue.mli new file mode 100644 index 0000000..4e095d9 --- /dev/null +++ b/src/queue.mli @@ -0,0 +1 @@ +include Queue_intf.Queue (** @inline *) diff --git a/src/queue_intf.ml b/src/queue_intf.ml new file mode 100644 index 0000000..a403a47 --- /dev/null +++ b/src/queue_intf.ml @@ -0,0 +1,135 @@ +open! Import + +(** An interface for queues that follows Base's conventions, as opposed to OCaml's + standard [Queue] module. *) +module type S = sig + type 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Indexed_container.S1 with type 'a t := 'a t + + (** [singleton a] returns a queue with one element. *) + val singleton : 'a -> 'a t + + (** [of_list list] returns a queue [t] with the elements of [list] in the same order as + the elements of [list] (i.e. the first element of [t] is the first element of the + list). *) + val of_list : 'a list -> 'a t + val of_array : 'a array -> 'a t + + (** [init n ~f] is equivalent to [of_list (List.init n ~f)]. *) + val init : int -> f:(int -> 'a) -> 'a t + + (** [enqueue t a] adds [a] to the end of [t].*) + val enqueue : 'a t -> 'a -> unit + + (** [enqueue_all t list] adds all elements in [list] to [t] in order of [list]. *) + val enqueue_all : 'a t -> 'a list -> unit + + (** [dequeue t] removes and returns the front element of [t], if any. *) + val dequeue : 'a t -> 'a option + val dequeue_exn : 'a t -> 'a + + (** [peek t] returns but does not remove the front element of [t], if any. *) + val peek : 'a t -> 'a option + val peek_exn : 'a t -> 'a + + (** [clear t] discards all elements from [t]. *) + val clear : _ t -> unit + + (** [copy t] returns a copy of [t]. *) + val copy : 'a t -> 'a t + + val map : 'a t -> f:( 'a -> 'b) -> 'b t + val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + + (** Creates a new queue with elements equal to [List.concat_map ~f (to_list t)]. *) + val concat_map : 'a t -> f:( 'a -> 'b list) -> 'b t + val concat_mapi : 'a t -> f:(int -> 'a -> 'b list) -> 'b t + + (** [filter_map] creates a new queue with elements equal to [List.filter_map ~f (to_list + t)]. *) + val filter_map : 'a t -> f:( 'a -> 'b option) -> 'b t + val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + + (** [filter] is like [filter_map], except with [List.filter]. *) + val filter : 'a t -> f:( 'a -> bool) -> 'a t + val filteri : 'a t -> f:(int -> 'a -> bool) -> 'a t + + (** [filter_inplace t ~f] removes all elements of [t] that don't satisfy [f]. If [f] + raises, [t] is unchanged. This is inplace in that it modifies [t]; however, it uses + space linear in the final length of [t]. *) + val filter_inplace : 'a t -> f:( 'a -> bool) -> unit + val filteri_inplace : 'a t -> f:(int -> 'a -> bool) -> unit + +end + +module type Queue = sig + (** A queue implemented with an array. + + The implementation will grow the array as necessary. The array will + never automatically be shrunk, but the size can be interrogated and set + with [capacity] and [set_capacity]. + + Iteration functions ([iter], [fold], [map], [concat_map], [filter], + [filter_map], [filter_inplace], and some functions from [Container.S1]) + will raise if the queue is modified during iteration. + + Also see {!Linked_queue}, which has different performance characteristics. *) + + module type S = S + + type 'a t [@@deriving_inline compare] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + end[@@ocaml.doc "@inline"] + [@@@end] + + include S with type 'a t := 'a t + + include Equal. S1 with type 'a t := 'a t + include Invariant.S1 with type 'a t := 'a t + + (** Create an empty queue. *) + val create + : ?capacity : int (** default is [1]. *) + -> unit + -> _ t + + (** [last t] returns the most recently enqueued element in [t], if any. *) + val last : 'a t -> 'a option + val last_exn : 'a t -> 'a + + (** Transfers up to [len] elements from the front of [src] to the end of [dst], removing + them from [src]. It is an error if [len < 0]. + + Aside from a call to [set_capacity dst] if needed, runs in O([len]) time *) + val blit_transfer + : src : 'a t + -> dst : 'a t + -> ?len : int (** default is [length src] *) + -> unit + -> unit + + (** [get t i] returns the [i]'th element in [t], where the 0'th element is at the front of + [t] and the [length t - 1] element is at the back. *) + val get : 'a t -> int -> 'a + val set : 'a t -> int -> 'a -> unit + + (** Returns the current length of the backing array. *) + val capacity : _ t -> int + + (** [set_capacity t c] sets the capacity of [t]'s backing array to at least [max c (length + t)]. If [t]'s capacity changes, then this involves allocating a new backing array and + copying the queue elements over. [set_capacity] may decrease the capacity of [t], if + [c < capacity t]. *) + val set_capacity : _ t -> int -> unit + +end diff --git a/src/random.ml b/src/random.ml new file mode 100644 index 0000000..b7b579d --- /dev/null +++ b/src/random.ml @@ -0,0 +1,242 @@ +open! Import +open Caml.Random + +module Array = Array0 +module Int = Int0 + +(* Unfortunately, because the standard library does not expose + [Caml.Random.State.default], we have to construct our own. We then build the + [Caml.Random.int], [Caml.Random.bool] functions and friends using that default state in + exactly the same way as the standard library. + + One other trickiness is that we need access to the unexposed [Caml.Random.State.assign] + function, which accesses the unexposed state representation. So, we copy the + [State.repr] type definition and [assign] function to here from the standard library, + and use [Obj.magic] to get access to the underlying implementation. *) + +(* Regression tests ought to be deterministic because that way anyone who breaks the test + knows that it's their code that broke the test. If tests are nondeterministic, a test + failure may instead happen because the test runner got unlucky and uncovered an + existing bug in the code supposedly being "protected" by the test in question. *) +let forbid_nondeterminism_in_tests ~allow_in_tests = + if am_testing then + match allow_in_tests with + | Some true -> () + | None | Some false -> + failwith "\ +initializing Random with a nondeterministic seed is forbidden in inline tests" +;; + +external random_seed: unit -> int array = "caml_sys_random_seed";; +let random_seed ?allow_in_tests () = + forbid_nondeterminism_in_tests ~allow_in_tests; + random_seed () +;; + +module State = struct + include State + + let make_self_init ?allow_in_tests () = + forbid_nondeterminism_in_tests ~allow_in_tests; + make_self_init () + ;; + + type repr = { st : int array; mutable idx : int } + + let assign t1 t2 = + let t1 = (Caml.Obj.magic t1 : repr) in + let t2 = (Caml.Obj.magic t2 : repr) in + Array.blit ~src:t2.st ~src_pos:0 ~dst:t1.st ~dst_pos:0 + ~len:(Array.length t1.st); + t1.idx <- t2.idx; + ;; + + let full_init t seed = assign t (make seed) + + let default = + (* We define Core's default random state as a copy of OCaml's default random state. + This means that programs that use Core.Random will see the same sequence of random + bits as if they had used Caml.Random. However, because [get_state] returns a + copy, Core.Random and OCaml.Random are not using the same state. If a program used + both, each of them would go through the same sequence of random bits. To avoid + that, we reset OCaml's random state to a different seed, giving it a different + sequence. *) + let t = Caml.Random.get_state () in + Caml.Random.init 137; + t + ;; + + let int_on_64bits t bound = + if bound <= 0x3FFFFFFF (* (1 lsl 30) - 1 *) + then int t bound + else Caml.Int64.to_int (int64 t (Caml.Int64.of_int bound)) + ;; + + let int_on_32bits t bound = + (* Not always true with the JavaScript backend. *) + if bound <= 0x3FFFFFFF (* (1 lsl 30) - 1 *) + then int t bound + else Caml.Int32.to_int (int32 t (Caml.Int32.of_int bound)) + ;; + + let int = + match Word_size.word_size with + | W64 -> int_on_64bits + | W32 -> int_on_32bits + ;; + + let full_range_int64 = + let open Caml.Int64 in + let bits state = of_int (bits state) in + fun state -> + logxor (bits state) + (logxor + (shift_left (bits state) 30) + (shift_left (bits state) 60)) + ;; + + let full_range_int32 = + let open Caml.Int32 in + let bits state = of_int (bits state) in + fun state -> + logxor (bits state) + (shift_left (bits state) 30) + ;; + + let full_range_int_on_64bits state = + Caml.Int64.to_int (full_range_int64 state) + ;; + + let full_range_int_on_32bits state = + Caml.Int32.to_int (full_range_int32 state) + ;; + + let full_range_int = + match Word_size.word_size with + | W64 -> full_range_int_on_64bits + | W32 -> full_range_int_on_32bits + ;; + + let full_range_nativeint_on_64bits state = + Caml.Int64.to_nativeint (full_range_int64 state) + ;; + + let full_range_nativeint_on_32bits state = + Caml.Nativeint.of_int32 (full_range_int32 state) + ;; + + let full_range_nativeint = + match Word_size.word_size with + | W64 -> full_range_nativeint_on_64bits + | W32 -> full_range_nativeint_on_32bits + ;; + + let [@inline never] raise_crossed_bounds name lower_bound upper_bound string_of_bound = + Printf.failwithf "Random.%s: crossed bounds [%s > %s]" + name (string_of_bound lower_bound) (string_of_bound upper_bound) () + ;; + + let int_incl = + let rec in_range state lo hi = + let int = full_range_int state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "int" lo hi Int.to_string; + let diff = hi - lo in + if diff = Int.max_value + then lo + ((full_range_int state) land Int.max_value) + else if diff >= 0 + then lo + int state (Int.succ diff) + else in_range state lo hi + ;; + + let int32_incl = + let open Int32_replace_polymorphic_compare in + let rec in_range state lo hi = + let int = full_range_int32 state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + let open Caml.Int32 in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "int32" lo hi to_string; + let diff = sub hi lo in + if diff = max_int + then add lo (logand (full_range_int32 state) max_int) + else if diff >= 0l + then add lo (int32 state (succ diff)) + else in_range state lo hi + ;; + + let nativeint_incl = + let open Nativeint_replace_polymorphic_compare in + let rec in_range state lo hi = + let int = full_range_nativeint state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + let open Caml.Nativeint in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "nativeint" lo hi to_string; + let diff = sub hi lo in + if diff = max_int + then add lo (logand (full_range_nativeint state) max_int) + else if diff >= 0n + then add lo (nativeint state (succ diff)) + else in_range state lo hi + ;; + + let int64_incl = + let open Int64_replace_polymorphic_compare in + let rec in_range state lo hi = + let int = full_range_int64 state in + if int >= lo && int <= hi + then int + else in_range state lo hi + in + let open Caml.Int64 in + fun state lo hi -> + if lo > hi then raise_crossed_bounds "int64" lo hi to_string; + let diff = sub hi lo in + if diff = max_int + then add lo (logand (full_range_int64 state) max_int) + else if diff >= 0L + then add lo (int64 state (succ diff)) + else in_range state lo hi + ;; + + let float_range state lo hi = + let open Float_replace_polymorphic_compare in + if lo > hi then raise_crossed_bounds "float" lo hi Caml.string_of_float; + lo +. float state (hi -. lo) + ;; +end + +let default = State.default + +let bits () = State.bits default + +let int x = State.int default x +let int32 x = State.int32 default x +let nativeint x = State.nativeint default x +let int64 x = State.int64 default x +let float x = State.float default x + +let int_incl x y = State.int_incl default x y +let int32_incl x y = State.int32_incl default x y +let nativeint_incl x y = State.nativeint_incl default x y +let int64_incl x y = State.int64_incl default x y +let float_range x y = State.float_range default x y + +let bool () = State.bool default + +let full_init seed = State.full_init default seed +let init seed = full_init [| seed |] +let self_init ?allow_in_tests () = full_init (random_seed ?allow_in_tests ()) + +let set_state s = State.assign default s diff --git a/src/random.mli b/src/random.mli new file mode 100644 index 0000000..5c3e3b8 --- /dev/null +++ b/src/random.mli @@ -0,0 +1,129 @@ +(** Pseudo-random number generation. + + This is a wrapper of the standard library's [Random] library, though it does not share + state with that library. +*) + +(*_ + (***********************************************************************) + (* *) + (* Objective Caml *) + (* *) + (* Damien Doligez, projet Para, INRIA Rocquencourt *) + (* *) + (* Copyright 1996 Institut National de Recherche en Informatique et *) + (* en Automatique. All rights reserved. This file is distributed *) + (* under the terms of the Apache 2.0 license. See ../THIRD-PARTY.txt *) + (* for details. *) + (* *) + (***********************************************************************) *) + +open! Import + +(** {6 Basic functions} *) + +(** Note that all of these "basic" functions mutate a global random state. *) + +(** Initialize the generator, using the argument as a seed. The same seed will always + yield the same sequence of numbers. *) +val init : int -> unit + +(** Same as {!Random.init} but takes more data as seed. *) +val full_init : int array -> unit + +(** Initialize the generator with a more-or-less random seed chosen in a system-dependent + way. By default, [self_init] is disallowed in inline tests, as it's often used for no + good reason and it just creates nondeterministic failures for everyone. Passing + [~allow_in_tests:true] removes this restriction in case you legitimately want + nondeterministic values, like in [Filename.temp_dir]. *) +val self_init : ?allow_in_tests:bool -> unit -> unit + +(** Return 30 random bits in a nonnegative integer. @before 3.12.0 used a different + algorithm (affects all the following functions) *) +val bits : unit -> int + +(** [Random.int bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val int : int -> int + +(** [Random.int32 bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val int32 : int32 -> int32 + +(** [Random.nativeint bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val nativeint : nativeint -> nativeint + +(** [Random.int64 bound] returns a random integer between 0 (inclusive) and [bound] + (exclusive). [bound] must be greater than 0. *) +val int64 : int64 -> int64 + +(** [Random.float bound] returns a random floating-point number between 0 (inclusive) and + [bound] (exclusive). If [bound] is negative, the result is negative or zero. If + [bound] is 0, the result is 0. *) +val float : float -> float + +(** Produces a random value between the given inclusive bounds. Raises if bounds are + given in decreasing order. *) +val int_incl : int -> int -> int +val int32_incl : int32 -> int32 -> int32 +val nativeint_incl : nativeint -> nativeint -> nativeint +val int64_incl : int64 -> int64 -> int64 + +(** Produces a value between the given bounds (inclusive and exclusive, respectively). + Raises if bounds are given in decreasing order. *) +val float_range : float -> float -> float + +(** [Random.bool ()] returns [true] or [false] with probability 0.5 each. *) +val bool : unit -> bool + +(** {6 Advanced functions} *) + +(** The functions from module [State] manipulate the current state of the random generator + explicitly. This allows using one or several deterministic PRNGs, even in a + multi-threaded program, without interference from other parts of the program. + + Note that [Random.get_state] from the standard library is not exposed, because it + misleadingly makes a copy of random state, which is not typically the desired outcome + for accessing the shared state. + + Obtaining multiple generators with good independence properties is nontrivial; see + the [Splittable_random] library for that. *) +module State : sig + type t + + (** This gives access to the default random state, allowing user code to share (and + thereby mutate) the random state used by the main functions in [Random]. *) + val default : t + + (** Creates a new state and initializes it with the given seed. *) + val make : int array -> t + + (** Creates a new state and initializes it with a system-dependent low-entropy seed. *) + val make_self_init : ?allow_in_tests:bool -> unit -> t + + val copy : t -> t + + (** These functions are the same as the basic functions, except that they use (and + update) the given PRNG state instead of the default one. *) + + val bits : t -> int + + val int : t -> int -> int + val int32 : t -> int32 -> int32 + val nativeint : t -> nativeint -> nativeint + val int64 : t -> int64 -> int64 + val float : t -> float -> float + + val int_incl : t -> int -> int -> int + val int32_incl : t -> int32 -> int32 -> int32 + val nativeint_incl : t -> nativeint -> nativeint -> nativeint + val int64_incl : t -> int64 -> int64 -> int64 + + val float_range : t -> float -> float -> float + + val bool : t -> bool +end + +(** Sets the state of the generator used by the basic functions. *) +val set_state : State.t -> unit diff --git a/src/ref.ml b/src/ref.ml new file mode 100644 index 0000000..15bfa7a --- /dev/null +++ b/src/ref.ml @@ -0,0 +1,47 @@ +open! Import + +(* In the definition of [t], we do not have [[@@deriving_inline compare, sexp][@@@end]] because + in general, syntax extensions tend to use the implementation when available rather than + using the alias. Here that would lead to use the record representation [ { mutable + contents : 'a } ] which would result in different (and unwanted) behavior. *) +type 'a t = 'a ref = { mutable contents : 'a } + +include (struct + type 'a t = 'a ref [@@deriving_inline compare, equal, sexp] + let compare : 'a . ('a -> 'a -> int) -> 'a t -> 'a t -> int = compare_ref + let equal : 'a . ('a -> 'a -> bool) -> 'a t -> 'a t -> bool = equal_ref + let t_of_sexp : + 'a . (Ppx_sexp_conv_lib.Sexp.t -> 'a) -> Ppx_sexp_conv_lib.Sexp.t -> 'a t = + ref_of_sexp + let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + sexp_of_ref + [@@@end] +end : sig + type 'a t = 'a ref [@@deriving_inline compare, equal, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + end with type 'a t := 'a t) + +external create : 'a -> 'a t = "%makemutable" +external ( ! ) : 'a t -> 'a = "%field0" +external ( := ) : 'a t -> 'a -> unit = "%setfield0" + +let swap t1 t2 = + let tmp = !t1 in + t1 := !t2; + t2 := tmp + +let replace t f = t := f !t + +let set_temporarily t a ~f = + let restore_to = !t in + t := a; + Exn.protect ~f ~finally:(fun () -> t := restore_to); +;; diff --git a/src/ref.mli b/src/ref.mli new file mode 100644 index 0000000..7024c80 --- /dev/null +++ b/src/ref.mli @@ -0,0 +1,30 @@ +(** Module for the type [ref], mutable indirection cells [r] containing a value of type + ['a], accessed with [!r] and set by [r := a]. *) + +open! Import + +type 'a t = 'a Caml.ref = { mutable contents : 'a } +[@@deriving_inline compare, equal, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val equal : ('a -> 'a -> bool) -> 'a t -> 'a t -> bool + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + +(*_ defined as externals to avoid breaking the inliner *) +external create : 'a -> 'a t = "%makemutable" +external ( ! ) : 'a t -> 'a = "%field0" +external ( := ) : 'a t -> 'a -> unit = "%setfield0" + +(** [swap t1 t2] swaps the values in [t1] and [t2]. *) +val swap : 'a t -> 'a t -> unit + +(** [replace t f] is [t := f !t] *) +val replace : 'a t -> ('a -> 'a) -> unit + +(** [set_temporarily t a ~f] sets [t] to [a], calls [f ()], and then restores [t] to its + value prior to [set_temporarily] being called, whether [f] returns or raises. *) +val set_temporarily : 'a t -> 'a -> f:(unit -> 'b) -> 'b diff --git a/src/result.ml b/src/result.ml new file mode 100644 index 0000000..3e35ef6 --- /dev/null +++ b/src/result.ml @@ -0,0 +1,188 @@ +open! Import + +type ('a, 'b) t = ('a, 'b) Caml.result = + | Ok of 'a + | Error of 'b +[@@deriving_inline sexp, compare, hash] +let t_of_sexp : type a b. + (Ppx_sexp_conv_lib.Sexp.t -> a) -> + (Ppx_sexp_conv_lib.Sexp.t -> b) -> Ppx_sexp_conv_lib.Sexp.t -> (a, b) t + = + let _tp_loc = "src/result.ml.t" in + fun _of_a -> + fun _of_b -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("ok"|"Ok" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Ok v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("error"|"Error" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_b v0 in Error v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("ok"|"Ok") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("error"|"Error") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : type a b. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (b -> Ppx_sexp_conv_lib.Sexp.t) -> (a, b) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_b -> + function + | Ok v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List [Ppx_sexp_conv_lib.Sexp.Atom "Ok"; v0] + | Error v0 -> + let v0 = _of_b v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Error"; v0] +let compare : + 'a 'b . + ('a -> 'a -> int) -> ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + = + fun _cmp__a -> + fun _cmp__b -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Ok _a__003_, Ok _b__004_) -> _cmp__a _a__003_ _b__004_ + | (Ok _, _) -> (-1) + | (_, Ok _) -> 1 + | (Error _a__005_, Error _b__006_) -> _cmp__b _a__005_ _b__006_) +let hash_fold_t : type a b. + (Ppx_hash_lib.Std.Hash.state -> a -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> b -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> (a, b) t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_a -> + fun _hash_fold_b -> + fun hsv -> + fun arg -> + match arg with + | Ok _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in _hash_fold_a hsv _a0 + | Error _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in _hash_fold_b hsv _a0 +[@@@end] + +include Monad.Make2 (struct + type nonrec ('a, 'b) t = ('a,'b) t + + let bind x ~f = match x with + | Error _ as x -> x + | Ok x -> f x + + let map x ~f = match x with + | Error _ as x -> x + | Ok x -> Ok (f x) + + let map = `Custom map + + let return x = Ok x + end) + +let ignore = ignore_m + +let fail x = Error x;; +let failf format = Printf.ksprintf fail format + +let map_error t ~f = match t with + | Ok _ as x -> x + | Error x -> Error (f x) + +let is_ok = function + | Ok _ -> true + | Error _ -> false + +let is_error = function + | Ok _ -> false + | Error _ -> true + +let ok = function + | Ok x -> Some x + | Error _ -> None + +let error = function + | Ok _ -> None + | Error x -> Some x + +let of_option opt ~error = + match opt with + | Some x -> Ok x + | None -> Error error + +let iter v ~f = match v with + | Ok x -> f x + | Error _ -> () + +let iter_error v ~f = match v with + | Ok _ -> () + | Error x -> f x + +let ok_fst = function + | Ok x -> `Fst x + | Error x -> `Snd x + +let ok_if_true bool ~error = + if bool + then Ok () + else Error error + +let try_with f = + try Ok (f ()) + with exn -> Error exn + +let ok_unit = Ok () + +let ok_exn = function + | Ok x -> x + | Error exn -> raise exn + +let ok_or_failwith = function + | Ok x -> x + | Error str -> failwith str + +module Export = struct + type ('ok, 'err) _result = + ('ok, 'err) t = + | Ok of 'ok + | Error of 'err + + let is_error = is_error + let is_ok = is_ok +end + +let combine t1 t2 ~ok ~err = + match t1, t2 with + | Ok _, Error e | Error e, Ok _ -> Error e + | Ok ok1 , Ok ok2 -> Ok (ok ok1 ok2 ) + | Error err1, Error err2 -> Error (err err1 err2) +;; + +let combine_errors l = + let ok, errs = List1.partition_map l ~f:ok_fst in + match errs with + | [] -> Ok ok + | _ :: _ -> Error errs +;; + +let combine_errors_unit l = map (combine_errors l) ~f:(fun (_ : unit list) -> ()) diff --git a/src/result.mli b/src/result.mli new file mode 100644 index 0000000..4e3501c --- /dev/null +++ b/src/result.mli @@ -0,0 +1,108 @@ +(** [Result] is often used to handle error messages. *) + +open! Import + +(** ['ok] is a function's expected return type, and ['err] is often an error message + string. + + {[ + let ric_of_ticker = function + | "IBM" -> Ok "IBM.N" + | "MSFT" -> Ok "MSFT.OQ" + | "AA" -> Ok "AA.N" + | "CSCO" -> Ok "CSCO.OQ" + | _ as ticker -> Error (sprintf "can't find ric of %s" ticker) + ]} + + The return type of ric_of_ticker could be [string option], but [(string, string) + Result.t] gives more control over the error message. *) +type ('ok, 'err) t = ('ok, 'err) Caml.result = + | Ok of 'ok + | Error of 'err +[@@deriving_inline sexp, compare, hash] +include +sig + [@@@ocaml.warning "-32"] + include + Ppx_sexp_conv_lib.Sexpable.S2 with type ('ok,'err) t := ('ok, 'err) t + val compare : + ('ok -> 'ok -> int) -> + ('err -> 'err -> int) -> ('ok, 'err) t -> ('ok, 'err) t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'ok -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> 'err -> Ppx_hash_lib.Std.Hash.state) + -> + Ppx_hash_lib.Std.Hash.state -> + ('ok, 'err) t -> Ppx_hash_lib.Std.Hash.state +end[@@ocaml.doc "@inline"] +[@@@end] + +include Monad.S2 with type ('a,'err) t := ('a,'err) t + +val ignore : (_, 'err) t -> (unit, 'err) t + +val fail : 'err -> (_, 'err) t + +(** e.g., [failf "Couldn't find bloogle %s" (Bloogle.to_string b)]. *) +val failf : ('a, unit, string, (_, string) t) format4 -> 'a + +val is_ok : (_, _) t -> bool +val is_error : (_, _) t -> bool + +val ok : ('ok, _ ) t -> 'ok option +val ok_exn : ('ok, exn ) t -> 'ok +val ok_or_failwith : ('ok, string) t -> 'ok + +val error : (_ , 'err) t -> 'err option + +val of_option : 'ok option -> error:'err -> ('ok, 'err) t + +val iter : ('ok, _ ) t -> f:('ok -> unit) -> unit +val iter_error : (_ , 'err) t -> f:('err -> unit) -> unit + +val map : ('ok, 'err) t -> f:('ok -> 'c) -> ('c , 'err) t +val map_error : ('ok, 'err) t -> f:('err -> 'c) -> ('ok, 'c ) t + +(** Returns [Ok] if both are [Ok] and [Error] otherwise. *) +val combine + : ('ok1, 'err) t + -> ('ok2, 'err) t + -> ok: ('ok1 -> 'ok2 -> 'ok3) + -> err:('err -> 'err -> 'err) + -> ('ok3, 'err) t + +(** [combine_errors ts] returns [Ok] if every element in [ts] is [Ok], else it returns + [Error] with all the errors in [ts]. + + This is similar to [all] from [Monad.S2], with the difference that [all] only returns + the first error. *) +val combine_errors : ('ok, 'err) t list -> ('ok list, 'err list) t + +(** [combine_errors_unit] returns [Ok] if every element in [ts] is [Ok ()], else it + returns [Error] with all the errors in [ts], like [combine_errors]. *) +val combine_errors_unit : (unit, 'err) t list -> (unit, 'err list) t + +(** [ok_fst] is useful with [List.partition_map]. Continuing the above example: + {[ + let rics, errors = List.partition_map ~f:Result.ok_fst + (List.map ~f:ric_of_ticker ["AA"; "F"; "CSCO"; "AAPL"]) ]} *) +val ok_fst : ('ok, 'err) t -> [ `Fst of 'ok | `Snd of 'err ] + +(** [ok_if_true] returns [Ok ()] if [bool] is true, and [Error error] if it is false. *) +val ok_if_true : bool -> error : 'err -> (unit, 'err) t + +val try_with : (unit -> 'a) -> ('a, exn) t + +(** [ok_unit = Ok ()], used to avoid allocation as a performance hack. *) +val ok_unit : (unit, _) t + +module Export : sig + type ('ok, 'err) _result + = ('ok, 'err) t + = Ok of 'ok + | Error of 'err + + val is_ok : (_, _) t -> bool + val is_error : (_, _) t -> bool +end + diff --git a/src/runtime.js b/src/runtime.js new file mode 100644 index 0000000..dfab440 --- /dev/null +++ b/src/runtime.js @@ -0,0 +1,118 @@ +//Provides: Base_int_math_int_popcount const +function Base_int_math_int_popcount(v) { + v = v - ((v >>> 1) & 0x55555555); + v = (v & 0x33333333) + ((v >>> 2) & 0x33333333); + return ((v + (v >>> 4) & 0xF0F0F0F) * 0x1010101) >>> 24; +} + +//Provides: Base_clear_caml_backtrace_pos const +function Base_clear_caml_backtrace_pos(x) { + return 0; +} + +//Provides: Base_int_math_int32_clz const +function Base_int_math_int32_clz(x) { + var n = 32; + var y; + y = x >>16; if (y != 0) { n = n -16; x = y; } + y = x >> 8; if (y != 0) { n = n - 8; x = y; } + y = x >> 4; if (y != 0) { n = n - 4; x = y; } + y = x >> 2; if (y != 0) { n = n - 2; x = y; } + y = x >> 1; if (y != 0) return n - 2; + return n - x; +} + +//Provides: Base_int_math_int_clz const +//Requires: Base_int_math_int32_clz +function Base_int_math_int_clz(x) { return Base_int_math_int32_clz(x); } + +//Provides: Base_int_math_nativeint_clz const +//Requires: Base_int_math_int32_clz +function Base_int_math_nativeint_clz(x) { return Base_int_math_int32_clz(x); } + +//Provides: Base_int_math_int64_clz const +//Requires: caml_int64_shift_right_unsigned, caml_int64_is_zero, caml_int64_to_int32 +function Base_int_math_int64_clz(x) { + var n = 64; + var y; + y = caml_int64_shift_right_unsigned(x, 32); + if (!caml_int64_is_zero(y)) { n = n -32; x = y; } + y = caml_int64_shift_right_unsigned(x, 16); + if (!caml_int64_is_zero(y)) { n = n -16; x = y; } + y = caml_int64_shift_right_unsigned(x, 8); + if (!caml_int64_is_zero(y)) { n = n - 8; x = y; } + y = caml_int64_shift_right_unsigned(x, 4); + if (!caml_int64_is_zero(y)) { n = n - 4; x = y; } + y = caml_int64_shift_right_unsigned(x, 2); + if (!caml_int64_is_zero(y)) { n = n - 2; x = y; } + y = caml_int64_shift_right_unsigned(x, 1); + if (!caml_int64_is_zero(y)) return n - 2; + return n - caml_int64_to_int32(x); +} + +//Provides: Base_int_math_int_pow_stub const +function Base_int_math_int_pow_stub(base, exponent) { + var one = 1; + var mul = [one, base, one, one]; + var res = one; + while (!exponent==0) { + mul[1] = (mul[1] * mul[3]) | 0; + mul[2] = (mul[1] * mul[1]) | 0; + mul[3] = (mul[2] * mul[1]) | 0; + res = (res * mul[exponent & 3]) | 0; + exponent = exponent >> 2; + } + return res; +} + +//Provides: Base_int_math_int64_pow_stub const +//Requires: caml_int64_mul, caml_int64_is_zero, caml_int64_shift_right_unsigned +function Base_int_math_int64_pow_stub(base, exponent) { + var one = [255,1,0,0]; + var mul = [one, base, one, one]; + var res = one; + while (!caml_int64_is_zero(exponent)) { + mul[1] = caml_int64_mul(mul[1], mul[3]); + mul[2] = caml_int64_mul(mul[1], mul[1]); + mul[3] = caml_int64_mul(mul[2], mul[1]); + res = caml_int64_mul(res, mul[exponent[1] & 3]); + exponent = caml_int64_shift_right_unsigned(exponent, 2); + } + return res; +} + +//Provides: Base_internalhash_fold_int64 +//Requires: caml_hash_mix_int64 +var Base_internalhash_fold_int64 = caml_hash_mix_int64; +//Provides: Base_internalhash_fold_int +//Requires: caml_hash_mix_int +var Base_internalhash_fold_int = caml_hash_mix_int; +//Provides: Base_internalhash_fold_float +//Requires: caml_hash_mix_float +var Base_internalhash_fold_float = caml_hash_mix_float; +//Provides: Base_internalhash_fold_string +//Requires: caml_hash_mix_string +var Base_internalhash_fold_string = caml_hash_mix_string; +//Provides: Base_internalhash_get_hash_value +//Requires: caml_hash_mix_final +function Base_internalhash_get_hash_value(seed) { + var h = caml_hash_mix_final(seed); + return h & 0x3FFFFFFF; +} + +//Provides: Base_hash_string mutable +//Requires: caml_hash +function Base_hash_string(s) { + return caml_hash(1,1,0,s) +} +//Provides: Base_hash_double const +//Requires: caml_hash +function Base_hash_double(d) { + return caml_hash(1,1,0,d); +} + +//Provides: Base_am_testing const +//Weakdef +function Base_am_testing(x) { + return 0; +} diff --git a/src/select-bytes-set-primitives/select.ml b/src/select-bytes-set-primitives/select.ml new file mode 100644 index 0000000..20cfcbe --- /dev/null +++ b/src/select-bytes-set-primitives/select.ml @@ -0,0 +1,20 @@ +let () = + let ver, output = + try + match Sys.argv with + | [|_; "-ocaml-version"; v; "-o"; fn|] -> + (Scanf.sscanf v "%d.%d" (fun major minor -> (major, minor)), + fn) + | _ -> raise Exit + with _ -> + failwith "bad command line arguments" + in + let prefix = + if ver >= (4, 04) then "bytes" else "string" + in + let oc = open_out output in + Printf.fprintf oc {| +external set : %s -> int -> char -> unit = "%%%s_safe_set" +external unsafe_set : %s -> int -> char -> unit = "%%%s_unsafe_set" +|} prefix prefix prefix prefix; + close_out oc diff --git a/src/select-int63-backend/select.ml b/src/select-int63-backend/select.ml new file mode 100644 index 0000000..535eca9 --- /dev/null +++ b/src/select-int63-backend/select.ml @@ -0,0 +1,29 @@ +let () = + let portable_int63, arch_sixtyfour, output = + try + match Sys.argv with + | [|_; "-portable-int63"; x; "-arch-sixtyfour"; y; "-o"; fn|] -> + let x = + match x with + | "true" | "!false" -> true + | "false" | "!true" -> false + | _ -> failwith "invalid value for -portable-int63" + in + (x, + bool_of_string y, + fn) + | _ -> raise Exit + with _ -> + failwith "bad command line arguments" + in + let backend = + if portable_int63 then + "Dynamic" + else if arch_sixtyfour then + "Native" + else + "Emulated" + in + let oc = open_out output in + Printf.fprintf oc "include Int63_backends.%s" backend; + close_out oc diff --git a/src/sequence.ml b/src/sequence.ml new file mode 100644 index 0000000..48badf4 --- /dev/null +++ b/src/sequence.ml @@ -0,0 +1,1097 @@ +open! Import +open Container_intf.Export + +module Array = Array0 +module List = List0 + +module Step = struct + (* 'a is an item in the sequence, 's is the state that will produce the remainder of + the sequence *) + type ('a,'s) t = + | Done + | Skip of 's + | Yield of 'a * 's + [@@deriving_inline sexp_of] + let sexp_of_t : type a s. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (s -> Ppx_sexp_conv_lib.Sexp.t) -> (a, s) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_s -> + function + | Done -> Ppx_sexp_conv_lib.Sexp.Atom "Done" + | Skip v0 -> + let v0 = _of_s v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Skip"; v0] + | Yield (v0, v1) -> + let v0 = _of_a v0 + and v1 = _of_s v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Yield"; v0; v1] + [@@@end] +end + +open Step + +(* 'a is an item in the sequence, 's is the state that will produce the remainder of the + sequence *) +type +_ t = + | Sequence : 's * ('s -> ('a,'s) Step.t) -> 'a t + +type 'a sequence = 'a t + +module Expert = struct + let next_step (Sequence (s, f)) = + match f s with + | Done -> Done + | Skip s -> Skip (Sequence (s, f)) + | Yield (a, s) -> Yield (a, Sequence (s, f)) + ;; + + let delayed_fold_step s ~init ~f ~finish = + let rec loop s next finish f acc = + match next s with + | Done -> finish acc + | Skip s -> f acc None ~k:(loop s next finish f) + | Yield (a, s) -> f acc (Some a) ~k:(loop s next finish f) + in + match s with + | Sequence(s, next) -> loop s next finish f init + ;; +end + +let unfold_step ~init ~f = + Sequence (init,f) + +let unfold ~init ~f = + unfold_step ~init + ~f:(fun s -> + match f s with + | None -> Step.Done + | Some(a,s) -> Step.Yield(a,s)) + +let unfold_with s ~init ~f = + match s with + | Sequence(s, next) -> + Sequence((init, s) , + (fun (seed, s) -> + match next s with + | Done -> Done + | Skip s -> Skip (seed, s) + | Yield(a,s) -> + match f seed a with + | Done -> Done + | Skip seed -> Skip (seed, s) + | Yield(a,seed) -> Yield(a,(seed,s)))) + +let unfold_with_and_finish s ~init ~running_step ~inner_finished ~finishing_step = + match s with + | Sequence (s, next) -> + Sequence (`Inner_running (init, s), (fun state -> + match state with + | `Inner_running (state, inner_state) -> begin + match next inner_state with + | Done -> + Skip (`Inner_finished (inner_finished state)) + | Skip inner_state -> + Skip (`Inner_running (state, inner_state)) + | Yield (x, inner_state) -> + match running_step state x with + | Done -> Done + | Skip state -> + Skip (`Inner_running (state, inner_state)) + | Yield (y, state) -> + Yield (y, `Inner_running (state, inner_state)) + end + | `Inner_finished state -> begin + match finishing_step state with + | Done -> Done + | Skip state -> + Skip (`Inner_finished state) + | Yield (y, state) -> + Yield (y, `Inner_finished state) + end)) + +let of_list l = + unfold_step ~init:l + ~f:(function + | [] -> Done + | x::l -> Yield(x,l)) + + +let fold t ~init ~f = + let rec loop seed v next f = + match next seed with + | Done -> v + | Skip s -> loop s v next f + | Yield(a,s) -> loop s (f v a) next f + in + match t with + | Sequence(seed, next) -> loop seed init next f + +let to_list_rev t = + fold t ~init:[] ~f:(fun l x -> x::l) + + +let to_list (Sequence(s,next)) = + let safe_to_list t = + List.rev (to_list_rev t) + in + let rec to_list s next i = + if i = 0 then safe_to_list (Sequence(s,next)) + else + match next s with + | Done -> [] + | Skip s -> to_list s next i + | Yield(a,s) -> a::(to_list s next (i-1)) + in + to_list s next 500 + +let sexp_of_t sexp_of_a t = sexp_of_list sexp_of_a (to_list t) + +let range ?(stride=1) ?(start=`inclusive) ?(stop=`exclusive) start_v stop_v = + let step = + match stop with + | `inclusive when stride >= 0 -> + fun i -> if i > stop_v then Done else Yield(i, i + stride) + | `inclusive -> + fun i -> if i < stop_v then Done else Yield(i, i + stride) + | `exclusive when stride >= 0 -> + fun i -> if i >= stop_v then Done else Yield(i,i + stride) + | `exclusive -> + fun i -> if i <= stop_v then Done else Yield(i,i + stride) + in + let init = + match start with + | `inclusive -> start_v + | `exclusive -> start_v + stride + in + unfold_step ~init ~f:step + +let of_lazy t_lazy = + unfold_step ~init:t_lazy ~f:(fun t_lazy -> + let Sequence (s, next) = Lazy.force t_lazy in + match next s with + | Done -> Done + | Skip s -> Skip (let v = Sequence (s, next) in lazy v) + | Yield (x, s) -> Yield (x, let v = Sequence (s, next) in lazy v)) + +let map t ~f = + match t with + | Sequence(seed, next) -> + Sequence(seed, + fun seed -> + match next seed with + | Done -> Done + | Skip s -> Skip s + | Yield(a,s) -> Yield(f a,s)) + + +let mapi t ~f = + match t with + | Sequence(s, next) -> + Sequence((0,s), + fun (i,s) -> + match next s with + | Done -> Done + | Skip s -> Skip (i,s) + | Yield(a,s) -> Yield(f i a,(i+1,s))) + +let folding_map t ~init ~f = + unfold_with t ~init ~f:(fun acc x -> + let acc, x = f acc x in + Yield (x, acc)) + +let folding_mapi t ~init ~f = + unfold_with t ~init:(0, init) ~f:(fun (i, acc) x -> + let acc, x = f i acc x in + Yield (x, (i+1, acc))) + +let filter t ~f = + match t with + | Sequence(seed, next) -> + Sequence(seed, + fun seed -> + match next seed with + | Done -> Done + | Skip s -> Skip s + | Yield(a,s) when f a -> Yield(a,s) + | Yield (_,s) -> Skip s) + +let filteri t ~f = + map ~f:snd ( + filter (mapi t ~f:(fun i s -> (i,s))) + ~f:(fun (i,s) -> f i s)) + +let length t = + let rec loop i s next = + match next s with + | Done -> i + | Skip s -> loop i s next + | Yield(_,s) -> loop (i+1) s next + in + match t with + | Sequence (seed, next) -> loop 0 seed next + +let to_list_rev_with_length t = + fold t ~init:([],0) ~f:(fun (l,i) x -> (x::l,i+1)) + +let to_array t = + let (l,len) = to_list_rev_with_length t in + match l with + | [] -> [||] + | x::l -> + let a = Array.create ~len x in + let rec loop i l = + match l with + | [] -> assert (i = -1) + | x::l -> a.(i) <- x; loop (i-1) l + in + loop (len - 2) l; + a + +let find t ~f = + let rec loop s next f = + match next s with + | Done -> None + | Yield(a,_) when f a -> Some a + | Yield(_,s) | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + +let find_map t ~f = + let rec loop s next f = + match next s with + | Done -> None + | Yield(a,s) -> + (match f a with + | None -> loop s next f + | some_b -> some_b) + | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + + +let find_mapi t ~f = + let rec loop s next f i = + match next s with + | Done -> None + | Yield(a,s) -> + (match f i a with + | None -> loop s next f (i+1) + | some_b -> some_b) + | Skip s -> loop s next f i + in + match t with + | Sequence (seed, next) -> loop seed next f 0 + +let for_all t ~f = + let rec loop s next f = + match next s with + | Done -> true + | Yield(a,_) when not (f a) -> false + | Yield (_,s) | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + +let for_alli t ~f = + let rec loop s next f i = + match next s with + | Done -> true + | Yield(a,_) when not (f i a) -> false + | Yield (_,s) -> loop s next f (i+1) + | Skip s -> loop s next f i + in + match t with + | Sequence (seed, next) -> loop seed next f 0 + +let exists t ~f = + let rec loop s next f = + match next s with + | Done -> false + | Yield(a,_) when f a -> true + | Yield(_,s) | Skip s -> loop s next f + in + match t with + | Sequence (seed, next) -> loop seed next f + +let existsi t ~f = + let rec loop s next f i = + match next s with + | Done -> false + | Yield(a,_) when f i a -> true + | Yield(_,s) -> loop s next f (i+1) + | Skip s -> loop s next f i + in + match t with + | Sequence (seed, next) -> loop seed next f 0 + +let iter t ~f = + let rec loop seed next f = + match next seed with + | Done -> () + | Skip s -> loop s next f + | Yield(a,s) -> + begin + f a; + loop s next f + end + in + match t with + | Sequence(seed, next) -> loop seed next f + +let is_empty t = + let rec loop s next = + match next s with + | Done -> true + | Skip s -> loop s next + | Yield _ -> false + in + match t with + | Sequence(seed, next) -> loop seed next + +let mem t a ~equal = + let rec loop s next a = + match next s with + | Done -> false + | Yield(b,_) when equal a b -> true + | Yield(_,s) | Skip s -> loop s next a + in + match t with + | Sequence(seed, next) -> loop seed next a + +let empty = + Sequence((), fun () -> Done) + +let bind t ~f = + unfold_step + ~f:(function + | Sequence(seed,next), rest -> + match next seed with + | Done -> + begin + match rest with + | Sequence(seed, next) -> + match next seed with + | Done -> Done + | Skip s -> Skip (empty, Sequence(s, next)) + | Yield(a, s) -> Skip(f a, Sequence(s, next)) + end + | Skip s -> Skip (Sequence(s,next), rest) + | Yield(a,s) -> Yield(a, (Sequence(s,next) , rest))) + ~init:(empty,t) + +let return x = + unfold_step ~init:(Some x) + ~f:(function + | None -> Done + | Some x -> Yield(x,None)) + +include Monad.Make(struct + type nonrec 'a t = 'a t + let map = `Custom map + let bind = bind + let return = return + end) + +let nth s n = + if n < 0 then None + else + let rec loop i s next = + match next s with + | Done -> None + | Skip s -> loop i s next + | Yield(a,s) -> if phys_equal i 0 then Some a else loop (i-1) s next + in + match s with + | Sequence(s,next) -> + loop n s next + +let nth_exn s n = + if n < 0 then raise (Invalid_argument "Sequence.nth") + else + match nth s n with + | None -> failwith "Sequence.nth" + | Some x -> x + +module Merge_with_duplicates_element = struct + type ('a, 'b) t = + | Left of 'a + | Right of 'b + | Both of 'a * 'b + [@@deriving_inline compare, hash, sexp] + let compare : + 'a 'b . + ('a -> 'a -> int) -> ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + = + fun _cmp__a -> + fun _cmp__b -> + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Left _a__003_, Left _b__004_) -> _cmp__a _a__003_ _b__004_ + | (Left _, _) -> (-1) + | (_, Left _) -> 1 + | (Right _a__005_, Right _b__006_) -> _cmp__b _a__005_ _b__006_ + | (Right _, _) -> (-1) + | (_, Right _) -> 1 + | (Both (_a__007_, _a__009_), Both (_b__008_, _b__010_)) -> + (match _cmp__a _a__007_ _b__008_ with + | 0 -> _cmp__b _a__009_ _b__010_ + | n -> n)) + let hash_fold_t : type a b. + (Ppx_hash_lib.Std.Hash.state -> a -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> b -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> (a, b) t -> Ppx_hash_lib.Std.Hash.state + = + fun _hash_fold_a -> + fun _hash_fold_b -> + fun hsv -> + fun arg -> + match arg with + | Left _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in _hash_fold_a hsv _a0 + | Right _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in _hash_fold_b hsv _a0 + | Both (_a0, _a1) -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 2 in + let hsv = let hsv = hsv in _hash_fold_a hsv _a0 in + _hash_fold_b hsv _a1 + let t_of_sexp : type a b. + (Ppx_sexp_conv_lib.Sexp.t -> a) -> + (Ppx_sexp_conv_lib.Sexp.t -> b) -> Ppx_sexp_conv_lib.Sexp.t -> (a, b) t + = + let _tp_loc = "src/sequence.ml.Merge_with_duplicates_element.t" in + fun _of_a -> + fun _of_b -> + function + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("left"|"Left" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_a v0 in Left v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("right"|"Right" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::[] -> let v0 = _of_b v0 in Right v0 + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("both"|"Both" as _tag))::sexp_args) as _sexp -> + (match sexp_args with + | v0::v1::[] -> + let v0 = _of_a v0 + and v1 = _of_b v1 in Both (v0, v1) + | _ -> + Ppx_sexp_conv_lib.Conv_error.stag_incorrect_n_args _tp_loc + _tag _sexp) + | Ppx_sexp_conv_lib.Sexp.Atom ("left"|"Left") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("right"|"Right") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.Atom ("both"|"Both") as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_takes_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as + sexp -> + Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp + let sexp_of_t : type a b. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (b -> Ppx_sexp_conv_lib.Sexp.t) -> (a, b) t -> Ppx_sexp_conv_lib.Sexp.t + = + fun _of_a -> + fun _of_b -> + function + | Left v0 -> + let v0 = _of_a v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Left"; v0] + | Right v0 -> + let v0 = _of_b v0 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Right"; v0] + | Both (v0, v1) -> + let v0 = _of_a v0 + and v1 = _of_b v1 in + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "Both"; v0; v1] + [@@@end] +end + +let merge_with_duplicates (Sequence (s1, next1)) (Sequence (s2, next2)) ~compare = + let unshadowed_compare = compare in + let open Merge_with_duplicates_element in + let next = function + | Skip s1, s2 -> Skip (next1 s1, s2) + | s1, Skip s2 -> Skip (s1, next2 s2) + | (Yield (a, s1') as s1), (Yield (b, s2') as s2) -> + let comparison = unshadowed_compare a b in + if comparison < 0 + then Yield (Left a, (Skip s1', s2)) + else if comparison = 0 + then Yield (Both (a, b), (Skip s1', Skip s2')) + else Yield (Right b, (s1, Skip s2')) + | Done, Done -> Done + | Yield (a, s1), Done -> Yield (Left a, (Skip s1, Done)) + | Done, Yield (b, s2) -> Yield (Right b, (Done, Skip s2)) + in + Sequence((Skip s1, Skip s2), next) + +let merge s1 s2 ~compare = + map (merge_with_duplicates s1 s2 ~compare) + ~f:(function Left x | Right x | Both (x, _) -> x) + +let hd s = + let rec loop s next = + match next s with + | Done -> None + | Skip s -> loop s next + | Yield(a,_) -> Some a + in + match s with + | Sequence (s,next) -> loop s next + +let hd_exn s = + match hd s with + | None -> failwith "hd_exn" + | Some a -> a + +let tl s = + let rec loop s next = + match next s with + | Done -> None + | Skip s -> loop s next + | Yield(_,a) -> Some a + in + match s with + | Sequence (s,next) -> + match loop s next with + | None -> None + | Some s -> Some (Sequence(s,next)) + +let tl_eagerly_exn s = + match tl s with + | None -> failwith "Sequence.tl_exn" + | Some s -> s + +let lift_identity next s = + match next s with + | Done -> Done + | Skip s -> Skip (`Identity s) + | Yield(a,s) -> Yield(a, `Identity s) + +let next s = + let rec loop s next = + match next s with + | Done -> None + | Skip s -> loop s next + | Yield(a,s) -> Some (a, Sequence(s, next)) + in + match s with + | Sequence(s, next) -> loop s next + +let filter_opt s = + match s with + | Sequence(s, next) -> + Sequence(s, + fun s -> + match next s with + | Done -> Done + | Skip s -> Skip s + | Yield(None, s) -> Skip s + | Yield(Some a, s) -> Yield(a, s)) + +let filter_map s ~f = + filter_opt (map s ~f) + +let filter_mapi s ~f = + filter_map (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun (i, s) -> f i s) + +let split_n s n = + let rec loop s i accum next = + if i <= 0 then + (List.rev accum, Sequence(s,next)) + else + match next s with + | Done -> (List.rev accum, empty) + | Skip s -> loop s i accum next + | Yield(a,s) -> loop s (i-1) (a::accum) next + in + match s with + | Sequence(s, next) -> loop s n [] next + +let chunks_exn t n = + if n <= 0 + then raise (Invalid_argument "Sequence.chunks_exn") + else + unfold_step ~init:t ~f:(fun t -> + match split_n t n with + | [], _empty -> Done + | _::_ as xs, t -> Yield (xs, t)) + +let findi s ~f = + find (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun (i,s) -> f i s) + +let find_exn s ~f = + match find s ~f with + | None -> failwith "Sequence.find_exn" + | Some x -> x + +let append s1 s2 = + match s1, s2 with + | Sequence(s1, next1), Sequence(s2, next2) -> + Sequence(`First_list s1, + function + | `First_list s1 -> + begin + match next1 s1 with + | Done -> Skip (`Second_list s2) + | Skip s1 -> Skip (`First_list s1) + | Yield(a,s1) -> Yield(a, `First_list s1) + end + | `Second_list s2 -> + begin + match next2 s2 with + | Done -> Done + | Skip s2 -> Skip (`Second_list s2) + | Yield(a,s2) -> Yield(a, `Second_list s2) + end) + +let concat_map s ~f = bind s ~f + +let concat s = concat_map s ~f:Fn.id + +let concat_mapi s ~f = + concat_map (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun (i,s) -> f i s) + +let zip (Sequence (s1, next1)) (Sequence (s2, next2)) = + let next = function + | Yield (a, s1), Yield (b, s2) -> Yield ((a, b), (Skip s1, Skip s2)) + | Done, _ + | _, Done -> Done + | Skip s1, s2 -> Skip (next1 s1, s2) + | s1, Skip s2 -> Skip (s1, next2 s2) + in + Sequence ((Skip s1, Skip s2), next) + +let zip_full (Sequence(s1,next1)) (Sequence(s2,next2)) = + let next = function + | Yield (a, s1), Yield (b, s2) -> Yield (`Both (a, b), (Skip s1, Skip s2)) + | Done, Done -> Done + | Skip s1, s2 -> Skip (next1 s1, s2) + | s1, Skip s2 -> Skip (s1, next2 s2) + | Done, Yield(b, s2) -> Yield((`Right b), (Done, next2 s2)) + | Yield(a, s1), Done -> Yield((`Left a), (next1 s1, Done)) + in + Sequence ((Skip s1, Skip s2), next) + +let bounded_length (Sequence(seed,next)) ~at_most = + let rec loop i seed next = + if i > at_most then `Greater + else + match next seed with + | Done -> `Is i + | Skip seed -> loop i seed next + | Yield(_, seed) -> loop (i+1) seed next + in + loop 0 seed next + +let length_is_bounded_by ?(min=(-1)) ?max t = + let length_is_at_least (Sequence(s,next)) = + let rec loop s acc = + if acc >= min then true else + match next s with + | Done -> false + | Skip s -> loop s acc + | Yield(_,s) -> loop s (acc + 1) + in loop s 0 + in + match max with + | None -> length_is_at_least t + | Some max -> + begin + match bounded_length t ~at_most:max with + | `Is len when len >= min -> true + | _ -> false + end + +let iteri s ~f = + iter (mapi s ~f:(fun i s -> (i, s))) + ~f:(fun (i, s) -> f i s) + +let foldi s ~init ~f = + fold ~init (mapi s ~f:(fun i s -> (i,s))) + ~f:(fun acc (i, s) -> f i acc s) + +let reduce s ~f = + match next s with + | None -> None + | Some(a, s) -> Some (fold s ~init:a ~f) + +let reduce_exn s ~f = + match reduce s ~f with + | None -> failwith "Sequence.reduce_exn" + | Some res -> res + +let group (Sequence (s, next)) ~break = + unfold_step ~init:(Some ([], s)) ~f:(function + | None -> Done + | Some (acc, s) -> + match acc, next s with + | _, Skip s -> Skip (Some (acc, s)) + | [], Done -> Done + | acc, Done -> Yield (List.rev acc, None) + | [], Yield (cur, s) -> Skip (Some ([cur], s)) + | prev :: _ as acc, Yield (cur, s) -> + if break prev cur + then Yield (List.rev acc, Some ([cur], s)) + else Skip (Some (cur :: acc, s))) +;; + +let find_consecutive_duplicate (Sequence(s, next)) ~equal = + let rec loop last_elt s = + match next s with + | Done -> None + | Skip s -> loop last_elt s + | Yield(a,s) -> + match last_elt with + | Some b when equal a b -> Some (b, a) + | None | Some _ -> loop (Some a) s + in + loop None s + +let remove_consecutive_duplicates s ~equal = + unfold_with s ~init:None + ~f:(fun prev a -> + match prev with + | Some b when equal a b -> Skip(Some a) + | None | Some _ -> Yield(a, Some a)) + +let count s ~f = + length (filter s ~f) + +let counti t ~f = + length (filteri t ~f) + +let sum m t ~f = Container.sum ~fold m t ~f +let min_elt t ~compare = Container.min_elt ~fold t ~compare +let max_elt t ~compare = Container.max_elt ~fold t ~compare + +let init n ~f = + unfold_step ~init:0 + ~f:(fun i -> + if i >= n then Done + else Yield(f i, i + 1)) + +let sub s ~pos ~len = + if pos < 0 || len < 0 then failwith "Sequence.sub"; + match s with + | Sequence(s, next) -> + Sequence((0,s), + (fun (i, s) -> + if i - pos >= len then Done + else + match next s with + | Done -> Done + | Skip s -> Skip (i, s) + | Yield(a, s) when i >= pos -> Yield (a,(i + 1, s)) + | Yield(_, s) -> Skip(i + 1, s))) + +let take s len = + if len < 0 then failwith "Sequence.take"; + match s with + | Sequence(s, next) -> + Sequence((0,s), + (fun (i, s) -> + if i >= len then Done + else + match next s with + | Done -> Done + | Skip s -> Skip (i, s) + | Yield(a, s) -> Yield (a,(i + 1, s)))) + +let drop s len = + if len < 0 then failwith "Sequence.drop"; + match s with + | Sequence(s, next) -> + Sequence((0,s), + (fun (i, s) -> + match next s with + | Done -> Done + | Skip s -> Skip (i, s) + | Yield(a, s) when i >= len -> Yield (a,(i + 1, s)) + | Yield(_, s) -> Skip (i+1, s))) + +let take_while s ~f = + match s with + | Sequence(s, next) -> + Sequence(s, + fun s -> + match next s with + | Done -> Done + | Skip s -> Skip s + | Yield (a, s) when f a -> Yield(a,s) + | Yield (_,_) -> Done) + +let drop_while s ~f = + match s with + | Sequence(s, next) -> + Sequence(`Dropping s, + function + |`Dropping s -> + begin + match next s with + | Done -> Done + | Skip s -> Skip (`Dropping s) + | Yield(a, s) when f a -> Skip (`Dropping s) + | Yield(a, s) -> Yield(a, `Identity s) + end + | `Identity s -> lift_identity next s) + +let shift_right s x = + match s with + | Sequence(seed, next) -> + Sequence(`Consing (seed, x), + function + | `Consing (seed, x) -> Yield(x, `Identity seed) + | `Identity s -> lift_identity next s) + +let shift_right_with_list s l = + append (of_list l) s + +let shift_left = drop + +module Infix = struct + let (@) = append +end + +let intersperse s ~sep = + match s with + | Sequence(s, next) -> + Sequence(`Init s, + function + | `Init s -> + begin + match next s with + | Done -> Done + | Skip s -> Skip (`Init s) + | Yield(a, s) -> Yield(a, `Running s) + end + | `Running s -> + begin + match next s with + | Done -> Done + | Skip s -> Skip (`Running s) + | Yield(a, s) -> Yield(sep, `Putting(a,s)) + end + | `Putting(a,s) -> Yield(a,`Running s)) + +let repeat x = + unfold_step ~init:x ~f:(fun x -> Yield(x, x)) + +let cycle_list_exn xs = + if List.is_empty xs then raise (Invalid_argument "Sequence.cycle_list_exn"); + let s = of_list xs in + concat_map ~f:(fun () -> s) (repeat ()) + +let cartesian_product sa sb = + concat_map sa + ~f:(fun a -> zip (repeat a) sb) + +let singleton x = return x + +let delayed_fold s ~init ~f ~finish = + Expert.delayed_fold_step s ~init ~finish ~f:(fun acc option ~k -> + match option with + | None -> k acc + | Some a -> f acc a ~k) + +let fold_m ~bind ~return t ~init ~f = + Expert.delayed_fold_step t + ~init + ~f:(fun acc option ~k -> + match option with + | None -> bind (return acc) ~f:k + | Some a -> bind (f acc a) ~f:k) + ~finish:return +;; + +let iter_m ~bind ~return t ~f = + Expert.delayed_fold_step t + ~init:() + ~f:(fun () option ~k -> + match option with + | None -> bind (return ()) ~f:k + | Some a -> bind (f a) ~f:k) + ~finish:return +;; + +let fold_until s ~init ~f ~finish = + let rec loop s next f = + fun acc -> + match next s with + | Done -> finish acc + | Skip s -> loop s next f acc + | Yield(a, s) -> match (f acc a : ('a, 'b) Continue_or_stop.t) with + | Stop x -> x + | Continue acc -> loop s next f acc + in + match s with + | Sequence(s, next) -> loop s next f init + +let fold_result s ~init ~f = + let rec loop s next f = + fun acc -> + match next s with + | Done -> Result.return acc + | Skip s -> loop s next f acc + | Yield(a, s) -> match (f acc a : (_, _) Result.t) with + | Error _ as e -> e + | Ok acc -> loop s next f acc + in + match s with + | Sequence(s, next) -> loop s next f init + +let force_eagerly t = of_list (to_list t) + +let memoize (type a) (Sequence (s, next)) = + let module M = struct + type t = T of (a, t) Step.t Lazy.t + end in + let rec memoize s = M.T (lazy (find_step s)) + and find_step s = + match next s with + | Done -> Done + | Skip s -> find_step s + | Yield (a, s) -> Yield (a, memoize s) + in + Sequence (memoize s, (fun (M.T l) -> Lazy.force l)) + +let drop_eagerly s len = + let rec loop i ~len s next = + if i >= len then Sequence(s, next) + else + match next s with + | Done -> empty + | Skip s -> loop i ~len s next + | Yield(_,s) -> loop (i+1) ~len s next + in + match s with + | Sequence(s, next) -> loop 0 ~len s next + +let drop_while_option (Sequence (s, next)) ~f = + let rec loop s = + match next s with + | Done -> None + | Skip s -> loop s + | Yield (x, s) -> if f x then loop s else Some (x, Sequence (s, next)) + in + loop s + +let compare compare_a t1 t2 = + With_return.with_return (fun r -> + iter (zip_full t1 t2) ~f:(function + | `Left _ -> r.return 1 + | `Right _ -> r.return (-1) + | `Both (v1, v2) -> + let c = compare_a v1 v2 in + if c <> 0 + then r.return c); + 0); +;; + +let round_robin list = + let next (todo_stack, done_stack) = + match todo_stack with + | Sequence (s, f) :: todo_stack -> + begin + match f s with + | Yield (x, s) -> Yield (x, (todo_stack, Sequence (s, f) :: done_stack)) + | Skip s -> Skip (Sequence (s, f) :: todo_stack, done_stack) + | Done -> Skip (todo_stack, done_stack) + end + | [] -> + if List.is_empty done_stack + then Done + else Skip (List.rev done_stack, []) + in + let state = list, [] in + Sequence (state, next) + +let interleave (Sequence (s1, f1)) = + let next (todo_stack, done_stack, s1) = + match todo_stack with + | Sequence (s2, f2) :: todo_stack -> + begin + match f2 s2 with + | Yield (x, s2) -> Yield (x, (todo_stack, Sequence (s2, f2) :: done_stack, s1)) + | Skip s2 -> Skip (todo_stack, Sequence (s2, f2) :: done_stack, s1) + | Done -> Skip (todo_stack, done_stack, s1) + end + | [] -> + begin + match f1 s1, done_stack with + | Yield (t, s1), _ -> Skip (List.rev (t :: done_stack), [], s1) + | Skip s1 , _ -> Skip (List.rev done_stack , [], s1) + | Done , _::_ -> Skip (List.rev done_stack , [], s1) + | Done , [] -> Done + end + in + let state = [], [], s1 in + Sequence (state, next) + +let interleaved_cartesian_product s1 s2 = + map s1 ~f:(fun x1 -> + map s2 ~f:(fun x2 -> + (x1, x2))) + |> interleave + +module Generator = struct + + type 'elt steps = Wrap of ('elt, unit -> 'elt steps) Step.t + + let unwrap (Wrap step) = step + + module T = struct + type ('a, 'elt) t = ('a -> 'elt steps) -> 'elt steps + let return x = (); fun k -> k x + let bind m ~f = (); fun k -> m (fun a -> let m' = f a in m' k) + let map m ~f = (); fun k -> m (fun a -> k (f a)) + let map = `Custom map + end + include T + include Monad.Make2 (T) + + let yield e = (); fun k -> Wrap (Yield (e, k)) + + let to_steps t = t (fun () -> Wrap Done) + + let of_sequence sequence = + delayed_fold sequence + ~init:() + ~f:(fun () x ~k f -> Wrap (Yield (x, fun () -> k () f))) + ~finish:return + + let run t = + let init () = to_steps t in + let f thunk = unwrap (thunk ()) in + unfold_step ~init ~f + +end diff --git a/src/sequence.mli b/src/sequence.mli new file mode 100644 index 0000000..5f18e29 --- /dev/null +++ b/src/sequence.mli @@ -0,0 +1,487 @@ +(** A sequence of elements that can be produced one at a time, on demand, normally with no + sharing. + + The elements are computed on demand, possibly repeating work if they are demanded + multiple times. A sequence can be built by unfolding from some initial state, which + will in practice often be other containers. + + Most functions constructing a sequence will not immediately compute any elements of + the sequence. These functions will always return in O(1), but traversing the + resulting sequence may be more expensive. The most they will do immediately is + generate a new internal state and a new step function. + + Functions that transform existing sequences sometimes have to reconstruct some suffix + of the input sequence, even if it is unmodified. For example, calling [drop 1] will + return a sequence with a slightly larger state and whose elements all cost slightly + more to traverse. Because this is sometimes undesirable (for example, applying [drop + 1] n times will cost O(n) per element traversed in the result), there are also more + eager versions of many functions (whose names are suffixed with [_eagerly]) that do + more work up front. A function has the [_eagerly] suffix iff it matches both of these + conditions: + + - It might consume an element from an input [t] before returning. + + - It only returns a [t] (not paired with something else, not wrapped in an [option], + etc.). If it returns anything other than a [t] and it has at least one [t] input, + it's probably demanding elements from the input [t] anyway. + + Only [*_exn] functions can raise exceptions, except if the function underlying the + sequence (the [f] passed to [unfold]) raises, in which case the exception will + cascade. *) + +open! Import + +type +'a t [@@deriving_inline compare, sexp_of] +include +sig + [@@@ocaml.warning "-32"] + val compare : ('a -> 'a -> int) -> 'a t -> 'a t -> int + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] +type 'a sequence = 'a t + +include Indexed_container.S1 with type 'a t := 'a t +include Monad.S with type 'a t := 'a t + +(** [empty] is a sequence with no elements. *) +val empty : _ t + +(** [next] returns the next element of a sequence and the next tail if the sequence is not + finished. *) +val next : 'a t -> ('a * 'a t) option + +(** A [Step] describes the next step of the sequence construction. [Done] indicates the + sequence is finished. [Skip] indicates the sequence continues with another state + without producing the next element yet. [Yield] outputs an element and introduces a + new state. + + Modifying ['s] doesn't violate any {e internal} invariants, but it may violate some + undocumented expectations. For example, one might expect that producing an element + from the same point in the sequence would always give the same value, but if the state + can mutate, that is not so. *) +module Step : sig + type ('a, 's) t = + | Done + | Skip of 's + | Yield of 'a * 's + [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('s -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 's) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +(** [unfold_step ~init ~f] constructs a sequence by giving an initial state [init] and a + function [f] explaining how to continue the next step from a given state. *) +val unfold_step : init:'s -> f:('s -> ('a, 's) Step.t) -> 'a t + +(** [unfold ~init f] is a simplified version of [unfold_step] that does not allow + [Skip]. *) +val unfold : init:'s -> f:('s -> ('a * 's) option) -> 'a t + +(** [unfold_with t ~init ~f] folds a state through the sequence [t] to create a new + sequence *) +val unfold_with : 'a t -> init:'s -> f:('s -> 'a -> ('b, 's) Step.t) -> 'b t + +(** [unfold_with_and_finish t ~init ~running_step ~inner_finished ~finishing_step] folds a + state through [t] to create a new sequence (like [unfold_with t ~init + ~f:running_step]), and then continues the new sequence by unfolding the final state + (like [unfold_step ~init:(inner_finished final_state) ~f:finishing_step]). *) +val unfold_with_and_finish + : 'a t + -> init : 's_a + -> running_step : ('s_a -> 'a -> ('b, 's_a) Step.t) + -> inner_finished : ('s_a -> 's_b) + -> finishing_step : ('s_b -> ('b, 's_b) Step.t) + -> 'b t + +(** Returns the nth element. *) +val nth : 'a t -> int -> 'a option +val nth_exn : 'a t -> int -> 'a + +(** [folding_map] is a version of [map] that threads an accumulator through calls to + [f]. *) +val folding_map : 'a t -> init:'b -> f:( 'b -> 'a -> 'b * 'c) -> 'c t +val folding_mapi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b * 'c) -> 'c t + +val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + +val filteri : 'a t -> f: (int -> 'a -> bool) -> 'a t + +val filter : 'a t -> f: ('a -> bool) -> 'a t + +(** [merge t1 t2 ~compare] merges two sorted sequences [t1] and [t2], returning a sorted + sequence, all according to [compare]. If two elements are equal, the one from [t1] is + preferred. The behavior is undefined if the inputs aren't sorted. *) +val merge : 'a t -> 'a t -> compare:('a -> 'a -> int) -> 'a t + +module Merge_with_duplicates_element : sig + type ('a, 'b) t = + | Left of 'a + | Right of 'b + | Both of 'a * 'b + [@@deriving_inline compare, hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + val hash_fold_t : + (Ppx_hash_lib.Std.Hash.state -> 'a -> Ppx_hash_lib.Std.Hash.state) -> + (Ppx_hash_lib.Std.Hash.state -> 'b -> Ppx_hash_lib.Std.Hash.state) -> + Ppx_hash_lib.Std.Hash.state -> + ('a, 'b) t -> Ppx_hash_lib.Std.Hash.state + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +(** [merge_with_duplicates_element t1 t2 ~compare] is like [merge], except that for each + element it indicates which input(s) the element comes from, using + [Merge_with_duplicates_element]. *) +val merge_with_duplicates + : 'a t + -> 'b t + -> compare:('a -> 'b -> int) + -> ('a, 'b) Merge_with_duplicates_element.t t + +val hd : 'a t -> 'a option +val hd_exn : 'a t -> 'a + +(** [tl t] and [tl_eagerly_exn t] immediately evaluates the first element of [t] and + returns the unevaluated tail. *) +val tl : 'a t -> 'a t option +val tl_eagerly_exn : 'a t -> 'a t + +(** [find_exn t ~f] returns the first element of [t] that satisfies [f]. It raises if + there is no such element. *) +val find_exn : 'a t -> f:('a -> bool) -> 'a + +(** Like [for_all], but passes the index as an argument. *) +val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + +(** [append t1 t2] first produces the elements of [t1], then produces the elements of + [t2]. *) +val append : 'a t -> 'a t -> 'a t + +(** [concat tt] produces the elements of each inner sequence sequentially. If any inner + sequences are infinite, elements of subsequent inner sequences will not be reached. *) +val concat : 'a t t -> 'a t + +(** [concat_map t ~f] is [concat (map t ~f)].*) +val concat_map : 'a t -> f:('a -> 'b t) -> 'b t + +(** [concat_mapi t ~f] is like concat_map, but passes the index as an argument. *) +val concat_mapi : 'a t -> f:(int -> 'a -> 'b t) -> 'b t + +(** [interleave tt] produces each element of the inner sequences of [tt] eventually, even + if any or all of the inner sequences are infinite. The elements of each inner + sequence are produced in order with respect to that inner sequence. The manner of + interleaving among the separate inner sequences is deterministic but unspecified. *) +val interleave : 'a t t -> 'a t + +(** [round_robin list] is like [interleave (of_list list)], except that the manner of + interleaving among the inner sequences is guaranteed to be round-robin. The input + sequences may be of different lengths; an empty sequence is dropped from subsequent + rounds of interleaving. *) +val round_robin : 'a t list -> 'a t + +(** Transforms a pair of sequences into a sequence of pairs. The length of the returned + sequence is the length of the shorter input. The remaining elements of the longer + input are discarded. + + WARNING: Unlike [List.zip], this will not error out if the two input sequences are of + different lengths, because [zip] may have already returned some elements by the time + this becomes apparent. *) +val zip : 'a t -> 'b t -> ('a * 'b) t + +(** [zip_full] is like [zip], but if one sequence ends before the other, then it keeps + producing elements from the other sequence until it has ended as well. *) +val zip_full: 'a t -> 'b t -> [ `Left of 'a | `Both of 'a * 'b | `Right of 'b ] t + +(** [reduce_exn f [a1; ...; an]] is [f (... (f (f a1 a2) a3) ...) an]. It fails on the + empty sequence. *) +val reduce_exn : 'a t -> f:('a -> 'a -> 'a) -> 'a +val reduce : 'a t -> f:('a -> 'a -> 'a) -> 'a option + +(** [group l ~break] returns a sequence of lists (i.e., groups) whose concatenation is + equal to the original sequence. Each group is broken where [break] returns true on a + pair of successive elements. + + Example: + + {[ + group ~break:(<>) (of_list ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i']) -> + + of_list [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] ]} *) +val group : 'a t -> break:('a -> 'a -> bool) -> 'a list t + +(** [find_consecutive_duplicate t ~equal] returns the first pair of consecutive elements + [(a1, a2)] in [t] such that [equal a1 a2]. They are returned in the same order as + they appear in [t]. *) +val find_consecutive_duplicate : 'a t -> equal:('a -> 'a -> bool) -> ('a * 'a) option + +(** The same sequence with consecutive duplicates removed. The relative order of the + other elements is unaffected. *) +val remove_consecutive_duplicates : 'a t -> equal:('a -> 'a -> bool) -> 'a t + +(** [range ?stride ?start ?stop start_i stop_i] is the sequence of integers from [start_i] + to [stop_i], stepping by [stride]. If [stride] < 0 then we need [start_i] > [stop_i] + for the result to be nonempty (or [start_i] >= [stop_i] in the case where both bounds + are inclusive). *) +val range + : ?stride:int (** default is [1] *) + -> ?start:[`inclusive|`exclusive] (** default is [`inclusive] *) + -> ?stop:[`inclusive|`exclusive] (** default is [`exclusive] *) + -> int + -> int + -> int t + +(** [init n ~f] is [[(f 0); (f 1); ...; (f (n-1))]]. It is an error if [n < 0]. *) +val init : int -> f:(int -> 'a) -> 'a t + +(** [filter_map t ~f] produce mapped elements of [t] which are not [None]. *) +val filter_map : 'a t -> f:('a -> 'b option) -> 'b t + +(** [filter_mapi] is just like [filter_map], but it also passes in the index of each + element to [f]. *) +val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + +(** [filter_opt t] produces the elements of [t] which are not [None]. [filter_opt t] = + [filter_map t ~f:ident]. *) +val filter_opt : 'a option t -> 'a t + +(** [sub t ~pos ~len] is the [len]-element subsequence of [t], starting at [pos]. If the + sequence is shorter than [pos + len], it returns [ t[pos] ... t[l-1] ], where [l] is + the length of the sequence. *) +val sub : 'a t -> pos:int -> len:int -> 'a t + +(** [take t n] produces the first [n] elements of [t]. *) +val take : 'a t -> int -> 'a t + +(** [drop t n] produces all elements of [t] except the first [n] elements. If there are + fewer than [n] elements in [t], there is no error; the resulting sequence simply + produces no elements. Usually you will probably want to use [drop_eagerly] because it + can be significantly cheaper. *) +val drop : 'a t -> int -> 'a t + +(** [drop_eagerly t n] immediately consumes the first [n] elements of [t] and returns the + unevaluated tail of [t]. *) +val drop_eagerly : 'a t -> int -> 'a t + +(** [take_while t ~f] produces the longest prefix of [t] for which [f] applied to each + element is [true]. *) +val take_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [drop_while t ~f] produces the suffix of [t] beginning with the first element of [t] + for which [f] is [false]. Usually you will probably want to use [drop_while_option] + because it can be significantly cheaper. *) +val drop_while : 'a t -> f : ('a -> bool) -> 'a t + +(** [drop_while_option t ~f] immediately consumes the elements from [t] until the + predicate [f] fails and returns the first element that failed along with the + unevaluated tail of [t]. The first element is returned separately because the + alternatives would mean forcing the consumer to evaluate the first element again (if + the previous state of the sequence is returned) or take on extra cost for each element + (if the element is added to the final state of the sequence using [shift_right]). *) +val drop_while_option : 'a t -> f : ('a -> bool) -> ('a * 'a t) option + +(** [split_n t n] immediately consumes the first [n] elements of [t] and returns the + consumed prefix, as a list, along with the unevaluated tail of [t]. *) +val split_n : 'a t -> int -> 'a list * 'a t + +(** [chunks_exn t n] produces lists of elements of [t], up to [n] elements at a time. The + last list may contain fewer than [n] elements. No list contains zero elements. If [n] + is not positive, it raises. *) +val chunks_exn : 'a t -> int -> 'a list t + +(** [shift_right t a] produces [a] and then produces each element of [t]. *) +val shift_right : 'a t -> 'a -> 'a t + +(** [shift_right_with_list t l] produces the elements of [l], then produces the elements + of [t]. It is better to call [shift_right_with_list] with a list of size n than + [shift_right] n times; the former will require O(1) work per element produced and the + latter O(n) work per element produced. *) +val shift_right_with_list : 'a t -> 'a list -> 'a t + +(** [shift_left t n] is a synonym for [drop t n].*) +val shift_left : 'a t -> int -> 'a t + +module Infix : sig + val ( @ ) : 'a t -> 'a t -> 'a t +end + +(** Returns a sequence with all possible pairs. The stepper function of the second + sequence passed as argument may be applied to the same state multiple times, so be + careful using [cartesian_product] with expensive or side-effecting functions. If the + second sequence is infinite, some values in the first sequence may not be reached. *) +val cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** Returns a sequence that eventually reaches every possible pair of elements of the + inputs, even if either or both are infinite. The step function of both inputs may be + applied to the same state repeatedly, so be careful using + [interleaved_cartesian_product] with expensive or side-effecting functions. *) +val interleaved_cartesian_product : 'a t -> 'b t -> ('a * 'b) t + +(** [intersperse xs ~sep] produces [sep] between adjacent elements of [xs], e.g., + [intersperse [1;2;3] ~sep:0 = [1;0;2;0;3]]. *) +val intersperse : 'a t -> sep:'a -> 'a t + +(** [cycle_list_exn xs] repeats the elements of [xs] forever. If [xs] is empty, it + raises. *) +val cycle_list_exn : 'a list -> 'a t + +(** [repeat a] repeats [a] forever. *) +val repeat : 'a -> 'a t + +(** [singleton a] produces [a] exactly once. *) +val singleton : 'a -> 'a t + +(** [delayed_fold] allows to do an on-demand fold, while maintaining a state. + + It is possible to exit early by not calling [k] in [f]. It is also possible to call + [k] multiple times. This results in the rest of the sequence being folded over + multiple times, independently. + + Note that [delayed_fold], when targeting JavaScript, can result in stack overflow as + JavaScript doesn't generally have tail call optimization. *) +val delayed_fold + : 'a t + -> init:'s + -> f:('s -> 'a -> k:('s -> 'r) -> 'r) (** [k] stands for "continuation" *) + -> finish:('s -> 'r) + -> 'r + +(** [fold_m] is a monad-friendly version of [fold]. Supply it with the monad's [return] + and [bind], and it will chain them through the computation. *) +val fold_m + : bind:('acc_m -> f:('acc -> 'acc_m) -> 'acc_m) + -> return:('acc -> 'acc_m) + -> 'elt t + -> init:'acc + -> f:('acc -> 'elt -> 'acc_m) + -> 'acc_m + +(** [iter_m] is a monad-friendly version of [iter]. Supply it with the monad's [return] + and [bind], and it will chain them through the computation. *) +val iter_m + : bind:('unit_m -> f:(unit -> 'unit_m) -> 'unit_m) + -> return:(unit -> 'unit_m) + -> 'elt t + -> f:('elt -> 'unit_m) + -> 'unit_m + +(** [to_list_rev t] returns a list of the elements of [t], in reverse order. It is faster + than [to_list]. *) +val to_list_rev : 'a t -> 'a list + +val of_list : 'a list -> 'a t + +(** [of_lazy t_lazy] produces a sequence that forces [t_lazy] the first time it needs to + compute an element. *) +val of_lazy : 'a t Lazy.t -> 'a t + +(** [memoize t] produces each element of [t], but also memoizes them so that if you + consume the same element multiple times it is only computed once. It's a non-eager + version of [force_eagerly]. *) +val memoize : 'a t -> 'a t + +(** [force_eagerly t] precomputes the sequence. It is behaviorally equivalent to [of_list + (to_list t)], but may at some point have a more efficient implementation. It's an + eager version of [memoize]. *) +val force_eagerly : 'a t -> 'a t + +(** [bounded_length ~at_most t] returns [`Is len] if [len = length t <= at_most], and + otherwise returns [`Greater]. Walks through only as much of the sequence as + necessary. Always returns [`Greater] if [at_most < 0]. *) +val bounded_length : _ t -> at_most:int -> [ `Is of int | `Greater ] + +(** [length_is_bounded_by ~min ~max t] returns true if [min <= length t] and [length t <= + max] When [min] or [max] are not provided, the check for that bound is omitted. Walks + through only as much of the sequence as necessary. *) +val length_is_bounded_by: ?min:int -> ?max:int -> _ t -> bool + +(** [Generator] is a monadic interface to generate sequences in a direct style, similar to + Python's generators. + + Here are some examples: + + {[ + open Generator + + let rec traverse_list = function + | [] -> return () + | x :: xs -> yield x >>= fun () -> traverse_list xs + + let traverse_option = function + | None -> return () + | Some x -> yield x + + let traverse_array arr = + let n = Array.length arr in + let rec loop i = + if i >= n then return () else yield arr.(i) >>= fun () -> loop (i + 1) + in + loop 0 + + let rec traverse_bst = function + | Node.Empty -> return () + | Node.Branch (left, value, right) -> + traverse_bst left >>= fun () -> + yield value >>= fun () -> + traverse_bst right + + let sequence_of_list x = Generator.run (traverse_list x) + let sequence_of_option x = Generator.run (traverse_option x) + let sequence_of_array x = Generator.run (traverse_array x) + let sequence_of_bst x = Generator.run (traverse_bst x) + ]} *) + +module Generator : sig + include Monad.S2 + val yield : 'elt -> (unit, 'elt) t + val of_sequence : 'elt sequence -> (unit, 'elt) t + val run : (unit, 'elt) t -> 'elt sequence +end + +(** The functions in [Expert] expose internal structure which is normally meant to be + hidden. For example, at least when [f] is purely functional, it is not intended for + client code to distinguish between + + {[ + List.filter xs ~f + |> Sequence.of_list + ]} + + and + + {[ + Sequence.of_list xs + |> Sequence.filter ~f + ]} + + But sometimes for operational reasons it still makes sense to distinguish them. For + example, being able to handle [Skip]s explicitly allows breaking up some + computationally expensive sequences into smaller chunks of work. *) +module Expert : sig + (** [next_step] returns the next step in a sequence's construction. It is like [next], + but it also allows observing [Skip] steps. *) + val next_step : 'a t -> ('a, 'a t) Step.t + + (** [delayed_fold_step] is liked [delayed_fold], but [f] takes an option where [None] + represents a [Skip] step. *) + val delayed_fold_step + : 'a t + -> init:'s + -> f:('s -> 'a option -> k:('s -> 'r) -> 'r) (** [k] stands for "continuation" *) + -> finish:('s -> 'r) + -> 'r +end diff --git a/src/set.ml b/src/set.ml new file mode 100644 index 0000000..13af40f --- /dev/null +++ b/src/set.ml @@ -0,0 +1,1314 @@ +(***********************************************************************) +(* *) +(* Objective Caml *) +(* *) +(* Xavier Leroy, projet Cristal, INRIA Rocquencourt *) +(* *) +(* Copyright 1996 Institut National de Recherche en Informatique et *) +(* en Automatique. All rights reserved. This file is distributed *) +(* under the terms of the Apache 2.0 license. See ../THIRD-PARTY.txt *) +(* for details. *) +(* *) +(***********************************************************************) + +(* Sets over ordered types *) + +open! Import + +include Set_intf + +let with_return = With_return.with_return + + +module Tree0 = struct + type 'a t = + | Empty + (* (Leaf x) is the same as (Node (Empty, x, Empty, 1, 1)) but uses less space. *) + | Leaf of 'a + (* first int is height, second is sub-tree size *) + | Node of 'a t * 'a * 'a t * int * int + + type 'a tree = 'a t + + (* Sets are represented by balanced binary trees (the heights of the children differ by + at most 2. *) + let height = function + | Empty -> 0 + | Leaf _ -> 1 + | Node(_, _, _, h, _) -> h + ;; + + let length = function + | Empty -> 0 + | Leaf _ -> 1 + | Node(_, _, _, _, s) -> s + ;; + + let invariants = + let in_range lower upper compare_elt v = + (match lower with + | None -> true + | Some lower -> compare_elt lower v < 0 + ) + && (match upper with + | None -> true + | Some upper -> compare_elt v upper < 0 + ) + in + let rec loop lower upper compare_elt t = + match t with + | Empty -> true + | Leaf v -> in_range lower upper compare_elt v + | Node (l, v, r, h, n) -> + let hl = height l and hr = height r in + abs (hl - hr) <= 2 + && h = (max hl hr) + 1 + && n = length l + length r + 1 + && in_range lower upper compare_elt v + && loop lower (Some v) compare_elt l + && loop (Some v) upper compare_elt r + in + fun t ~compare_elt -> loop None None compare_elt t + ;; + + let is_empty = function Empty -> true | Leaf _ | Node _ -> false + + (* Creates a new node with left son l, value v and right son r. + We must have all elements of l < v < all elements of r. + l and r must be balanced and | height l - height r | <= 2. + Inline expansion of height for better speed. *) + + let create l v r = + let hl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + let hr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + let h = if hl >= hr then hl + 1 else hr + 1 in + if h = 1 + then Leaf v + else begin + let sl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + let sr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + Node (l, v, r, h, sl + sr + 1) + end + + (* We must call [f] with increasing indexes, because the bin_prot reader in + Core_kernel.Set needs it. *) + let of_increasing_iterator_unchecked ~len ~f = + let rec loop n ~f i = + match n with + | 0 -> Empty + | 1 -> + let k = f i in + Leaf k + | 2 -> + let kl = f i in + let k = f (i + 1) in + create (Leaf kl) k (Empty) + | 3 -> + let kl = f i in + let k = f (i + 1) in + let kr = f (i + 2) in + create (Leaf kl) k (Leaf kr) + | n -> + let left_length = n lsr 1 in + let right_length = n - left_length - 1 in + let left = loop left_length ~f i in + let k = f (i + left_length) in + let right = loop right_length ~f (i + left_length + 1) in + create left k right + in + loop len ~f 0 + + let of_sorted_array_unchecked array ~compare_elt = + let array_length = Array.length array in + let next = + (* We don't check if the array is sorted or keys are duplicated, because that + checking is slower than the whole [of_sorted_array] function *) + if array_length < 2 || compare_elt array.(0) array.(1) < 0 + then (fun i -> array.(i)) + else (fun i -> array.(array_length - 1 - i)) + in + of_increasing_iterator_unchecked ~len:array_length ~f:next + ;; + + let of_sorted_array array ~compare_elt = + match array with + | [||] | [|_|] -> Result.Ok (of_sorted_array_unchecked array ~compare_elt) + | _ -> + with_return (fun r -> + let increasing = + match compare_elt array.(0) array.(1) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> i < 0 + in + for i = 1 to Array.length array - 2 do + match compare_elt array.(i) array.(i+1) with + | 0 -> r.return (Or_error.error_string "of_sorted_array: duplicated elements") + | i -> + if Poly.(<>) (i < 0) increasing then + r.return (Or_error.error_string "of_sorted_array: elements are not ordered") + done; + Result.Ok (of_sorted_array_unchecked array ~compare_elt) + ) + + (* Same as create, but performs one step of rebalancing if necessary. + Assumes l and r balanced and | height l - height r | <= 3. + Inline expansion of create for better speed in the most frequent case + where no rebalancing is required. *) + + let bal l v r = + let hl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + let hr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,h,_) -> h in + if hl > hr + 2 then begin + match l with + | Empty -> assert false + | Leaf _ -> assert false (* because h(l)>h(r)+2 and h(leaf)=1 *) + | Node (ll, lv, lr, _, _) -> + if height ll >= height lr then + create ll lv (create lr v r) + else begin + match lr with + | Empty -> assert false + | Leaf lrv -> + assert (is_empty ll); + create (create ll lv Empty) lrv (create Empty v r) + | Node(lrl, lrv, lrr, _, _)-> + create (create ll lv lrl) lrv (create lrr v r) + end + end else if hr > hl + 2 then begin + match r with + Empty -> assert false + | Leaf rv -> create (create l v Empty) rv Empty + | Node(rl, rv, rr, _, _) -> + if height rr >= height rl then + create (create l v rl) rv rr + else begin + match rl with + Empty -> assert false + | Leaf rlv -> + assert (is_empty rr); + create (create l v Empty) rlv (create Empty rv rr) + | Node(rll, rlv, rlr, _, _) -> + create (create l v rll) rlv (create rlr rv rr) + end + end else begin + let h = if hl >= hr then hl + 1 else hr + 1 in + let sl = match l with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + let sr = match r with Empty -> 0 | Leaf _ -> 1 | Node(_,_,_,_,s) -> s in + if h = 1 + then Leaf v + else Node (l, v, r, h, sl + sr + 1) + end + + (* Insertion of one element *) + + exception Same + + let add t x ~compare_elt = + let rec aux = function + | Empty -> Leaf x + | Leaf v -> + let c = compare_elt x v in + if c = 0 then + raise Same + else if c < 0 then + bal (Leaf x) v Empty + else + bal Empty v (Leaf x) + | Node(l, v, r, _, _) -> + let c = compare_elt x v in + if c = 0 then + raise Same + else if c < 0 then + bal (aux l) v r + else + bal l v (aux r) + in + try aux t with Same -> t + ;; + + (* Same as create and bal, but no assumptions are made on the relative heights of l and + r. *) + let rec join l v r ~compare_elt = + match (l, r) with + | (Empty, _) -> add r v ~compare_elt + | (_, Empty) -> add l v ~compare_elt + | (Leaf lv, _) -> add (add r v ~compare_elt) lv ~compare_elt + | (_, Leaf rv) -> add (add l v ~compare_elt) rv ~compare_elt + | (Node (ll, lv, lr, lh, _), Node (rl, rv, rr, rh, _)) -> + if lh > rh + 2 then bal ll lv (join lr v r ~compare_elt) else + if rh > lh + 2 then bal (join l v rl ~compare_elt) rv rr else + create l v r + ;; + + (* Smallest and greatest element of a set *) + let rec min_elt = function + | Empty -> None + | Leaf v + | Node(Empty, v, _, _, _) -> Some v + | Node(l, _, _, _, _) -> min_elt l + ;; + + exception Set_min_elt_exn_of_empty_set [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Set_min_elt_exn_of_empty_set]) + (function + | Set_min_elt_exn_of_empty_set -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/set.ml.Tree0.Set_min_elt_exn_of_empty_set" + | _ -> assert false) + [@@@end] + exception Set_max_elt_exn_of_empty_set [@@deriving_inline sexp] + let () = + Ppx_sexp_conv_lib.Conv.Exn_converter.add + ([%extension_constructor Set_max_elt_exn_of_empty_set]) + (function + | Set_max_elt_exn_of_empty_set -> + Ppx_sexp_conv_lib.Sexp.Atom + "src/set.ml.Tree0.Set_max_elt_exn_of_empty_set" + | _ -> assert false) + [@@@end] + + let min_elt_exn t = + match min_elt t with + | None -> raise Set_min_elt_exn_of_empty_set + | Some v -> v + ;; + + let fold_until t ~init ~f ~finish = + let rec fold_until_helper ~f t acc = + match t with + | Empty -> Continue_or_stop.Continue acc + | Leaf value -> f acc value + | Node(left, value, right, _, _) -> + match fold_until_helper ~f left acc with + | Stop _a as x -> x + | Continue acc -> + match f acc value with + | Stop _a as x -> x + | Continue a -> fold_until_helper ~f right a + in + match fold_until_helper ~f t init with + | Continue x -> finish x + | Stop x -> x + ;; + + let rec max_elt = function + | Empty -> None + | Leaf v + | Node(_, v, Empty, _, _) -> Some v + | Node(_, _, r, _, _) -> max_elt r + ;; + + let max_elt_exn t = + match max_elt t with + | None -> raise Set_max_elt_exn_of_empty_set + | Some v -> v + ;; + + (* Remove the smallest element of the given set *) + + let rec remove_min_elt = function + | Empty -> invalid_arg "Set.remove_min_elt" + | Leaf _ -> Empty + | Node(Empty, _, r, _, _) -> r + | Node(l, v, r, _, _) -> bal (remove_min_elt l) v r + ;; + + (* Merge two trees l and r into one. All elements of l must precede the elements of r. + Assume | height l - height r | <= 2. *) + let merge t1 t2 = + match (t1, t2) with + | (Empty, t) -> t + | (t, Empty) -> t + | (_, _) -> bal t1 (min_elt_exn t2) (remove_min_elt t2) + ;; + + (* Merge two trees l and r into one. All elements of l must precede the elements of r. + No assumption on the heights of l and r. *) + let concat t1 t2 ~compare_elt = + match (t1, t2) with + | Empty, t | t, Empty -> t + | (_, _) -> join t1 (min_elt_exn t2) (remove_min_elt t2) ~compare_elt + ;; + + let split t x ~compare_elt = + let rec split t = + match t with + | Empty -> (Empty, None, Empty) + | Leaf v -> + let c = compare_elt x v in + if c = 0 then (Empty, Some v, Empty) + else if c < 0 then (Empty, None, Leaf v) + else (Leaf v, None, Empty) + | Node (l, v, r, _, _) -> + let c = compare_elt x v in + if c = 0 then (l, Some v, r) + else if c < 0 then + let (ll, maybe_elt, rl) = split l in + (ll, maybe_elt, join rl v r ~compare_elt) + else + let (lr, maybe_elt, rr) = split r in + (join l v lr ~compare_elt, maybe_elt, rr) + in + split t + ;; + + (* Implementation of the set operations *) + + let empty = Empty + + let rec mem t x ~compare_elt = + match t with + | Empty -> false + | Leaf v -> + let c = compare_elt x v in + c = 0 + | Node(l, v, r, _, _) -> + let c = compare_elt x v in + c = 0 || mem (if c < 0 then l else r) x ~compare_elt + ;; + + let singleton x = Leaf x + + let remove t x ~compare_elt = + let rec aux t = + match t with + | Empty -> raise Same + | Leaf v -> if compare_elt x v = 0 then Empty else raise Same + | Node(l, v, r, _, _) -> + let c = compare_elt x v in + if c = 0 then + merge l r + else if c < 0 then + bal (aux l) v r + else + bal l v (aux r) + in + try aux t with Same -> t + ;; + + let remove_index t i ~compare_elt:_ = + let rec aux t i = + match t with + | Empty -> raise Same + | Leaf _ -> if i = 0 then Empty else raise Same + | Node (l, v, r, _, _) -> + let l_size = length l in + let c = Poly.compare i l_size in + if c = 0 then + merge l r + else if c < 0 then + bal (aux l i) v r + else + bal l v (aux r (i - l_size - 1)) + in + try aux t i with Same -> t + ;; + + let union s1 s2 ~compare_elt = + let rec union s1 s2 = + match s1, s2 with + | Empty, t | t, Empty -> t + | Leaf v1, _ -> union (Node(Empty, v1, Empty, 1, 1)) s2 + | _, Leaf v2 -> union s1 (Node(Empty, v2, Empty, 1, 1)) + | (Node(l1, v1, r1, h1, _), Node(l2, v2, r2, h2, _)) -> + if h1 >= h2 + then + if h2 = 1 + then add s1 v2 ~compare_elt + else begin + let (l2, _, r2) = split s2 v1 ~compare_elt in + join (union l1 l2) v1 (union r1 r2) ~compare_elt + end + else + if h1 = 1 + then add s2 v1 ~compare_elt + else begin + let (l1, _, r1) = split s1 v2 ~compare_elt in + join (union l1 l2) v2 (union r1 r2) ~compare_elt + end + in + union s1 s2 + ;; + + let union_list ~comparator ~to_tree xs = + let compare_elt = comparator.Comparator.compare in + List.fold xs ~init:empty ~f:(fun ac x -> union ac (to_tree x) ~compare_elt) + ;; + + let inter s1 s2 ~compare_elt = + let rec inter s1 s2 = + match s1, s2 with + | Empty, _ | _, Empty -> Empty + | ((Leaf elt as singleton), other_set) + | (other_set, (Leaf elt as singleton)) -> + if mem other_set elt ~compare_elt then singleton else Empty + | (Node (l1, v1, r1, _, _), t2) -> + match split t2 v1 ~compare_elt with + | (l2, None, r2) -> concat (inter l1 l2) (inter r1 r2) ~compare_elt + | (l2, Some v1, r2) -> join (inter l1 l2) v1 (inter r1 r2) ~compare_elt + in + inter s1 s2 + ;; + + let diff s1 s2 ~compare_elt = + let rec diff s1 s2 = + match s1, s2 with + | (Empty, _) -> Empty + | (t1, Empty) -> t1 + | (Leaf v1, t2) -> diff (Node(Empty, v1, Empty, 1, 1)) t2 + | (Node(l1, v1, r1, _, _), t2) -> + match split t2 v1 ~compare_elt with + | (l2, None, r2) -> + join (diff l1 l2) v1 (diff r1 r2) ~compare_elt + | (l2, Some _, r2) -> + concat (diff l1 l2) (diff r1 r2) ~compare_elt + in + diff s1 s2 + ;; + + module Enum = struct + type increasing + type decreasing + type ('a, 'direction) t = End | More of 'a * 'a tree * ('a, 'direction) t + + let rec cons s (e : (_, increasing) t) : (_, increasing) t = + match s with + | Empty -> e + | Leaf v -> (More (v, Empty, e)) + | Node (l, v, r, _, _) -> cons l (More (v, r, e)) + ;; + + let rec cons_right s (e : (_, decreasing) t) : (_, decreasing) t = + match s with + | Empty -> e + | Leaf v -> More (v, Empty, e) + | Node (l, v, r, _, _) -> cons_right r (More (v, l, e)) + ;; + + let of_set s : (_, increasing) t = cons s End + + let of_set_right s : (_, decreasing) t = cons_right s End + + let starting_at_increasing t key compare : (_, increasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf v -> loop (Node (Empty, v, Empty, 1, 1)) e + | Node(_, v, r, _, _) when compare v key < 0 -> loop r e + | Node(l, v, r, _, _) -> loop l (More(v, r, e)) + in + loop t End + ;; + + let starting_at_decreasing t key compare : (_, decreasing) t = + let rec loop t e = + match t with + | Empty -> e + | Leaf v -> loop (Node (Empty, v, Empty, 1, 1)) e + | Node(l, v, _, _, _) when compare v key > 0 -> loop l e + | Node(l, v, r, _, _) -> loop r (More(v, l, e)) + in + loop t End + ;; + + let compare compare_elt e1 e2 = + let rec loop e1 e2 = + match e1, e2 with + | End, End -> 0 + | End, _ -> -1 + | _, End -> 1 + | More (v1, r1, e1), More (v2, r2, e2) -> + let c = compare_elt v1 v2 in + if c <> 0 + then c + else loop (cons r1 e1) (cons r2 e2) + in + loop e1 e2 + ;; + + let rec iter ~f = function + | End -> () + | More (a, tree, enum) -> + f a; + iter (cons tree enum) ~f + ;; + + let iter2 compare_elt t1 t2 ~f = + let rec loop t1 t2 = + match t1, t2 with + | End, End -> () + | End, _ -> iter t2 ~f:(fun a -> f (`Right a)) + | _, End -> iter t1 ~f:(fun a -> f (`Left a)) + | More (a1, tree1, enum1), More (a2, tree2, enum2) -> + let compare_result = compare_elt a1 a2 in + if compare_result = 0 then begin + f (`Both (a1, a2)); + loop (cons tree1 enum1) (cons tree2 enum2) + end else if compare_result < 0 then begin + f (`Left a1); + loop (cons tree1 enum1) t2 + end else begin + f (`Right a2); + loop t1 (cons tree2 enum2) + end + in + loop t1 t2 + + let symmetric_diff t1 t2 ~compare_elt = + let step state : ((_,_) Either.t, _) Sequence.Step.t = + match state with + | End, End -> + Done + | End, More (elt, tree, enum) -> + Yield (Second elt, (End, cons tree enum)) + | More (elt, tree, enum), End -> + Yield (First elt, (cons tree enum, End)) + | (More (a1, tree1, enum1) as left), (More (a2, tree2, enum2) as right) -> + let compare_result = compare_elt a1 a2 in + if compare_result = 0 then begin + let next_state = + if phys_equal tree1 tree2 + then (enum1, enum2) + else (cons tree1 enum1, cons tree2 enum2) + in + Skip next_state + end else if compare_result < 0 then begin + Yield (First a1, (cons tree1 enum1, right)) + end else begin + Yield (Second a2, (left, cons tree2 enum2)) + end + in + Sequence.unfold_step ~init:(of_set t1, of_set t2) ~f:step + ;; + + end + + let to_sequence_increasing comparator ~from_elt t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More (k, t, e) -> Sequence.Step.Yield (k, Enum.cons t e) + in + let init = + match from_elt with + | None -> Enum.of_set t + | Some key -> Enum.starting_at_increasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence_decreasing comparator ~from_elt t = + let next enum = + match enum with + | Enum.End -> Sequence.Step.Done + | Enum.More (k, t, e) -> Sequence.Step.Yield (k, Enum.cons_right t e) + in + let init = + match from_elt with + | None -> Enum.of_set_right t + | Some key -> Enum.starting_at_decreasing t key comparator.Comparator.compare + in + Sequence.unfold_step ~init ~f:next + ;; + + let to_sequence comparator ?(order = `Increasing) + ?greater_or_equal_to ?less_or_equal_to t = + let inclusive_bound side t bound = + let compare_elt = comparator.Comparator.compare in + let l, maybe, r = split t bound ~compare_elt in + let t = side (l, r) in + match maybe with + | None -> t + | Some elt -> add t elt ~compare_elt + in + match order with + | `Increasing -> + let t = Option.fold less_or_equal_to ~init:t ~f:(inclusive_bound fst) in + to_sequence_increasing comparator ~from_elt:greater_or_equal_to t + | `Decreasing -> + let t = Option.fold greater_or_equal_to ~init:t ~f:(inclusive_bound snd) in + to_sequence_decreasing comparator ~from_elt:less_or_equal_to t + ;; + + let merge_to_sequence comparator ?(order = `Increasing) + ?greater_or_equal_to ?less_or_equal_to t t' = + Sequence.merge_with_duplicates + (to_sequence comparator ~order ?greater_or_equal_to ?less_or_equal_to t) + (to_sequence comparator ~order ?greater_or_equal_to ?less_or_equal_to t') + ~compare:begin + match order with + | `Increasing -> comparator.compare + | `Decreasing -> Fn.flip comparator.compare + end + ;; + + let compare compare_elt s1 s2 = + Enum.compare compare_elt (Enum.of_set s1) (Enum.of_set s2) + ;; + + let iter2 s1 s2 ~compare_elt = + Enum.iter2 compare_elt (Enum.of_set s1) (Enum.of_set s2) + + let equal s1 s2 ~compare_elt = compare compare_elt s1 s2 = 0 + + let is_subset s1 ~of_:s2 ~compare_elt = + let rec is_subset s1 ~of_:s2 = + match s1, s2 with + | Empty, _ -> true + | _, Empty -> false + | Leaf v1, t2 -> mem t2 v1 ~compare_elt + | Node (l1, v1, r1, _, _), Leaf v2 -> + begin match l1, r1 with + | Empty, Empty -> + (* This case shouldn't occur in practice because we should have constructed + a Leaf rather than a Node with two Empty subtrees *) + compare_elt v1 v2 = 0 + | _, _ -> false + end + | Node (l1, v1, r1, _, _), (Node (l2, v2, r2, _, _) as t2) -> + let c = compare_elt v1 v2 in + if c = 0 + then is_subset l1 ~of_:l2 && is_subset r1 ~of_:r2 + (* Note that height and size don't matter here. *) + else if c < 0 then + is_subset (Node (l1, v1, Empty, 0, 0)) ~of_:l2 && is_subset r1 ~of_:t2 + else + is_subset (Node (Empty, v1, r1, 0, 0)) ~of_:r2 && is_subset l1 ~of_:t2 + in + is_subset s1 ~of_:s2 + ;; + + let iter t ~f = + let rec iter = function + | Empty -> () + | Leaf v -> f v + | Node(l, v, r, _, _) -> iter l; f v; iter r + in + iter t + ;; + + let symmetric_diff = Enum.symmetric_diff + + let rec fold s ~init:accu ~f = + match s with + | Empty -> accu + | Leaf v -> f accu v + | Node(l, v, r, _, _) -> fold ~f r ~init:(f (fold ~f l ~init:accu) v) + ;; + + let hash_fold_t_ignoring_structure hash_fold_elem state t = + fold t ~init:(hash_fold_int state (length t)) ~f:hash_fold_elem + ;; + + let count t ~f = Container.count ~fold t ~f + let sum m t ~f = Container.sum ~fold m t ~f + + let rec fold_right s ~init:accu ~f = + match s with + | Empty -> accu + | Leaf v -> f v accu + | Node(l, v, r, _, _) -> fold_right ~f l ~init:(f v (fold_right ~f r ~init:accu)) + ;; + + let rec for_all t ~f:p = match t with + | Empty -> true + | Leaf v -> p v + | Node(l, v, r, _, _) -> p v && for_all ~f:p l && for_all ~f:p r + ;; + + let rec exists t ~f:p = match t with + | Empty -> false + | Leaf v -> p v + | Node(l, v, r, _, _) -> p v || exists ~f:p l || exists ~f:p r + ;; + + let filter s ~f:p ~compare_elt = + let rec filt accu = function + | Empty -> accu + | Leaf v -> if p v then add accu v ~compare_elt else accu + | Node(l, v, r, _, _) -> + filt (filt (if p v then add accu v ~compare_elt else accu) l) r + in + filt Empty s + ;; + + let filter_map s ~f:p ~compare_elt = + let rec filt accu = function + | Empty -> accu + | Leaf v -> + (match p v with + | None -> accu + | Some v -> add accu v ~compare_elt) + | Node(l, v, r, _, _) -> + filt (filt (match p v with + | None -> accu + | Some v -> add accu v ~compare_elt) l) r + in + filt Empty s + ;; + + let partition_tf s ~f:p ~compare_elt = + let rec part ((t, f) as accu) = function + | Empty -> accu + | Leaf v -> if p v then (add t v ~compare_elt, f) else (t, add f v ~compare_elt) + | Node(l, v, r, _, _) -> + part (part ( + if p v + then (add t v ~compare_elt, f) + else (t, add f v ~compare_elt)) l) r + in + part (Empty, Empty) s + ;; + + let rec elements_aux accu = function + | Empty -> accu + | Leaf v -> v :: accu + | Node(l, v, r, _, _) -> elements_aux (v :: elements_aux accu r) l + ;; + + let elements s = elements_aux [] s + + let choose t = + match t with + | Empty -> None + | Leaf v -> Some v + | Node (_, v, _, _, _) -> Some v + ;; + + let choose_exn t = + match choose t with + | None -> + raise Caml.Not_found + | Some v -> v + ;; + + let of_list lst ~compare_elt = + List.fold lst ~init:empty ~f:(fun t x -> add t x ~compare_elt) + ;; + + let to_list s = elements s + + let of_array a ~compare_elt = + Array.fold a ~init:empty ~f:(fun t x -> add t x ~compare_elt) + ;; + + (* faster but equivalent to [Array.of_list (to_list t)] *) + let to_array = function + | Empty -> [||] + | Leaf v -> [| v |] + | Node (l, v, r, _, s) -> + let res = Array.create ~len:s v in + let pos_ref = ref 0 in + let rec loop = function + (* Invariant: on entry and on exit to [loop], !pos_ref is the next + available cell in the array. *) + | Empty -> () + | Leaf v -> + res.(!pos_ref) <- v; + incr pos_ref + | Node (l, v, r, _, _) -> + loop l; + res.(!pos_ref) <- v; + incr pos_ref; + loop r + in + loop l; + (* res.(!pos_ref) is already initialized (by Array.create ~len:above). *) + incr pos_ref; + loop r; + res + ;; + + let map t ~f ~compare_elt = fold t ~init:empty ~f:(fun t x -> add t (f x) ~compare_elt) + + let group_by set ~equiv ~compare_elt = + let rec loop set equiv_classes = + if is_empty set + then equiv_classes + else + let x = choose_exn set in + let equiv_x, not_equiv_x = + partition_tf set ~f:(fun elt -> phys_equal x elt || equiv x elt) ~compare_elt + in + loop not_equiv_x (equiv_x :: equiv_classes) + in + loop set [] + ;; + + let rec find t ~f = + match t with + | Empty -> None + | Leaf v -> if f v then Some v else None + | Node(l, v, r, _, _) -> + if f v then Some v + else + match find l ~f with + | None -> find r ~f + | Some _ as r -> r + ;; + + let rec find_map t ~f = + match t with + | Empty -> None + | Leaf v -> f v + | Node(l, v, r, _, _) -> + match f v with + | Some _ as r -> r + | None -> + match find_map l ~f with + | None -> find_map r ~f + | Some _ as r -> r + ;; + + let find_exn t ~f = + match find t ~f with + | None -> failwith "Set.find_exn failed to find a matching element" + | Some e -> e + ;; + + let rec nth t i = + match t with + | Empty -> None + | Leaf v -> if i = 0 then Some v else None + | Node (l, v, r, _, s) -> + if i >= s then None + else begin + let l_size = length l in + let c = Poly.compare i l_size in + if c < 0 then nth l i + else if c = 0 then Some v + else nth r (i - l_size - 1) + end + ;; + + let stable_dedup_list xs ~compare_elt = + let rec loop xs leftovers already_seen = + match xs with + | [] -> List.rev leftovers + | hd :: tl -> + if mem already_seen hd ~compare_elt + then loop tl leftovers already_seen + else loop tl (hd :: leftovers) (add already_seen hd ~compare_elt) + in + loop xs [] empty + ;; + + let t_of_sexp_direct a_of_sexp sexp ~compare_elt = + match sexp with + | Sexp.List lst -> + let elt_lst = List.map lst ~f:a_of_sexp in + let set = of_list elt_lst ~compare_elt in + if length set = List.length lst then + set + else + let compare (_, e) (_, e') = compare_elt e e' in + begin match List.find_a_dup (List.zip_exn lst elt_lst) ~compare with + | None -> assert false + | Some (el_sexp, _) -> + of_sexp_error "Set.t_of_sexp: duplicate element in set" el_sexp + end + | sexp -> of_sexp_error "Set.t_of_sexp: list needed" sexp + ;; + + let sexp_of_t sexp_of_a t = + Sexp.List (fold_right t ~init:[] ~f:(fun el acc -> sexp_of_a el :: acc)) + ;; + + module Named = struct + type nonrec ('a, 'cmp) t = { + tree : 'a t; + name : string; + } + + let is_subset (subset : _ t) ~of_:(superset : _ t) ~sexp_of_elt + ~compare_elt = + let invalid_elements = diff subset.tree superset.tree ~compare_elt in + if is_empty invalid_elements + then Ok () + else begin + let invalid_elements_sexp = sexp_of_t sexp_of_elt invalid_elements in + Or_error.error_s ( + Sexp.message (subset.name ^ " is not a subset of " ^ superset.name) + [ "invalid_elements", invalid_elements_sexp ]) + end + + let equal s1 s2 ~sexp_of_elt ~compare_elt = + Or_error.combine_errors_unit + [ is_subset s1 ~of_:s2 ~sexp_of_elt ~compare_elt + ; is_subset s2 ~of_:s1 ~sexp_of_elt ~compare_elt + ] + end +end + +type ('a, 'comparator) t = + { (* [comparator] is the first field so that polymorphic equality fails on a map due + to the functional value in the comparator. + Note that this does not affect polymorphic [compare]: that still produces + nonsense. *) + comparator : ('a, 'comparator) Comparator.t; + tree : 'a Tree0.t; + } + +type ('a, 'comparator) tree = 'a Tree0.t + +let like { tree = _; comparator } tree = { tree; comparator } + +let compare_elt t = t.comparator.Comparator.compare + +module Accessors = struct + let comparator t = t.comparator + let invariants t = Tree0.invariants t.tree ~compare_elt:(compare_elt t) + let length t = Tree0.length t.tree + let is_empty t = Tree0.is_empty t.tree + let elements t = Tree0.elements t.tree + let min_elt t = Tree0.min_elt t.tree + let min_elt_exn t = Tree0.min_elt_exn t.tree + let max_elt t = Tree0.max_elt t.tree + let max_elt_exn t = Tree0.max_elt_exn t.tree + let choose t = Tree0.choose t.tree + let choose_exn t = Tree0.choose_exn t.tree + let to_list t = Tree0.to_list t.tree + let to_array t = Tree0.to_array t.tree + let fold t ~init ~f = Tree0.fold t.tree ~init ~f + let fold_until t ~init ~f = Tree0.fold_until t.tree ~init ~f + let fold_right t ~init ~f = Tree0.fold_right t.tree ~init ~f + let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t + + let iter t ~f = Tree0.iter t.tree ~f + let iter2 a b ~f = Tree0.iter2 a.tree b.tree ~f ~compare_elt:(compare_elt a) + let exists t ~f = Tree0.exists t.tree ~f + let for_all t ~f = Tree0.for_all t.tree ~f + let count t ~f = Tree0.count t.tree ~f + let sum m t ~f = Tree0.sum m t.tree ~f + let find t ~f = Tree0.find t.tree ~f + let find_exn t ~f = Tree0.find_exn t.tree ~f + let find_map t ~f = Tree0.find_map t.tree ~f + let mem t a = Tree0.mem t.tree a ~compare_elt:(compare_elt t) + let filter t ~f = like t (Tree0.filter t.tree ~f ~compare_elt:(compare_elt t)) + let add t a = like t (Tree0.add t.tree a ~compare_elt:(compare_elt t)) + let remove t a = like t (Tree0.remove t.tree a ~compare_elt:(compare_elt t)) + let union t1 t2 = like t1 (Tree0.union t1.tree t2.tree ~compare_elt:(compare_elt t1)) + let inter t1 t2 = like t1 (Tree0.inter t1.tree t2.tree ~compare_elt:(compare_elt t1)) + let diff t1 t2 = like t1 (Tree0.diff t1.tree t2.tree ~compare_elt:(compare_elt t1)) + let symmetric_diff t1 t2 = + Tree0.symmetric_diff t1.tree t2.tree ~compare_elt:(compare_elt t1) + let compare_direct t1 t2 = Tree0.compare (compare_elt t1) t1.tree t2.tree + let equal t1 t2 = Tree0.equal t1.tree t2.tree ~compare_elt:(compare_elt t1) + let is_subset t ~of_ = Tree0.is_subset t.tree ~of_:of_.tree ~compare_elt:(compare_elt t) + + module Named = struct + type nonrec ('a, 'cmp) t = { + set : ('a, 'cmp) t; + name : string; + } + + let to_named_tree { set; name } = { + Tree0.Named. + tree = set.tree; + name; + } + + let is_subset (subset : (_, _) t) ~of_:(superset : (_, _) t) = + Tree0.Named.is_subset (to_named_tree subset) + ~of_:(to_named_tree superset) + ~compare_elt:(compare_elt subset.set) + ~sexp_of_elt:subset.set.comparator.sexp_of_t + + let equal t1 t2 = + Or_error.combine_errors_unit + [ is_subset t1 ~of_:t2 + ; is_subset t2 ~of_:t1 + ] + end + + let partition_tf t ~f = + let (tree_t, tree_f) = Tree0.partition_tf t.tree ~f ~compare_elt:(compare_elt t) in + like t tree_t, like t tree_f + ;; + let split t a = + let (tree1, b, tree2) = Tree0.split t.tree a ~compare_elt:(compare_elt t) in + like t tree1, b, like t tree2 + ;; + let group_by t ~equiv = + List.map (Tree0.group_by t.tree ~equiv ~compare_elt:(compare_elt t)) ~f:(like t) + ;; + let nth t i = Tree0.nth t.tree i + let remove_index t i = like t (Tree0.remove_index t.tree i ~compare_elt:(compare_elt t)) + let sexp_of_t sexp_of_a _ t = Tree0.sexp_of_t sexp_of_a t.tree + let to_sequence ?order ?greater_or_equal_to ?less_or_equal_to t = + Tree0.to_sequence t.comparator ?order ?greater_or_equal_to ?less_or_equal_to t.tree + let merge_to_sequence ?order ?greater_or_equal_to ?less_or_equal_to t t' = + Tree0.merge_to_sequence + t.comparator + ?order + ?greater_or_equal_to + ?less_or_equal_to + t.tree + t'.tree + let hash_fold_direct hash_fold_key state t = + Tree0.hash_fold_t_ignoring_structure hash_fold_key state t.tree +end + +include Accessors + +let compare _ _ t1 t2 = compare_direct t1 t2 + +module Tree = struct + type ('a, 'comparator) t = ('a, 'comparator) tree + + let ce comparator = comparator.Comparator.compare + + let t_of_sexp_direct ~comparator a_of_sexp sexp = + Tree0.t_of_sexp_direct ~compare_elt:(ce comparator) a_of_sexp sexp + + let empty_without_value_restriction = Tree0.empty + let empty ~comparator:_ = empty_without_value_restriction + let singleton ~comparator:_ e = Tree0.singleton e + + let length t = Tree0.length t + let invariants ~comparator t = Tree0.invariants t ~compare_elt:(ce comparator) + let is_empty t = Tree0.is_empty t + let elements t = Tree0.elements t + let min_elt t = Tree0.min_elt t + let min_elt_exn t = Tree0.min_elt_exn t + let max_elt t = Tree0.max_elt t + let max_elt_exn t = Tree0.max_elt_exn t + let choose t = Tree0.choose t + let choose_exn t = Tree0.choose_exn t + let to_list t = Tree0.to_list t + let to_array t = Tree0.to_array t + + let iter t ~f = Tree0.iter t ~f + let exists t ~f = Tree0.exists t ~f + let for_all t ~f = Tree0.for_all t ~f + let count t ~f = Tree0.count t ~f + let sum m t ~f = Tree0.sum m t ~f + let find t ~f = Tree0.find t ~f + let find_exn t ~f = Tree0.find_exn t ~f + let find_map t ~f = Tree0.find_map t ~f + + let fold t ~init ~f = Tree0.fold t ~init ~f + let fold_until t ~init ~f = Tree0.fold_until t ~init ~f + let fold_right t ~init ~f = Tree0.fold_right t ~init ~f + + let map ~comparator t ~f = Tree0.map t ~f ~compare_elt:(ce comparator) + let filter ~comparator t ~f = Tree0.filter t ~f ~compare_elt:(ce comparator) + let filter_map ~comparator t ~f = Tree0.filter_map t ~f ~compare_elt:(ce comparator) + let partition_tf ~comparator t ~f = Tree0.partition_tf t ~f ~compare_elt:(ce comparator) + + let iter2 ~comparator a b ~f = Tree0.iter2 a b ~f ~compare_elt:(ce comparator) + + let mem ~comparator t a = Tree0.mem t a ~compare_elt:(ce comparator) + let add ~comparator t a = Tree0.add t a ~compare_elt:(ce comparator) + let remove ~comparator t a = Tree0.remove t a ~compare_elt:(ce comparator) + + let union ~comparator t1 t2 = Tree0.union t1 t2 ~compare_elt:(ce comparator) + let inter ~comparator t1 t2 = Tree0.inter t1 t2 ~compare_elt:(ce comparator) + let diff ~comparator t1 t2 = Tree0.diff t1 t2 ~compare_elt:(ce comparator) + let symmetric_diff ~comparator t1 t2 = + Tree0.symmetric_diff t1 t2 ~compare_elt:(ce comparator) + let compare_direct ~comparator t1 t2 = Tree0.compare (ce comparator) t1 t2 + let equal ~comparator t1 t2 = Tree0.equal t1 t2 ~compare_elt:(ce comparator) + + let is_subset ~comparator t ~of_ = Tree0.is_subset t ~of_ ~compare_elt:(ce comparator) + + let of_list ~comparator l = Tree0.of_list l ~compare_elt:(ce comparator) + let of_array ~comparator a = Tree0.of_array a ~compare_elt:(ce comparator) + let of_sorted_array_unchecked ~comparator a = + Tree0.of_sorted_array_unchecked a ~compare_elt:(ce comparator) + let of_increasing_iterator_unchecked ~comparator:_ ~len ~f = + Tree0.of_increasing_iterator_unchecked ~len ~f + let of_sorted_array ~comparator a = + Tree0.of_sorted_array a ~compare_elt:(ce comparator) + + let union_list ~comparator l = Tree0.union_list l ~to_tree:Fn.id ~comparator + let stable_dedup_list ~comparator xs = + Tree0.stable_dedup_list xs ~compare_elt:(ce comparator) + ;; + let group_by ~comparator t ~equiv = Tree0.group_by t ~equiv ~compare_elt:(ce comparator) + let split ~comparator t a = Tree0.split t a ~compare_elt:(ce comparator) + let nth t i = Tree0.nth t i + let remove_index ~comparator t i = Tree0.remove_index t i ~compare_elt:(ce comparator) + let sexp_of_t sexp_of_a _ t = Tree0.sexp_of_t sexp_of_a t + + let to_tree t = t + let of_tree ~comparator:_ t = t + + let to_sequence ~comparator ?order ?greater_or_equal_to ?less_or_equal_to t = + Tree0.to_sequence comparator ?order ?greater_or_equal_to ?less_or_equal_to t + + let merge_to_sequence ~comparator ?order ?greater_or_equal_to ?less_or_equal_to t t' = + Tree0.merge_to_sequence comparator ?order ?greater_or_equal_to ?less_or_equal_to t t' + + let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t + + module Named = struct + include Tree0.Named + + let is_subset ~comparator t1 ~of_:t2 = + Tree0.Named.is_subset t1 ~of_:t2 ~compare_elt:(ce comparator) + ~sexp_of_elt:comparator.Comparator.sexp_of_t + + let equal ~comparator t1 t2 = + Tree0.Named.equal t1 t2 ~compare_elt:(ce comparator) + ~sexp_of_elt:comparator.Comparator.sexp_of_t + end +end + +module Using_comparator = struct + type nonrec ('elt, 'cmp) t = ('elt, 'cmp) t + + include Accessors + + let to_tree t = t.tree + + let of_tree ~comparator tree = { comparator; tree } + + let t_of_sexp_direct ~comparator a_of_sexp sexp = + of_tree ~comparator + (Tree0.t_of_sexp_direct ~compare_elt:comparator.compare a_of_sexp sexp) + + let empty ~comparator = { comparator; tree = Tree0.empty } + + module Empty_without_value_restriction(Elt : Comparator.S1) = struct + let empty = { comparator = Elt.comparator; tree = Tree0.empty } + end + + let singleton ~comparator e = { comparator; tree = Tree0.singleton e } + + let union_list ~comparator l = + of_tree ~comparator (Tree0.union_list ~comparator ~to_tree l) + ;; + + let of_sorted_array_unchecked ~comparator array = + let tree = Tree0.of_sorted_array_unchecked array ~compare_elt:comparator.Comparator.compare in + { comparator; tree } + ;; + + let of_increasing_iterator_unchecked ~comparator ~len ~f = + of_tree ~comparator (Tree0.of_increasing_iterator_unchecked ~len ~f) + ;; + + let of_sorted_array ~comparator array = + Or_error.Monad_infix.( + Tree0.of_sorted_array array ~compare_elt:comparator.Comparator.compare + >>| fun tree -> { comparator; tree }) + ;; + + let of_list ~comparator l = + { comparator; tree = Tree0.of_list l ~compare_elt:comparator.Comparator.compare } + ;; + + let of_array ~comparator a = + { comparator; tree = Tree0.of_array a ~compare_elt:comparator.Comparator.compare } + ;; + + let stable_dedup_list ~comparator xs = + Tree0.stable_dedup_list xs ~compare_elt:comparator.Comparator.compare; + ;; + + let map ~comparator t ~f = + { comparator; tree = Tree0.map t.tree ~f ~compare_elt:comparator.Comparator.compare } + ;; + + let filter_map ~comparator t ~f = + { comparator; + tree = Tree0.filter_map t.tree ~f ~compare_elt:comparator.Comparator.compare; + } + ;; + + module Tree = Tree +end + +type ('elt, 'cmp) comparator = + (module Comparator.S with type t = 'elt and type comparator_witness = 'cmp) + +let comparator_s (type k cmp) t : (k, cmp) comparator = + (module struct + type t = k + type comparator_witness = cmp + let comparator = t.comparator + end) + +let to_comparator (type elt cmp) ((module M) : (elt, cmp) comparator) = M.comparator + +let empty m = Using_comparator.empty ~comparator:(to_comparator m) +let singleton m a = Using_comparator.singleton ~comparator:(to_comparator m) a +let union_list m a = Using_comparator.union_list ~comparator:(to_comparator m) a +let of_sorted_array_unchecked m a = Using_comparator.of_sorted_array_unchecked ~comparator:(to_comparator m) a +let of_increasing_iterator_unchecked m ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~comparator:(to_comparator m) ~len ~f +let of_sorted_array m a = Using_comparator.of_sorted_array ~comparator:(to_comparator m) a +let of_list m a = Using_comparator.of_list ~comparator:(to_comparator m) a +let of_array m a = Using_comparator.of_array ~comparator:(to_comparator m) a +let stable_dedup_list m a = Using_comparator.stable_dedup_list ~comparator:(to_comparator m) a +let map m a ~f = Using_comparator.map ~comparator:(to_comparator m) a ~f +let filter_map m a ~f = Using_comparator.filter_map ~comparator:(to_comparator m) a ~f + +module M (Elt : sig type t type comparator_witness end) = struct + type nonrec t = (Elt.t, Elt.comparator_witness) t +end + +module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end +module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t +end +module type Compare_m = sig end +module type Hash_fold_m = Hasher.S + +let sexp_of_m__t (type elt) (module Elt : Sexp_of_m with type t = elt) t = + sexp_of_t Elt.sexp_of_t (fun _ -> Sexp.Atom "_") t + +let m__t_of_sexp (type elt cmp) + (module Elt : M_of_sexp with type t = elt and type comparator_witness = cmp) + sexp = + Using_comparator.t_of_sexp_direct ~comparator:Elt.comparator Elt.t_of_sexp sexp + +let compare_m__t (module Elt : Compare_m) t1 t2 = + compare_direct t1 t2 + +let hash_fold_m__t (type elt) (module Elt : Hash_fold_m with type t = elt) state = + hash_fold_direct Elt.hash_fold_t state + +let hash_m__t folder t = + let state = hash_fold_m__t folder (Hash.create ()) t in + Hash.get_hash_value state + +module Poly = struct + type comparator_witness = Comparator.Poly.comparator_witness + type nonrec ('elt, 'cmp) set = ('elt, comparator_witness) t + type nonrec 'elt t = ('elt, comparator_witness) t + type nonrec 'elt tree = ('elt, comparator_witness) tree + type nonrec 'elt named = ('elt, comparator_witness) Named.t + + include Accessors + + let comparator = Comparator.Poly.comparator + + include Using_comparator.Empty_without_value_restriction(Comparator.Poly) + + let singleton a = Using_comparator.singleton ~comparator a + let union_list a = Using_comparator.union_list ~comparator a + let of_sorted_array_unchecked a = Using_comparator.of_sorted_array_unchecked ~comparator a + let of_increasing_iterator_unchecked ~len ~f = + Using_comparator.of_increasing_iterator_unchecked ~comparator ~len ~f + let of_sorted_array a = Using_comparator.of_sorted_array ~comparator a + let of_list a = Using_comparator.of_list ~comparator a + let of_array a = Using_comparator.of_array ~comparator a + let stable_dedup_list a = Using_comparator.stable_dedup_list ~comparator a + let map a ~f = Using_comparator.map ~comparator a ~f + let filter_map a ~f = Using_comparator.filter_map ~comparator a ~f + + let of_tree tree = { comparator; tree } + let to_tree t = t.tree +end diff --git a/src/set.mli b/src/set.mli new file mode 100644 index 0000000..5b64bbd --- /dev/null +++ b/src/set.mli @@ -0,0 +1 @@ +include Set_intf.Set (** @inline *) diff --git a/src/set_intf.ml b/src/set_intf.ml new file mode 100644 index 0000000..db1605a --- /dev/null +++ b/src/set_intf.ml @@ -0,0 +1,1278 @@ +open! Import +open! T + +module type Elt_plain = sig + type t [@@deriving_inline compare, sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] +end + +module Without_comparator = Map_intf.Without_comparator +module With_comparator = Map_intf.With_comparator +module With_first_class_module = Map_intf.With_first_class_module + +include Container_intf.Export + +module Merge_to_sequence_element = Sequence.Merge_with_duplicates_element + +module type Accessors_generic = sig + + include Container.Generic_phantom + + type ('a, 'cmp) tree + + (** The [options] type is used to make [Accessors_generic] flexible as to whether a + comparator is required to be passed to certain functions. *) + type ('a, 'cmp, 'z) options + + type 'cmp cmp + + val invariants + : ('a, 'cmp, + ('a, 'cmp) t -> bool + ) options + + (** override [Container]'s [mem] *) + val mem : ('a, 'cmp, ('a, 'cmp) t -> 'a elt -> bool) options + val add + : ('a, 'cmp, + ('a, 'cmp) t -> 'a elt -> ('a, 'cmp) t + ) options + val remove + : ('a, 'cmp, + ('a, 'cmp) t -> 'a elt -> ('a, 'cmp) t + ) options + val union + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + ) options + val inter + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + ) options + val diff + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + ) options + val symmetric_diff + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t + -> ('a elt, 'a elt) Either.t Sequence.t + ) options + val compare_direct + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> int + ) options + val equal + : ('a, 'cmp, + ('a, 'cmp) t -> ('a, 'cmp) t -> bool + ) options + val is_subset + : ('a, 'cmp, + ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + ) options + + type ('a, 'cmp) named + module Named : sig + val is_subset + : ('a, 'cmp, + ('a, 'cmp) named + -> of_:('a, 'cmp) named + -> unit Or_error.t + ) options + + val equal + : ('a, 'cmp, + ('a, 'cmp) named + -> ('a, 'cmp) named + -> unit Or_error.t + ) options + end + + val fold_until + : ('a, _) t + -> init:'b + -> f:('b -> 'a elt -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + val fold_right + : ('a, _) t + -> init:'b + -> f:('a elt -> 'b -> 'b) + -> 'b + val iter2 + : ('a, 'cmp, + ('a, 'cmp) t + -> ('a, 'cmp) t + -> f:([ `Left of 'a elt | `Right of 'a elt | `Both of 'a elt * 'a elt ] -> unit) + -> unit + ) options + val filter + : ('a, 'cmp, + ('a, 'cmp) t -> f:('a elt -> bool) -> ('a, 'cmp) t + ) options + val partition_tf + : ('a, 'cmp, + ('a, 'cmp) t + -> f:('a elt -> bool) + -> ('a, 'cmp) t * ('a, 'cmp) t + ) options + + val elements : ('a, _) t -> 'a elt list + + val min_elt : ('a, _) t -> 'a elt option + val min_elt_exn : ('a, _) t -> 'a elt + val max_elt : ('a, _) t -> 'a elt option + val max_elt_exn : ('a, _) t -> 'a elt + val choose : ('a, _) t -> 'a elt option + val choose_exn : ('a, _) t -> 'a elt + + val split + : ('a, 'cmp, + ('a, 'cmp) t + -> 'a elt + -> ('a, 'cmp) t * 'a elt option * ('a, 'cmp) t + ) options + + val group_by + : ('a, 'cmp, + ('a, 'cmp) t + -> equiv:('a elt -> 'a elt -> bool) + -> ('a, 'cmp) t list + ) options + + val find_exn : ('a, _) t -> f:('a elt -> bool) -> 'a elt + val nth : ('a, _) t -> int -> 'a elt option + val remove_index + : ('a, 'cmp, + ('a, 'cmp) t -> int -> ('a, 'cmp) t + ) options + + val to_tree : ('a, 'cmp) t -> ('a elt, 'cmp) tree + + val to_sequence + : ('a, 'cmp, + ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a elt + -> ?less_or_equal_to:'a elt + -> ('a, 'cmp) t + -> 'a elt Sequence.t + ) options + + val merge_to_sequence + : ('a, 'cmp, + ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a elt + -> ?less_or_equal_to:'a elt + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a elt, 'a elt) Merge_to_sequence_element.t Sequence.t + ) options +end + +module type Accessors0 = sig + include Container.S0 + type tree + type comparator_witness + val invariants : t -> bool + val mem : t -> elt -> bool + val add : t -> elt -> t + val remove : t -> elt -> t + val union : t -> t -> t + val inter : t -> t -> t + val diff : t -> t -> t + val symmetric_diff : t -> t -> (elt, elt) Either.t Sequence.t + val compare_direct : t -> t -> int + val equal : t -> t -> bool + val is_subset : t -> of_:t -> bool + + type named + module Named : sig + val is_subset : named -> of_:named -> unit Or_error.t + val equal : named -> named -> unit Or_error.t + end + + val fold_until + : t + -> init:'b + -> f:('b -> elt -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + val fold_right : t -> init:'b -> f:(elt -> 'b -> 'b) -> 'b + val iter2 + : t -> t -> f:([ `Left of elt | `Right of elt | `Both of elt * elt ] -> unit) -> unit + val filter : t -> f:(elt -> bool) -> t + val partition_tf : t -> f:(elt -> bool) -> t * t + val elements : t -> elt list + val min_elt : t -> elt option + val min_elt_exn : t -> elt + val max_elt : t -> elt option + val max_elt_exn : t -> elt + val choose : t -> elt option + val choose_exn : t -> elt + val split : t -> elt -> t * elt option * t + val group_by : t -> equiv:(elt -> elt -> bool) -> t list + val find_exn : t -> f:(elt -> bool) -> elt + val nth : t -> int -> elt option + val remove_index : t -> int -> t + val to_tree : t -> tree + val to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:elt + -> ?less_or_equal_to:elt + -> t + -> elt Sequence.t + val merge_to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:elt + -> ?less_or_equal_to:elt + -> t + -> t + -> (elt, elt) Merge_to_sequence_element.t Sequence.t +end + +module type Accessors1 = sig + include Container.S1 + type 'a tree + type comparator_witness + val invariants : _ t -> bool + val mem : 'a t -> 'a -> bool + val add : 'a t -> 'a -> 'a t + val remove : 'a t -> 'a -> 'a t + val union : 'a t -> 'a t -> 'a t + val inter : 'a t -> 'a t -> 'a t + val diff : 'a t -> 'a t -> 'a t + val symmetric_diff : 'a t -> 'a t -> ('a, 'a) Either.t Sequence.t + val compare_direct : 'a t -> 'a t -> int + val equal : 'a t -> 'a t -> bool + val is_subset : 'a t -> of_:'a t -> bool + + type 'a named + module Named : sig + val is_subset : 'a named -> of_:'a named -> unit Or_error.t + val equal : 'a named -> 'a named -> unit Or_error.t + end + + val fold_until + : 'a t + -> init:'b + -> f:('b -> 'a -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + val fold_right : 'a t -> init:'b -> f:('a -> 'b -> 'b) -> 'b + val iter2 + : 'a t -> 'a t -> f:([ `Left of 'a | `Right of 'a | `Both of 'a * 'a ] -> unit) -> unit + val filter : 'a t -> f:('a -> bool) -> 'a t + val partition_tf : 'a t -> f:('a -> bool) -> 'a t * 'a t + val elements : 'a t -> 'a list + val min_elt : 'a t -> 'a option + val min_elt_exn : 'a t -> 'a + val max_elt : 'a t -> 'a option + val max_elt_exn : 'a t -> 'a + val choose : 'a t -> 'a option + val choose_exn : 'a t -> 'a + val split : 'a t -> 'a -> 'a t * 'a option * 'a t + val group_by : 'a t -> equiv:('a -> 'a -> bool) -> 'a t list + val find_exn : 'a t -> f:('a -> bool) -> 'a + val nth : 'a t -> int -> 'a option + val remove_index : 'a t -> int -> 'a t + val to_tree : 'a t -> 'a tree + val to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> 'a t + -> 'a Sequence.t + val merge_to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> 'a t + -> 'a t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t +end + +module type Accessors2 = sig + include Container.S1_phantom_invariant + type ('a, 'cmp) tree + val invariants : (_, _) t -> bool + val mem : ('a, _) t -> 'a -> bool + val add : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val remove : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val union : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val inter : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val diff : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val symmetric_diff : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'a) Either.t Sequence.t + val compare_direct : ('a, 'cmp) t -> ('a, 'cmp) t -> int + val equal : ('a, 'cmp) t -> ('a, 'cmp) t -> bool + val is_subset : ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + + type ('a, 'cmp) named + module Named : sig + val is_subset : ('a, 'cmp) named -> of_:('a, 'cmp) named -> unit Or_error.t + val equal : ('a, 'cmp) named -> ('a, 'cmp) named -> unit Or_error.t + end + + val fold_until + : ('a, _) t + -> init:'b + -> f:('b -> 'a -> ('b, 'final) Continue_or_stop.t) + -> finish:('b -> 'final) + -> 'final + + val fold_right : ('a, _) t -> init:'b -> f:('a -> 'b -> 'b) -> 'b + val iter2 + : ('a, 'cmp) t + -> ('a, 'cmp) t -> f:([ `Left of 'a | `Right of 'a | `Both of 'a * 'a ] -> unit) + -> unit + val filter : ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t + val partition_tf : ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t * ('a, 'cmp) t + val elements : ('a, _) t -> 'a list + val min_elt : ('a, _) t -> 'a option + val min_elt_exn : ('a, _) t -> 'a + val max_elt : ('a, _) t -> 'a option + val max_elt_exn : ('a, _) t -> 'a + val choose : ('a, _) t -> 'a option + val choose_exn : ('a, _) t -> 'a + val split : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t * 'a option * ('a, 'cmp) t + val group_by : ('a, 'cmp) t -> equiv:('a -> 'a -> bool) -> ('a, 'cmp) t list + val find_exn : ('a, _) t -> f:('a -> bool) -> 'a + val nth : ('a, _) t -> int -> 'a option + val remove_index : ('a, 'cmp) t -> int -> ('a, 'cmp) t + val to_tree : ('a, 'cmp) t -> ('a, 'cmp) tree + val to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> 'a Sequence.t + val merge_to_sequence + : ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t +end + +module type Accessors2_with_comparator = sig + include Container.S1_phantom_invariant + type ('a, 'cmp) tree + val invariants : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> bool + val mem : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> 'a -> bool + val add + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val remove + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + val union + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val inter + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val diff + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + val symmetric_diff + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t + -> ('a, 'a) Either.t Sequence.t + val compare_direct + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> int + val equal + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> ('a, 'cmp) t -> bool + val is_subset + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + + type ('a, 'cmp) named + module Named : sig + val is_subset + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) named + -> of_:('a, 'cmp) named + -> unit Or_error.t + val equal + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) named + -> ('a, 'cmp) named + -> unit Or_error.t + end + + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + val fold_right : ('a, _) t -> init:'accum -> f:('a -> 'accum -> 'accum) -> 'accum + val iter2 + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> f:([ `Left of 'a | `Right of 'a | `Both of 'a * 'a ] -> unit) + -> unit + val filter + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t + val partition_tf + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t * ('a, 'cmp) t + val elements : ('a, _) t -> 'a list + val min_elt : ('a, _) t -> 'a option + val min_elt_exn : ('a, _) t -> 'a + val max_elt : ('a, _) t -> 'a option + val max_elt_exn : ('a, _) t -> 'a + val choose : ('a, _) t -> 'a option + val choose_exn : ('a, _) t -> 'a + val split + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t -> 'a -> ('a, 'cmp) t * 'a option * ('a, 'cmp) t + val group_by + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t -> equiv:('a -> 'a -> bool) -> ('a, 'cmp) t list + val find_exn : ('a, _) t -> f:('a -> bool) -> 'a + val nth : ('a, _) t -> int -> 'a option + + val remove_index + : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t -> int -> ('a, 'cmp) t + val to_tree : ('a, 'cmp) t -> ('a, 'cmp) tree + val to_sequence + : comparator:('a, 'cmp) Comparator.t + -> ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> 'a Sequence.t + val merge_to_sequence + : comparator:('a, 'cmp) Comparator.t + -> ?order:[ `Increasing | `Decreasing ] + -> ?greater_or_equal_to:'a + -> ?less_or_equal_to:'a + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t +end + +(** Consistency checks (same as in [Container]). *) +module Check_accessors (T : T2) (Tree : T2) (Elt : T1) (Named : T2) (Cmp : T1) (Options : T3) + (M : Accessors_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b) t := ('a, 'b) T.t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type 'a elt := 'a Elt.t + with type 'cmp cmp := 'cmp Cmp.t + with type ('a, 'b) named := ('a, 'b) Named.t) += struct end + +module Check_accessors0 (M : Accessors0) = + Check_accessors + (struct type ('a, 'b) t = M.t end) + (struct type ('a, 'b) t = M.tree end) + (struct type 'a t = M.elt end) + (struct type ('a, 'b) t = M.named end) + (struct type 'a t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_accessors1 (M : Accessors1) = + Check_accessors + (struct type ('a, 'b) t = 'a M.t end) + (struct type ('a, 'b) t = 'a M.tree end) + (struct type 'a t = 'a end) + (struct type ('a, 'b) t = 'a M.named end) + (struct type 'a t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_accessors2 (M : Accessors2) = + Check_accessors + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type ('a, 'b) t = ('a, 'b) M.named end) + (struct type 'a t = 'a end) + (Without_comparator) + (M) + +module Check_accessors2_with_comparator (M : Accessors2_with_comparator) = + Check_accessors + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type ('a, 'b) t = ('a, 'b) M.named end) + (struct type 'a t = 'a end) + (With_comparator) + (M) + +module type Creators_generic = sig + type ('a, 'cmp) t + type ('a, 'cmp) set + type ('a, 'cmp) tree + type 'a elt + type ('a, 'cmp, 'z) options + type 'cmp cmp + + val empty : ('a, 'cmp, ('a, 'cmp) t) options + val singleton : ('a, 'cmp, 'a elt -> ('a, 'cmp) t) options + val union_list + : ('a, 'cmp, + ('a, 'cmp) t list -> ('a, 'cmp) t + ) options + val of_list : ('a, 'cmp, 'a elt list -> ('a, 'cmp) t) options + val of_array : ('a, 'cmp, 'a elt array -> ('a, 'cmp) t) options + + val of_sorted_array : ('a, 'cmp, 'a elt array -> ('a, 'cmp) t Or_error.t) options + + val of_sorted_array_unchecked : ('a, 'cmp, 'a elt array -> ('a, 'cmp) t) options + + val of_increasing_iterator_unchecked + : ('a, 'cmp, len:int -> f:(int -> 'a elt) -> ('a, 'cmp) t) options + + val stable_dedup_list : ('a, _, 'a elt list -> 'a elt list) options + + (** The types of [map] and [filter_map] are subtle. The input set, [('a, _) set], + reflects the fact that these functions take a set of *any* type, with any + comparator, while the output set, [('b, 'cmp) t], reflects that the output set has + the particular ['cmp] of the creation function. The comparator can come in one of + three ways, depending on which set module is used + + - [Set.map] -- comparator comes as an argument + - [Set.Poly.map] -- comparator is polymorphic comparison + - [Foo.Set.map] -- comparator is [Foo.comparator] *) + val map + : ('b, 'cmp, ('a, _) set -> f:('a -> 'b elt ) -> ('b, 'cmp) t + ) options + val filter_map + : ('b, 'cmp, ('a, _) set -> f:('a -> 'b elt option) -> ('b, 'cmp) t + ) options + + val of_tree + : ('a, 'cmp, + ('a elt, 'cmp) tree -> ('a, 'cmp) t + ) options +end + +module type Creators0 = sig + type ('a, 'cmp) set + type t + type tree + type elt + type comparator_witness + val empty : t + val singleton : elt -> t + val union_list : t list -> t + val of_list : elt list -> t + val of_array : elt array -> t + val of_sorted_array : elt array -> t Or_error.t + val of_sorted_array_unchecked : elt array -> t + val of_increasing_iterator_unchecked : len:int -> f:(int -> elt) -> t + val stable_dedup_list : elt list -> elt list + val map : ('a, _) set -> f:('a -> elt ) -> t + val filter_map : ('a, _) set -> f:('a -> elt option) -> t + val of_tree : tree -> t +end + +module type Creators1 = sig + type ('a, 'cmp) set + type 'a t + type 'a tree + type comparator_witness + val empty : 'a t + val singleton : 'a -> 'a t + val union_list : 'a t list -> 'a t + val of_list : 'a list -> 'a t + val of_array : 'a array -> 'a t + val of_sorted_array : 'a array -> 'a t Or_error.t + val of_sorted_array_unchecked : 'a array -> 'a t + val of_increasing_iterator_unchecked : len:int -> f:(int -> 'a) -> 'a t + val stable_dedup_list : 'a list -> 'a list + val map : ('a, _) set -> f:('a -> 'b ) -> 'b t + val filter_map : ('a, _) set -> f:('a -> 'b option) -> 'b t + val of_tree : 'a tree -> 'a t +end + +module type Creators2 = sig + type ('a, 'cmp) set + type ('a, 'cmp) t + type ('a, 'cmp) tree + val empty : ('a, 'cmp) t + val singleton : 'a -> ('a, 'cmp) t + val union_list : ('a, 'cmp) t list -> ('a, 'cmp) t + val of_list : 'a list -> ('a, 'cmp) t + val of_array : 'a array -> ('a, 'cmp) t + val of_sorted_array : 'a array -> ('a, 'cmp) t Or_error.t + val of_sorted_array_unchecked : 'a array -> ('a, 'cmp) t + val of_increasing_iterator_unchecked : len:int -> f:(int -> 'a) -> ('a, 'cmp) t + val stable_dedup_list : 'a list -> 'a list + val map : ('a, _) set -> f:('a -> 'b ) -> ('b, 'cmp) t + val filter_map : ('a, _) set -> f:('a -> 'b option) -> ('b, 'cmp) t + val of_tree : ('a, 'cmp) tree -> ('a, 'cmp) t +end + +module type Creators2_with_comparator = sig + type ('a, 'cmp) set + type ('a, 'cmp) t + type ('a, 'cmp) tree + val empty : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t + val singleton : comparator:('a, 'cmp) Comparator.t -> 'a -> ('a, 'cmp) t + val union_list : comparator:('a, 'cmp) Comparator.t -> ('a, 'cmp) t list + -> ('a, 'cmp) t + val of_list : comparator:('a, 'cmp) Comparator.t -> 'a list + -> ('a, 'cmp) t + val of_array : comparator:('a, 'cmp) Comparator.t -> 'a array + -> ('a, 'cmp) t + val of_sorted_array : comparator:('a, 'cmp) Comparator.t -> 'a array + -> ('a, 'cmp) t Or_error.t + val of_sorted_array_unchecked : comparator:('a, 'cmp) Comparator.t -> 'a array + -> ('a, 'cmp) t + val of_increasing_iterator_unchecked + : comparator:('a, 'cmp) Comparator.t -> len:int -> f:(int -> 'a) -> ('a, 'cmp) t + val stable_dedup_list : comparator:('a, 'cmp) Comparator.t -> 'a list -> 'a list + val map : comparator:('b, 'cmp) Comparator.t -> ('a, _) set + -> f:('a -> 'b ) -> ('b, 'cmp) t + val filter_map : comparator:('b, 'cmp) Comparator.t -> ('a, _) set + -> f:('a -> 'b option) -> ('b, 'cmp) t + val of_tree : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) tree -> ('a, 'cmp) t +end + +module Check_creators (T : T2) (Tree : T2) (Elt : T1) (Cmp : T1) (Options : T3) + (M : Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) Options.t + with type ('a, 'b) t := ('a, 'b) T.t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type 'a elt := 'a Elt.t + with type 'cmp cmp := 'cmp Cmp.t) += struct end + +module Check_creators0 (M : Creators0) = + Check_creators + (struct type ('a, 'b) t = M.t end) + (struct type ('a, 'b) t = M.tree end) + (struct type 'a t = M.elt end) + (struct type 'cmp t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_creators1 (M : Creators1) = + Check_creators + (struct type ('a, 'b) t = 'a M.t end) + (struct type ('a, 'b) t = 'a M.tree end) + (struct type 'a t = 'a end) + (struct type 'cmp t = M.comparator_witness end) + (Without_comparator) + (M) + +module Check_creators2 (M : Creators2) = + Check_creators + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type 'cmp t = 'cmp end) + (Without_comparator) + (M) + +module Check_creators2_with_comparator (M : Creators2_with_comparator) = + Check_creators + (struct type ('a, 'b) t = ('a, 'b) M.t end) + (struct type ('a, 'b) t = ('a, 'b) M.tree end) + (struct type 'a t = 'a end) + (struct type 'cmp t = 'cmp end) + (With_comparator) + (M) + +module type Creators_and_accessors_generic = sig + include Accessors_generic + include Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) options + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree + with type 'a elt := 'a elt + with type 'cmp cmp := 'cmp cmp +end + +module type Creators_and_accessors0 = sig + include Accessors0 + include Creators0 + with type t := t + with type tree := tree + with type elt := elt + with type comparator_witness := comparator_witness +end + +module type Creators_and_accessors1 = sig + include Accessors1 + include Creators1 + with type 'a t := 'a t + with type 'a tree := 'a tree + with type comparator_witness := comparator_witness +end + +module type Creators_and_accessors2 = sig + include Accessors2 + include Creators2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree +end + +module type Creators_and_accessors2_with_comparator = sig + include Accessors2_with_comparator + include Creators2_with_comparator + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) tree +end + +module type S_poly = Creators_and_accessors1 + +module type For_deriving = sig + type ('a, 'b) t + + module type Sexp_of_m = sig type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] end + module type M_of_sexp = sig + type t [@@deriving_inline of_sexp] + include + sig [@@@ocaml.warning "-32"] val t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t + end[@@ocaml.doc "@inline"] + [@@@end] include Comparator.S with type t := t + end + module type Compare_m = sig end + module type Hash_fold_m = Hasher.S + + val sexp_of_m__t + : (module Sexp_of_m with type t = 'elt) + -> ('elt, 'cmp) t + -> Sexp.t + + val m__t_of_sexp + : (module M_of_sexp with type t = 'elt and type comparator_witness = 'cmp) + -> Sexp.t + -> ('elt, 'cmp) t + + val compare_m__t : (module Compare_m) -> ('elt, 'cmp) t -> ('elt, 'cmp) t -> int + + val hash_fold_m__t + : (module Hash_fold_m with type t = 'elt) + -> (Hash.state -> ('elt, _) t -> Hash.state) + + val hash_m__t + : (module Hash_fold_m with type t = 'elt) + -> (('elt, _) t -> int) +end + +module type Set = sig + (** This module defines the [Set] module for [Base]. Functions that construct a set take + as an argument the comparator for the element type. *) + + (** The type of a set. The first type parameter identifies the type of the element, and + the second identifies the comparator, which determines the comparison function that + is used for ordering elements in this set. Many operations (e.g., {!union}), + require that they be passed sets with the same element type and the same comparator + type. *) + type ('elt, 'cmp) t [@@deriving_inline compare] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('elt -> 'elt -> int) -> + ('cmp -> 'cmp -> int) -> ('elt, 'cmp) t -> ('elt, 'cmp) t -> int + end[@@ocaml.doc "@inline"] + [@@@end] + + type ('k, 'cmp) comparator = + (module Comparator.S with type t = 'k and type comparator_witness = 'cmp) + + (** Tests internal invariants of the set data structure. Returns true on success. *) + val invariants : (_, _) t -> bool + + (** Returns a first-class module that can be used to build other map/set/etc + with the same notion of comparison. *) + val comparator_s : ('a, 'cmp) t -> ('a, 'cmp) comparator + + val comparator : ('a, 'cmp) t -> ('a, 'cmp) Comparator.t + + (** Creates an empty set based on the provided comparator. *) + val empty : ('a, 'cmp) comparator -> ('a, 'cmp) t + + (** Creates a set based on the provided comparator that contains only the provided + element. *) + val singleton : ('a, 'cmp) comparator -> 'a -> ('a, 'cmp) t + + (** Returns the cardinality of the set. [O(1)]. *) + val length : (_, _) t -> int + + (** [is_empty t] is [true] iff [t] is empty. [O(1)]. *) + val is_empty : (_, _) t -> bool + + (** [mem t a] returns [true] iff [a] is in [t]. [O(log n)]. *) + val mem : ('a, _) t -> 'a -> bool + + (** [add t a] returns a new set with [a] added to [t], or returns [t] if [mem t a]. + [O(log n)]. *) + val add : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + + (** [remove t a] returns a new set with [a] removed from [t] if [mem t a], or returns [t] + otherwise. [O(log n)]. *) + val remove : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t + + (** [union t1 t2] returns the union of the two sets. [O(length t1 + length t2)]. *) + val union : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + + (** [union c list] returns the union of all the sets in [list]. The + [comparator] argument is required for the case where [list] is empty. + [O(max(List.length list, n log n))], where [n] is the sum of sizes of the input sets. *) + val union_list : ('a, 'cmp) comparator -> ('a, 'cmp) t list -> ('a, 'cmp) t + + (** [inter t1 t2] computes the intersection of sets [t1] and [t2]. [O(length t1 + + length t2)]. *) + val inter : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + + (** [diff t1 t2] computes the set difference [t1 - t2], i.e., the set containing all + elements in [t1] that are not in [t2]. [O(length t1 + length t2)]. *) + val diff : ('a, 'cmp) t -> ('a, 'cmp) t -> ('a, 'cmp) t + + (** [symmetric_diff t1 t2] returns a sequence of changes between [t1] and [t2]. It is + intended to be efficient in the case where [t1] and [t2] share a large amount of + structure. *) + val symmetric_diff + : ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Either.t Sequence.t + + (** [compare_direct t1 t2] compares the sets [t1] and [t2]. It returns the same result + as [compare], but unlike compare, doesn't require arguments to be passed in for the + type parameters of the set. [O(length t1 + length t2)]. *) + val compare_direct : ('a, 'cmp) t -> ('a, 'cmp) t -> int + + (** Hash function: a building block to use when hashing data structures containing sets in + them. [hash_fold_direct hash_fold_key] is compatible with [compare_direct] iff + [hash_fold_key] is compatible with [(comparator s).compare] of the set [s] being + hashed. *) + val hash_fold_direct + : 'a Hash.folder + -> ('a, 'cmp) t Hash.folder + + (** [equal t1 t2] returns [true] iff the two sets have the same elements. [O(length t1 + + length t2)] *) + val equal : ('a, 'cmp) t -> ('a, 'cmp) t -> bool + + (** [exists t ~f] returns [true] iff there exists an [a] in [t] for which [f a]. [O(n)], + but returns as soon as it finds an [a] for which [f a]. *) + val exists : ('a, _) t -> f:('a -> bool) -> bool + + (** [for_all t ~f] returns [true] iff for all [a] in [t], [f a]. [O(n)], but returns as + soon as it finds an [a] for which [not (f a)]. *) + val for_all : ('a, _) t -> f:('a -> bool) -> bool + + (** [count t] returns the number of elements of [t] for which [f] returns [true]. + [O(n)]. *) + val count : ('a, _) t -> f:('a -> bool) -> int + + (** [sum t] returns the sum of [f t] for each [t] in the set. + [O(n)]. *) + val sum + : (module Container.Summable with type t = 'sum) + -> ('a, _) t -> f:('a -> 'sum) -> 'sum + + (** [find t f] returns an element of [t] for which [f] returns true, with no guarantee as + to which element is returned. [O(n)], but returns as soon as a suitable element is + found. *) + val find : ('a, _) t -> f:('a -> bool) -> 'a option + + (** [find_map t f] returns [b] for some [a] in [t] for which [f a = Some b]. If no such + [a] exists, then [find] returns [None]. [O(n)], but returns as soon as a suitable + element is found. *) + val find_map : ('a, _) t -> f:('a -> 'b option) -> 'b option + + (** Like [find], but throws an exception on failure. *) + val find_exn : ('a, _) t -> f:('a -> bool) -> 'a + + (** [nth t i] returns the [i]th smallest element of [t], in [O(log n)] time. The + smallest element has [i = 0]. Returns [None] if [i < 0] or [i >= length t]. *) + val nth : ('a, _) t -> int -> 'a option + + (** [remove_index t i] returns a version of [t] with the [i]th smallest element removed, + in [O(log n)] time. The smallest element has [i = 0]. Returns [t] if [i < 0] or + [i >= length t]. *) + val remove_index : ('a, 'cmp) t -> int -> ('a, 'cmp) t + + (** [is_subset t1 ~of_:t2] returns true iff [t1] is a subset of [t2]. *) + val is_subset : ('a, 'cmp) t -> of_:('a, 'cmp) t -> bool + + (** [Named] allows the validation of subset and equality relationships between sets. A + [Named.t] is a record of a set and a name, where the name is used in error messages, + and [Named.is_subset] and [Named.equal] validate subset and equality relationships + respectively. + + The error message for, e.g., + {[ + Named.is_subset { set = set1; name = "set1" } ~of_:{set = set2; name = "set2" } + ]} + + looks like + {v + ("set1 is not a subset of set2" (invalid_elements (...elements of set1 - set2...))) + v} + + so [name] should be a noun phrase that doesn't sound awkward in the above error + message. Even though it adds verbosity, choosing [name]s that start with the phrase + "the set of" often makes the error message sound more natural. + *) + module Named : sig + type nonrec ('a, 'cmp) t = { + set : ('a, 'cmp) t; + name : string; + } + + (** [is_subset t1 ~of_:t2] returns [Ok ()] if [t1] is a subset of [t2] and a + human-readable error otherwise. *) + val is_subset : ('a, 'cmp) t -> of_:('a, 'cmp) t -> unit Or_error.t + + (** [equal t1 t2] returns [Ok ()] if [t1] is equal to [t2] and a human-readable + error otherwise. *) + val equal : ('a, 'cmp) t -> ('a, 'cmp) t -> unit Or_error.t + end + + (** The list or array given to [of_list] and [of_array] need not be sorted. *) + val of_list : ('a, 'cmp) comparator -> 'a list -> ('a, 'cmp) t + val of_array : ('a, 'cmp) comparator -> 'a array -> ('a, 'cmp) t + + (** [to_list] and [to_array] produce sequences sorted in ascending order according to the + comparator. *) + val to_list : ('a, _) t -> 'a list + val to_array : ('a, _) t -> 'a array + + (** Create set from sorted array. The input must be sorted (either in ascending or + descending order as given by the comparator) and contain no duplicates, otherwise the + result is an error. The complexity of this function is [O(n)]. *) + val of_sorted_array + : ('a, 'cmp) comparator + -> 'a array + -> ('a, 'cmp) t Or_error.t + + (** Similar to [of_sorted_array], but without checking the input array. *) + val of_sorted_array_unchecked + : ('a, 'cmp) comparator + -> 'a array + -> ('a, 'cmp) t + + (** [of_increasing_iterator_unchecked c ~len ~f] behaves like [of_sorted_array_unchecked c + (Array.init len ~f)], with the additional restriction that a decreasing order is not + supported. The advantage is not requiring you to allocate an intermediate array. [f] + will be called with 0, 1, ... [len - 1], in order. *) + val of_increasing_iterator_unchecked + : ('a, 'cmp) comparator + -> len:int + -> f:(int -> 'a) + -> ('a, 'cmp) t + + (** [stable_dedup_list] is here rather than in the [List] module because the + implementation relies crucially on sets, and because doing so allows one to avoid uses + of polymorphic comparison by instantiating the functor at a different implementation + of [Comparator] and using the resulting [stable_dedup_list]. *) + val stable_dedup_list : ('a, _) comparator -> 'a list -> 'a list + + (** [map c t ~f] returns a new set created by applying [f] to every element in + [t]. The returned set is based on the provided [comparator]. [O(n log n)]. *) + val map + : ('b, 'cmp) comparator + -> ('a, _) t + -> f:('a -> 'b) + -> ('b, 'cmp) t + + (** Like {!map}, except elements for which [f] returns [None] will be dropped. *) + val filter_map + : ('b, 'cmp) comparator + -> ('a, _) t + -> f:('a -> 'b option) + -> ('b, 'cmp) t + + (** [filter t ~f] returns the subset of [t] for which [f] evaluates to true. [O(n log + n)]. *) + val filter : ('a, 'cmp) t -> f:('a -> bool) -> ('a, 'cmp) t + + (** [fold t ~init ~f] folds over the elements of the set from smallest to largest. *) + val fold + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> 'accum) + -> 'accum + + (** [fold_result ~init ~f] folds over the elements of the set from smallest to + largest, short circuiting the fold if [f accum x] is an [Error _] *) + val fold_result + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'e) Result.t) + -> ('accum, 'e) Result.t + + (** [fold_until t ~init ~f] is a short-circuiting version of [fold]. If [f] + returns [Stop _] the computation ceases and results in that value. If [f] returns + [Continue _], the fold will proceed. *) + val fold_until + : ('a, _) t + -> init:'accum + -> f:('accum -> 'a -> ('accum, 'final) Continue_or_stop.t) + -> finish:('accum -> 'final) + -> 'final + + (** Like {!fold}, except that it goes from the largest to the smallest element. *) + val fold_right + : ('a, _) t + -> init:'accum + -> f:('a -> 'accum -> 'accum) + -> 'accum + + (** [iter t ~f] calls [f] on every element of [t], going in order from the smallest to + largest. *) + val iter : ('a, _) t -> f:('a -> unit) -> unit + + (** Iterate two sets side by side. Complexity is [O(m+n)] where [m] and [n] are the sizes + of the two input sets. As an example, with the inputs [0; 1] and [1; 2], [f] will be + called with [`Left 0]; [`Both (1, 1)]; and [`Right 2]. *) + val iter2 + : ('a, 'cmp) t + -> ('a, 'cmp) t + -> f:([`Left of 'a | `Right of 'a | `Both of 'a * 'a] -> unit) + -> unit + + (** if [a, b = partition_tf set ~f] then [a] is the elements on which [f] produced [true], + and [b] is the elements on which [f] produces [false]. *) + val partition_tf + : ('a, 'cmp) t + -> f:('a -> bool) + -> ('a, 'cmp) t * ('a, 'cmp) t + + (** Same as {!to_list}. *) + val elements : ('a, _) t -> 'a list + + (** Returns the smallest element of the set. [O(log n)]. *) + val min_elt : ('a, _) t -> 'a option + + (** Like {!min_elt}, but throws an exception when given an empty set. *) + val min_elt_exn : ('a, _) t -> 'a + + (** Returns the largest element of the set. [O(log n)]. *) + val max_elt : ('a, _) t -> 'a option + + (** Like {!max_elt}, but throws an exception when given an empty set. *) + val max_elt_exn : ('a, _) t -> 'a + + (** returns an arbitrary element, or [None] if the set is empty. *) + val choose : ('a, _) t -> 'a option + + (** Like {!choose}, but throws an exception on an empty set. *) + val choose_exn : ('a, _) t -> 'a + + (** [split t x] produces a triple [(t1, maybe_x, t2)] where [t1] is the set of elements + strictly less than [x], [maybe_x] is the member (if any) of [t] which compares equal + to [x], and [t2] is the set of elements strictly larger than [x]. *) + val split : ('a, 'cmp) t -> 'a -> ('a, 'cmp) t * 'a option * ('a, 'cmp) t + + (** if [equiv] is an equivalence predicate, then [group_by set ~equiv] produces a list + of equivalence classes (i.e., a set-theoretic quotient). E.g., + + {[ + let chars = Set.of_list ['A'; 'a'; 'b'; 'c'] in + let equiv c c' = Char.equal (Char.uppercase c) (Char.uppercase c') in + group_by chars ~equiv + ]} + + produces: + + {[ + [Set.of_list ['A';'a']; Set.singleton 'b'; Set.singleton 'c'] + ]} + + [group_by] runs in O(n^2) time, so if you have a comparison function, it's usually + much faster to use [Set.of_list]. *) + val group_by : ('a, 'cmp) t -> equiv:('a -> 'a -> bool) -> ('a, 'cmp) t list + + (** [to_sequence t] converts the set [t] to a sequence of the elements between + [greater_or_equal_to] and [less_or_equal_to] inclusive in the order indicated by + [order]. If [greater_or_equal_to > less_or_equal_to] the sequence is empty. Cost is + O(log n) up front and amortized O(1) for each element produced. *) + val to_sequence + : ?order : [ `Increasing (** default *) | `Decreasing ] + -> ?greater_or_equal_to : 'a + -> ?less_or_equal_to : 'a + -> ('a, 'cmp) t + -> 'a Sequence.t + + (** Produces the elements of the two sets between [greater_or_equal_to] and + [less_or_equal_to] in [order], noting whether each element appears in the left set, + the right set, or both. In the both case, both elements are returned, in case the + caller can distinguish between elements that are equal to the sets' comparator. Runs + in O(length t + length t'). *) + module Merge_to_sequence_element : sig + type ('a, 'b) t = ('a, 'b) Sequence.Merge_with_duplicates_element.t = + | Left of 'a + | Right of 'b + | Both of 'a * 'b + [@@deriving_inline compare, sexp] + include + sig + [@@@ocaml.warning "-32"] + val compare : + ('a -> 'a -> int) -> + ('b -> 'b -> int) -> ('a, 'b) t -> ('a, 'b) t -> int + include Ppx_sexp_conv_lib.Sexpable.S2 with type ('a,'b) t := ('a, 'b) t + end[@@ocaml.doc "@inline"] + [@@@end] + end + + val merge_to_sequence + : ?order : [ `Increasing (** default *) | `Decreasing ] + -> ?greater_or_equal_to : 'a + -> ?less_or_equal_to : 'a + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> ('a, 'a) Merge_to_sequence_element.t Sequence.t + + (** [M] is meant to be used in combination with OCaml applicative functor types: + + {[ + type string_set = Set.M(String).t + ]} + + which stands for: + + {[ + type string_set = (String.t, String.comparator_witness) Set.t + ]} + + The point is that [Set.M(String).t] supports deriving, whereas the second syntax + doesn't (because there is no such thing as, say, String.sexp_of_comparator_witness, + instead you would want to pass the comparator directly). *) + module M (Elt : sig type t type comparator_witness end) : sig + type nonrec t = (Elt.t, Elt.comparator_witness) t + end + + include For_deriving with type ('a, 'b) t := ('a, 'b) t + + (** A polymorphic Set. *) + module Poly : S_poly with type 'elt t = ('elt, Comparator.Poly.comparator_witness) t + + (** Using comparator is a similar interface as the toplevel of [Set], except the functions + take a [~comparator:('elt, 'cmp) Comparator.t] where the functions at the toplevel of + [Set] takes a [('elt, 'cmp) comparator]. *) + module Using_comparator : sig + type nonrec ('elt, 'cmp) t = ('elt, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('elt -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('elt, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('elt, 'cmp) Comparator.t + -> (Sexp.t -> 'elt) + -> Sexp.t + -> ('elt, 'cmp) t + + module Tree : sig + (** A [Tree.t] contains just the tree data structure that a set is based on, without + including the comparator. Accordingly, any operation on a [Tree.t] must also take + as an argument the corresponding comparator. *) + type ('a, 'cmp) t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('cmp -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'cmp) t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + val t_of_sexp_direct + : comparator:('elt, 'cmp) Comparator.t + -> (Sexp.t -> 'elt) + -> Sexp.t + -> ('elt, 'cmp) t + + module Named : sig + type nonrec ('a, 'cmp) t = { + tree : ('a, 'cmp) t; + name : string; + } + + val is_subset + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t + -> of_:('a, 'cmp) t + -> unit Or_error.t + + val equal + : comparator:('a, 'cmp) Comparator.t + -> ('a, 'cmp) t + -> ('a, 'cmp) t + -> unit Or_error.t + end + + include Creators_and_accessors2_with_comparator + with type ('a, 'b) set := ('a, 'b) t + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) t + with type ('a, 'b) named := ('a, 'b) Named.t + with module Named := Named + + val empty_without_value_restriction : (_, _) t + end + + include Accessors2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type ('a, 'b) named := ('a, 'b) Named.t + + include Creators2_with_comparator + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Tree.t + with type ('a, 'b) set := ('a, 'b) t + + val comparator : ('a, 'cmp) t -> ('a, 'cmp) Comparator.t + + val hash_fold_direct + : 'elt Hash.folder + -> ('elt, 'cmp) t Hash.folder + + module Empty_without_value_restriction (Elt : Comparator.S1) : sig + val empty : ('a Elt.t, Elt.comparator_witness) t + end + end + + (** {2 Modules and module types for extending [Set]} + + For use in extensions of Base, like [Core_kernel]. *) + + module With_comparator = With_comparator + module With_first_class_module = With_first_class_module + module Without_comparator = Without_comparator + + module type For_deriving = For_deriving + + module type S_poly = S_poly + module type Accessors0 = Accessors0 + module type Accessors1 = Accessors1 + module type Accessors2 = Accessors2 + module type Accessors2_with_comparator = Accessors2_with_comparator + module type Accessors_generic = Accessors_generic + module type Creators0 = Creators0 + module type Creators1 = Creators1 + module type Creators2 = Creators2 + module type Creators2_with_comparator = Creators2_with_comparator + module type Creators_and_accessors0 = Creators_and_accessors0 + module type Creators_and_accessors1 = Creators_and_accessors1 + module type Creators_and_accessors2 = Creators_and_accessors2 + module type Creators_and_accessors2_with_comparator = Creators_and_accessors2_with_comparator + module type Creators_generic = Creators_generic + module type Elt_plain = Elt_plain +end diff --git a/src/sexp.ml b/src/sexp.ml new file mode 100644 index 0000000..079acef --- /dev/null +++ b/src/sexp.ml @@ -0,0 +1,42 @@ +open Hash.Builtin +open Ppx_compare_lib.Builtin + +include (Sexplib0.Sexp : module type of Sexplib0.Sexp with type t := Sexplib0.Sexp.t) + +(** Type of S-expressions *) +type t = Sexplib0.Sexp.t = Atom of string | List of t list +[@@deriving_inline compare, hash] +let rec compare : t -> t -> int = + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match (a__001_, b__002_) with + | (Atom _a__003_, Atom _b__004_) -> compare_string _a__003_ _b__004_ + | (Atom _, _) -> (-1) + | (_, Atom _) -> 1 + | (List _a__005_, List _b__006_) -> + compare_list compare _a__005_ _b__006_) +let rec (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Atom _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 0 in + let hsv = hsv in hash_fold_string hsv _a0 + | List _a0 -> + let hsv = Ppx_hash_lib.Std.Hash.fold_int hsv 1 in + let hsv = hsv in hash_fold_list hash_fold_t hsv _a0 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) +and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x +[@@@end] + +let of_string = () diff --git a/src/sexp.mli b/src/sexp.mli new file mode 100644 index 0000000..282c13b --- /dev/null +++ b/src/sexp.mli @@ -0,0 +1,20 @@ +(** Type of S-expressions *) +type t = Sexplib0.Sexp.t = Atom of string | List of t list +[@@deriving_inline compare, hash] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +include module type of Sexplib0.Sexp with type t := Sexplib0.Sexp.t + +(** Base has never had an [of_string] function. We expose a deprecated [of_string] here + so that people can find it (e.g. with merlin), and learn what we recommend. This + [of_string] has type [unit] because we don't want it to be accidentally used. *) +val of_string : unit +[@@deprecated "[since 2018-02] Use [Parsexp.Single.parse_string_exn]"] diff --git a/src/sexp_with_comparable.ml b/src/sexp_with_comparable.ml new file mode 100644 index 0000000..87005aa --- /dev/null +++ b/src/sexp_with_comparable.ml @@ -0,0 +1,3 @@ +include Sexp + +include Comparable.Make (Sexp) diff --git a/src/sexp_with_comparable.mli b/src/sexp_with_comparable.mli new file mode 100644 index 0000000..a92ccd4 --- /dev/null +++ b/src/sexp_with_comparable.mli @@ -0,0 +1,6 @@ +(*_ This module is separated from Sexp to avoid circular dependencies as many things use + s-expressions *) + +include module type of struct include Sexp end (** @inline *) + +include Comparable.S with type t := t diff --git a/src/sexpable.ml b/src/sexpable.ml new file mode 100644 index 0000000..ad37db9 --- /dev/null +++ b/src/sexpable.ml @@ -0,0 +1,79 @@ +open! Import + +include Sexplib0.Sexpable + +module Of_sexpable + (Sexpable : S) + (M : sig + type t + val to_sexpable : t -> Sexpable.t + val of_sexpable : Sexpable.t -> t + end) + : S with type t := M.t = +struct + let t_of_sexp sexp = + let s = Sexpable.t_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t t = Sexpable.sexp_of_t (M.to_sexpable t) +end + +module Of_sexpable1 + (Sexpable : S1) + (M : sig + type 'a t + val to_sexpable : 'a t -> 'a Sexpable.t + val of_sexpable : 'a Sexpable.t -> 'a t + end) + : S1 with type 'a t := 'a M.t = +struct + let t_of_sexp a_of_sexp sexp = + let s = Sexpable.t_of_sexp a_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t sexp_of_a t = Sexpable.sexp_of_t sexp_of_a (M.to_sexpable t) +end + +module Of_sexpable2 (Sexpable : S2) + (M : sig + type ('a, 'b) t + val to_sexpable : ('a, 'b) t -> ('a, 'b) Sexpable.t + val of_sexpable : ('a, 'b) Sexpable.t -> ('a, 'b) t + end) + : S2 with type ('a, 'b) t := ('a, 'b) M.t = +struct + let t_of_sexp a_of_sexp b_of_sexp sexp = + let s = Sexpable.t_of_sexp a_of_sexp b_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t sexp_of_a sexp_of_b t = + Sexpable.sexp_of_t sexp_of_a sexp_of_b (M.to_sexpable t) +end + +module Of_sexpable3 (Sexpable : S3) + (M : sig + type ('a, 'b, 'c) t + val to_sexpable : ('a, 'b, 'c) t -> ('a, 'b, 'c) Sexpable.t + val of_sexpable : ('a, 'b, 'c) Sexpable.t -> ('a, 'b, 'c) t + end) + : S3 with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t = +struct + let t_of_sexp a_of_sexp b_of_sexp c_of_sexp sexp = + let s = Sexpable.t_of_sexp a_of_sexp b_of_sexp c_of_sexp sexp in + (try M.of_sexpable s with exn -> of_sexp_error_exn exn sexp) + + let sexp_of_t sexp_of_a sexp_of_b sexp_of_c t = + Sexpable.sexp_of_t sexp_of_a sexp_of_b sexp_of_c (M.to_sexpable t) +end + +module Of_stringable (M : Stringable.S) : S with type t := M.t = struct + let t_of_sexp sexp = + match sexp with + | Sexp.Atom s -> + (try M.of_string s with exn -> of_sexp_error_exn exn sexp) + | Sexp.List _ -> + of_sexp_error + "Sexpable.Of_stringable.t_of_sexp expected an atom, but got a list" sexp + + let sexp_of_t t = Sexp.Atom (M.to_string t) +end diff --git a/src/sexpable.mli b/src/sexpable.mli new file mode 100644 index 0000000..9c1ef50 --- /dev/null +++ b/src/sexpable.mli @@ -0,0 +1,78 @@ +(** Provides functors for making modules sexpable. New code should use the [[@@deriving + sexp]] syntax directly. These module types ([S], [S1], [S2], and [S3]) are exported + for backwards compatibility only. *) + +open! Import + +module type S = sig + type t + + val t_of_sexp : Sexp.t -> t + val sexp_of_t : t -> Sexp.t +end + +module type S1 = sig + type 'a t + + val t_of_sexp : (Sexp.t -> 'a) -> Sexp.t -> 'a t + val sexp_of_t : ('a -> Sexp.t) -> 'a t -> Sexp.t +end + +module type S2 = sig + type ('a, 'b) t + + val t_of_sexp : (Sexp.t -> 'a) -> (Sexp.t -> 'b) -> Sexp.t -> ('a, 'b) t + val sexp_of_t : ('a -> Sexp.t) -> ('b -> Sexp.t) -> ('a, 'b) t -> Sexp.t +end + +module type S3 = sig + type ('a, 'b, 'c) t + + val t_of_sexp + : (Sexp.t -> 'a) -> (Sexp.t -> 'b) -> (Sexp.t -> 'c) -> Sexp.t + -> ('a, 'b, 'c) t + + val sexp_of_t + : ('a -> Sexp.t) -> ('b -> Sexp.t) -> ('c -> Sexp.t) -> ('a, 'b, 'c) t + -> Sexp.t +end + +(** For when you want the sexp representation of one type to be the same as that for + some other isomorphic type. *) +module Of_sexpable + (Sexpable : S) + (M : sig + type t + val to_sexpable : t -> Sexpable.t + val of_sexpable : Sexpable.t -> t + end) + : S with type t := M.t + +module Of_sexpable1 + (Sexpable : S1) + (M : sig + type 'a t + val to_sexpable : 'a t -> 'a Sexpable.t + val of_sexpable : 'a Sexpable.t -> 'a t + end) + : S1 with type 'a t := 'a M.t + +module Of_sexpable2 + (Sexpable : S2) + (M : sig + type ('a, 'b) t + val to_sexpable : ('a, 'b) t -> ('a, 'b) Sexpable.t + val of_sexpable : ('a, 'b) Sexpable.t -> ('a, 'b) t + end) + : S2 with type ('a, 'b) t := ('a, 'b) M.t + +module Of_sexpable3 + (Sexpable : S3) + (M : sig + type ('a, 'b, 'c) t + val to_sexpable : ('a, 'b, 'c) t -> ('a, 'b, 'c) Sexpable.t + val of_sexpable : ('a, 'b, 'c) Sexpable.t -> ('a, 'b, 'c) t + end) + : S3 with type ('a, 'b, 'c) t := ('a, 'b, 'c) M.t + +module Of_stringable (M : Stringable.S) : S with type t := M.t diff --git a/src/sexplib.ml b/src/sexplib.ml new file mode 100644 index 0000000..148fd14 --- /dev/null +++ b/src/sexplib.ml @@ -0,0 +1,10 @@ +(** This module is for use by ppx_sexp_conv, and is thus not in the interface of + Base. *) +module Conv_error = Sexplib0.Sexp_conv_error +module Conv = Sexplib0.Sexp_conv + +(** @canonical Base.Sexp *) +module Sexp = Sexp + +(** @canonical Base.Sexpable *) +module Sexpable = Sexpable diff --git a/src/sign.ml b/src/sign.ml new file mode 100644 index 0000000..5b9848b --- /dev/null +++ b/src/sign.ml @@ -0,0 +1,27 @@ +open! Import + +include Sign0 +include Identifiable.Make(Sign0) + +(* Open [Replace_polymorphic_compare] after including functor applications so + they do not shadow its definitions. This is here so that efficient versions + of the comparison functions are available within this module. *) +open! Replace_polymorphic_compare + +let to_float = function + | Neg -> -1. + | Zero -> 0. + | Pos -> 1. + +let flip = function + | Neg -> Pos + | Zero -> Zero + | Pos -> Neg + +let ( * ) t t' = of_int (to_int t * to_int t') + +(* Include type-specific [Replace_polymorphic_compare at the end, after any + functor applications that could shadow its definitions. This is here so + that efficient versions of the comparison functions are exported by this + module. *) +include Replace_polymorphic_compare diff --git a/src/sign.mli b/src/sign.mli new file mode 100644 index 0000000..89a9d7d --- /dev/null +++ b/src/sign.mli @@ -0,0 +1,32 @@ +(** A type for representing the sign of a numeric value. *) + +open! Import + +type t = Sign0.t = Neg | Zero | Pos [@@deriving_inline enumerate, hash] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +(** This provides [to_string]/[of_string], sexp conversion, Map, Hashtbl, etc. *) +include Identifiable.S with type t := t + +val of_int : int -> t + +(** Map [Neg/Zero/Pos] to [-1/0/1] respectively. *) +val to_int : t -> int + +(** Map [Neg/Zero/Pos] to [-1./0./1.] respectively. + (There is no [of_float] here, but see {!Float.sign_exn}.) *) +val to_float : t -> float + +(** Map [Neg/Zero/Pos] to [Pos/Zero/Neg] respectively. *) +val flip : t -> t + +(** [Neg * Neg = Pos], etc. *) +val ( * ) : t -> t -> t diff --git a/src/sign0.ml b/src/sign0.ml new file mode 100644 index 0000000..a7b8b9c --- /dev/null +++ b/src/sign0.ml @@ -0,0 +1,85 @@ +(* This is broken off to avoid circular dependency between Sign and Comparable. *) + +open! Import + +type t = Neg | Zero | Pos [@@deriving_inline sexp, compare, hash, enumerate] +let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/sign0.ml.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("neg"|"Neg") -> Neg + | Ppx_sexp_conv_lib.Sexp.Atom ("zero"|"Zero") -> Zero + | Ppx_sexp_conv_lib.Sexp.Atom ("pos"|"Pos") -> Pos + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("neg"|"Neg"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("zero"|"Zero"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("pos"|"Pos"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Neg -> Ppx_sexp_conv_lib.Sexp.Atom "Neg" + | Zero -> Ppx_sexp_conv_lib.Sexp.Atom "Zero" + | Pos -> Ppx_sexp_conv_lib.Sexp.Atom "Pos" +let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare +let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Neg -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | Zero -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Pos -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) +let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x +let all : t list = [Neg; Zero; Pos] +[@@@end] + +module Replace_polymorphic_compare = struct + let ( < ) (x : t) y = Poly.( < ) x y + let ( <= ) (x : t) y = Poly.( <= ) x y + let ( <> ) (x : t) y = Poly.( <> ) x y + let ( = ) (x : t) y = Poly.( = ) x y + let ( > ) (x : t) y = Poly.( > ) x y + let ( >= ) (x : t) y = Poly.( >= ) x y + + let ascending (x : t) y = Poly.ascending x y + let descending (x : t) y = Poly.descending x y + let compare (x : t) y = Poly.compare x y + let equal (x : t) y = Poly.equal x y + let max (x : t) y = if x >= y then x else y + let min (x : t) y = if x <= y then x else y +end + +let of_string s = t_of_sexp (sexp_of_string s) +let to_string t = string_of_sexp (sexp_of_t t) + +let to_int = function + | Neg -> -1 + | Zero -> 0 + | Pos -> 1 + +let _ = hash (* Ignore the hash function produced by [@@deriving_inline hash][@@@end] *) +let hash = to_int + +let module_name = "Base.Sign" + +let of_int n = + if n < 0 + then Neg + else if n = 0 + then Zero + else Pos diff --git a/src/sign_or_nan.ml b/src/sign_or_nan.ml new file mode 100644 index 0000000..33f2268 --- /dev/null +++ b/src/sign_or_nan.ml @@ -0,0 +1,117 @@ +open! Import + +module T = struct + type t = Neg | Zero | Pos | Nan [@@deriving_inline sexp, compare, hash, enumerate] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/sign_or_nan.ml.T.t" in + function + | Ppx_sexp_conv_lib.Sexp.Atom ("neg"|"Neg") -> Neg + | Ppx_sexp_conv_lib.Sexp.Atom ("zero"|"Zero") -> Zero + | Ppx_sexp_conv_lib.Sexp.Atom ("pos"|"Pos") -> Pos + | Ppx_sexp_conv_lib.Sexp.Atom ("nan"|"Nan") -> Nan + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("neg"|"Neg"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("zero"|"Zero"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("pos"|"Pos"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + ("nan"|"Nan"))::_) as sexp -> + Ppx_sexp_conv_lib.Conv_error.stag_no_args _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.List _)::_) as sexp + -> Ppx_sexp_conv_lib.Conv_error.nested_list_invalid_sum _tp_loc sexp + | Ppx_sexp_conv_lib.Sexp.List [] as sexp -> + Ppx_sexp_conv_lib.Conv_error.empty_list_invalid_sum _tp_loc sexp + | sexp -> Ppx_sexp_conv_lib.Conv_error.unexpected_stag _tp_loc sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | Neg -> Ppx_sexp_conv_lib.Sexp.Atom "Neg" + | Zero -> Ppx_sexp_conv_lib.Sexp.Atom "Zero" + | Pos -> Ppx_sexp_conv_lib.Sexp.Atom "Pos" + | Nan -> Ppx_sexp_conv_lib.Sexp.Atom "Nan" + let compare : t -> t -> int = Ppx_compare_lib.polymorphic_compare + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + (fun hsv -> + fun arg -> + match arg with + | Neg -> Ppx_hash_lib.Std.Hash.fold_int hsv 0 + | Zero -> Ppx_hash_lib.Std.Hash.fold_int hsv 1 + | Pos -> Ppx_hash_lib.Std.Hash.fold_int hsv 2 + | Nan -> Ppx_hash_lib.Std.Hash.fold_int hsv 3 : Ppx_hash_lib.Std.Hash.state + -> + t -> + Ppx_hash_lib.Std.Hash.state) + let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x + let all : t list = [Neg; Zero; Pos; Nan] + [@@@end] + + let of_string s = t_of_sexp (sexp_of_string s) + let to_string t = string_of_sexp (sexp_of_t t) + + let module_name = "Base.Sign_or_nan" +end + +module Replace_polymorphic_compare = struct + let ( < ) (x : T.t) y = Poly.( < ) x y + let ( <= ) (x : T.t) y = Poly.( <= ) x y + let ( <> ) (x : T.t) y = Poly.( <> ) x y + let ( = ) (x : T.t) y = Poly.( = ) x y + let ( > ) (x : T.t) y = Poly.( > ) x y + let ( >= ) (x : T.t) y = Poly.( >= ) x y + + let ascending (x : T.t) y = Poly.ascending x y + let descending (x : T.t) y = Poly.descending x y + let compare (x : T.t) y = Poly.compare x y + let equal (x : T.t) y = Poly.equal x y + let max (x : T.t) y = if x >= y then x else y + let min (x : T.t) y = if x <= y then x else y +end + +include T +include Identifiable.Make(T) + +(* Open [Replace_polymorphic_compare] after including functor applications so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Replace_polymorphic_compare + +let of_sign = function + | Sign.Neg -> Neg + | Sign.Zero -> Zero + | Sign.Pos -> Pos + +let to_sign_exn = function + | Neg -> Sign.Neg + | Zero -> Sign.Zero + | Pos -> Sign.Pos + | Nan -> invalid_arg "Base.Sign_or_nan.to_sign_exn: Nan" + +let of_int n = of_sign (Sign.of_int n) + +let to_int_exn t = Sign.to_int (to_sign_exn t) + +let flip = function + | Neg -> Pos + | Zero -> Zero + | Pos -> Neg + | Nan -> Nan + +let ( * ) t t' = + match t, t' with + | Nan, _ | _, Nan -> + Nan + | _ -> + of_sign (Sign.( * ) (to_sign_exn t) (to_sign_exn t')) + +(* Include [Replace_polymorphic_compare] at the end, after any functor applications that + could shadow its definitions. This is here so that efficient versions of the comparison + functions are exported by this module. *) +include Replace_polymorphic_compare diff --git a/src/sign_or_nan.mli b/src/sign_or_nan.mli new file mode 100644 index 0000000..83d1e6f --- /dev/null +++ b/src/sign_or_nan.mli @@ -0,0 +1,34 @@ +(** An extension to [Sign] with a [Nan] constructor, for representing the sign + of float-like numeric values. *) + +open! Import + +type t = Neg | Zero | Pos | Nan [@@deriving_inline enumerate, hash] +include +sig + [@@@ocaml.warning "-32"] + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value +end[@@ocaml.doc "@inline"] +[@@@end] + +(** This provides [to_string]/[of_string], sexp conversion, Map, Hashtbl, etc. *) +include Identifiable.S with type t := t + +val of_int : int -> t + +(** Map [Neg/Zero/Pos] to [-1/0/1] respectively. [Nan] raises. *) +val to_int_exn : t -> int + +val of_sign : Sign.t -> t + +(** [Nan] raises. *) +val to_sign_exn : t -> Sign.t + +(** Map [Neg/Zero/Pos/Nan] to [Pos/Zero/Neg/Nan] respectively. *) +val flip : t -> t + +(** [Neg * Neg = Pos], etc. If either argument is [Nan] then the result is [Nan]. *) +val ( * ) : t -> t -> t diff --git a/src/source_code_position.ml b/src/source_code_position.ml new file mode 100644 index 0000000..9b4edad --- /dev/null +++ b/src/source_code_position.ml @@ -0,0 +1,20 @@ +open! Import + + +(* This is lifted out of [M] because [Source_code_position0] exports [String0] + as [String], which does not export a hash function. *) +let hash_override { Caml.Lexing. pos_fname; pos_lnum; pos_bol; pos_cnum } = + String.hash pos_fname + lxor Int.hash pos_lnum + lxor Int.hash pos_bol + lxor Int.hash pos_cnum +;; + +module M = struct + include Source_code_position0 + + let hash = hash_override +end + +include M +include Comparable.Make_using_comparator(M) diff --git a/src/source_code_position.mli b/src/source_code_position.mli new file mode 100644 index 0000000..8b9d424 --- /dev/null +++ b/src/source_code_position.mli @@ -0,0 +1,32 @@ +(** One typically obtains a [Source_code_position.t] using a [[%here]] expression, which + is implemented by the [ppx_here] preprocessor. *) + +open! Import + +(** See INRIA's OCaml documentation for a description of these fields. + + [sexp_of_t] uses the form ["FILE:LINE:COL"], and does not have a corresponding + [of_sexp]. *) +type t + = Caml.Lexing.position + = { pos_fname : string + ; pos_lnum : int + ; pos_bol : int + ; pos_cnum : int + } +[@@deriving_inline hash, sexp_of] +include +sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Comparable.S with type t := t + +(** [to_string t] converts [t] to the form ["FILE:LINE:COL"]. *) +val to_string : t -> string + diff --git a/src/source_code_position0.ml b/src/source_code_position0.ml new file mode 100644 index 0000000..92f8fd6 --- /dev/null +++ b/src/source_code_position0.ml @@ -0,0 +1,179 @@ +open! Import + +module Int = Int0 +module String = String0 + +module T = struct + type t = Caml.Lexing.position = + { pos_fname : string; + pos_lnum : int; + pos_bol : int; + pos_cnum : int; + } + [@@deriving_inline compare, hash, sexp] + let compare : t -> t -> int = + fun a__001_ -> + fun b__002_ -> + if Ppx_compare_lib.phys_equal a__001_ b__002_ + then 0 + else + (match compare_string a__001_.pos_fname b__002_.pos_fname with + | 0 -> + (match compare_int a__001_.pos_lnum b__002_.pos_lnum with + | 0 -> + (match compare_int a__001_.pos_bol b__002_.pos_bol with + | 0 -> compare_int a__001_.pos_cnum b__002_.pos_cnum + | n -> n) + | n -> n) + | n -> n) + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + fun hsv -> + fun arg -> + let hsv = + let hsv = + let hsv = let hsv = hsv in hash_fold_string hsv arg.pos_fname in + hash_fold_int hsv arg.pos_lnum in + hash_fold_int hsv arg.pos_bol in + hash_fold_int hsv arg.pos_cnum + let (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func arg = + Ppx_hash_lib.Std.Hash.get_hash_value + (let hsv = Ppx_hash_lib.Std.Hash.create () in hash_fold_t hsv arg) in + fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = + let _tp_loc = "src/source_code_position0.ml.T.t" in + function + | Ppx_sexp_conv_lib.Sexp.List field_sexps as sexp -> + let pos_fname_field = ref None + and pos_lnum_field = ref None + and pos_bol_field = ref None + and pos_cnum_field = ref None + and duplicates = ref [] + and extra = ref [] in + let rec iter = + function + | (Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + field_name)::_field_sexp::[]))::tail -> + ((match field_name with + | "pos_fname" -> + (match !pos_fname_field with + | None -> + let fvalue = string_of_sexp _field_sexp in + pos_fname_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | "pos_lnum" -> + (match !pos_lnum_field with + | None -> + let fvalue = int_of_sexp _field_sexp in + pos_lnum_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | "pos_bol" -> + (match !pos_bol_field with + | None -> + let fvalue = int_of_sexp _field_sexp in + pos_bol_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | "pos_cnum" -> + (match !pos_cnum_field with + | None -> + let fvalue = int_of_sexp _field_sexp in + pos_cnum_field := (Some fvalue) + | Some _ -> duplicates := (field_name :: (!duplicates))) + | _ -> + if !Ppx_sexp_conv_lib.Conv.record_check_extra_fields + then extra := (field_name :: (!extra)) + else ()); + iter tail) + | (Ppx_sexp_conv_lib.Sexp.List ((Ppx_sexp_conv_lib.Sexp.Atom + field_name)::[]))::tail -> + ((let _ = field_name in + if !Ppx_sexp_conv_lib.Conv.record_check_extra_fields + then extra := (field_name :: (!extra)) + else ()); + iter tail) + | (Ppx_sexp_conv_lib.Sexp.Atom _|Ppx_sexp_conv_lib.Sexp.List _ as + sexp)::_ + -> + Ppx_sexp_conv_lib.Conv_error.record_only_pairs_expected _tp_loc + sexp + | [] -> () in + (iter field_sexps; + (match !duplicates with + | _::_ -> + Ppx_sexp_conv_lib.Conv_error.record_duplicate_fields _tp_loc + (!duplicates) sexp + | [] -> + (match !extra with + | _::_ -> + Ppx_sexp_conv_lib.Conv_error.record_extra_fields _tp_loc + (!extra) sexp + | [] -> + (match ((!pos_fname_field), (!pos_lnum_field), + (!pos_bol_field), (!pos_cnum_field)) + with + | (Some pos_fname_value, Some pos_lnum_value, Some + pos_bol_value, Some pos_cnum_value) -> + { + pos_fname = pos_fname_value; + pos_lnum = pos_lnum_value; + pos_bol = pos_bol_value; + pos_cnum = pos_cnum_value + } + | _ -> + Ppx_sexp_conv_lib.Conv_error.record_undefined_elements + _tp_loc sexp + [((Ppx_sexp_conv_lib.Conv.(=) (!pos_fname_field) None), + "pos_fname"); + ((Ppx_sexp_conv_lib.Conv.(=) (!pos_lnum_field) None), + "pos_lnum"); + ((Ppx_sexp_conv_lib.Conv.(=) (!pos_bol_field) None), + "pos_bol"); + ((Ppx_sexp_conv_lib.Conv.(=) (!pos_cnum_field) None), + "pos_cnum")])))) + | Ppx_sexp_conv_lib.Sexp.Atom _ as sexp -> + Ppx_sexp_conv_lib.Conv_error.record_list_instead_atom _tp_loc sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | { pos_fname = v_pos_fname; pos_lnum = v_pos_lnum; pos_bol = v_pos_bol; + pos_cnum = v_pos_cnum } -> + let bnds = [] in + let bnds = + let arg = sexp_of_int v_pos_cnum in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_cnum"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_pos_bol in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_bol"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_pos_lnum in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_lnum"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_string v_pos_fname in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "pos_fname"; arg]) + :: bnds in + Ppx_sexp_conv_lib.Sexp.List bnds + [@@@end] +end + +include T +include Comparator.Make(T) + +(* This is the same function as Ppx_here.lift_position_as_string. *) +let make_location_string ~pos_fname ~pos_lnum ~pos_cnum ~pos_bol = + String.concat + [ pos_fname + ; ":"; Int.to_string pos_lnum + ; ":"; Int.to_string (pos_cnum - pos_bol) + ] + +let to_string {Caml.Lexing.pos_fname; pos_lnum; pos_cnum; pos_bol} = + make_location_string ~pos_fname ~pos_lnum ~pos_cnum ~pos_bol + +let sexp_of_t t = Sexp.Atom (to_string t) diff --git a/src/stack.ml b/src/stack.ml new file mode 100644 index 0000000..721df1a --- /dev/null +++ b/src/stack.ml @@ -0,0 +1,209 @@ +open! Import + +include Stack_intf + +let raise_s = Error.raise_s + +(* This implementation is similar to [Deque] in that it uses an array of ['a] and + a mutable [int] to indicate what in the array is used. We choose to implement [Stack] + directly rather than on top of [Deque] for performance reasons. E.g. a simple + microbenchmark shows that push/pop is about 20% faster. *) +type 'a t = + { mutable length : int; + mutable elts : 'a Option_array.t; + } +[@@deriving_inline sexp_of] +let sexp_of_t : + 'a . ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t = + fun _of_a -> + function + | { length = v_length; elts = v_elts } -> + let bnds = [] in + let bnds = + let arg = Option_array.sexp_of_t _of_a v_elts in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "elts"; arg]) + :: bnds in + let bnds = + let arg = sexp_of_int v_length in + (Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "length"; arg]) + :: bnds in + Ppx_sexp_conv_lib.Sexp.List bnds +[@@@end] + +let sexp_of_t_internal = sexp_of_t +let sexp_of_t = `Rebound_later +let _ = sexp_of_t + +let capacity t = Option_array.length t.elts + +let invariant invariant_a ({ length; elts } as t) : unit = + try + assert (0 <= length && length <= Option_array.length elts); + for i = 0 to length - 1 do + invariant_a (Option_array.get_some_exn elts i); + done; + (* We maintain the invariant that unused elements are unset to avoid a space + leak. *) + for i = length to Option_array.length elts - 1 do + assert (not (Option_array.is_some elts i)) + done; + with exn -> + raise_s (Sexp.message "Stack.invariant failed" + [ "exn", exn |> Exn.sexp_of_t + ; "stack", t |> sexp_of_t_internal sexp_of_opaque ]) +;; + +let create (type a) () : a t = + { length = 0; + elts = Option_array.empty; + } +;; + +let length t = t.length + +let is_empty t = length t = 0 + +(* The order in which elements are visited has been chosen so as to be backwards + compatible with both [Linked_stack] and [Caml.Stack] *) +let fold t ~init ~f = + let r = ref init in + for i = t.length - 1 downto 0 do + r := f !r (Option_array.get_some_exn t.elts i) + done; + !r +;; + +let iter t ~f = + for i = t.length - 1 downto 0 do + f (Option_array.get_some_exn t.elts i) + done; +;; + +module C = + Container.Make (struct + type nonrec 'a t = 'a t + let fold = fold + let iter = `Custom iter + let length = `Custom length + end) + +let mem = C.mem +let exists = C.exists +let for_all = C.for_all +let count = C.count +let sum = C.sum +let find = C.find +let find_map = C.find_map +let to_list = C.to_list +let to_array = C.to_array +let min_elt = C.min_elt +let max_elt = C.max_elt +let fold_result = C.fold_result +let fold_until = C.fold_until + +let of_list (type a) (l : a list) = + if List.is_empty l then + create () + else begin + let length = List.length l in + let elts = Option_array.create ~len:(2 * length) in + let r = ref l in + for i = length - 1 downto 0 do + match !r with + | [] -> assert false + | a :: l -> + Option_array.set_some elts i a; + r := l + done; + { length; elts } + end +;; + +let sexp_of_t sexp_of_a t = List.sexp_of_t sexp_of_a (to_list t) + +let t_of_sexp a_of_sexp sexp = of_list (List.t_of_sexp a_of_sexp sexp) + +let resize t size = + let arr = Option_array.create ~len:size in + Option_array.blit ~src:t.elts ~dst:arr ~src_pos:0 ~dst_pos:0 ~len:t.length; + t.elts <- arr +;; + +let set_capacity t new_capacity = + let new_capacity = max new_capacity (length t) in + if new_capacity <> capacity t then + resize t new_capacity +;; + +let push t a = + if t.length = Option_array.length t.elts then + resize t (2 * (t.length + 1)); + Option_array.set_some t.elts t.length a; + t.length <- t.length + 1; +;; + +let pop_nonempty t = + let i = t.length - 1 in + let result = Option_array.get_some_exn t.elts i in + Option_array.set_none t.elts i; + t.length <- i; + result +;; + +let pop_error = Error.of_string "Stack.pop of empty stack" + +let pop t = + if is_empty t + then None + else Some (pop_nonempty t) +;; + +let pop_exn t = + if is_empty t + then Error.raise pop_error + else pop_nonempty t +;; + +let top_nonempty t = Option_array.get_some_exn t.elts (t.length - 1) + +let top_error = Error.of_string "Stack.top of empty stack" + +let top t = + if is_empty t + then None + else Some (top_nonempty t) +;; + +let top_exn t = + if is_empty t + then Error.raise top_error + else top_nonempty t; +;; + +let copy { length; elts } = + { length; + elts = Option_array.copy elts; + } +;; + +let clear t = + if t.length > 0 then begin + for i = 0 to t.length - 1 do + Option_array.set_none t.elts i; + done; + t.length <- 0; + end; +;; + +let until_empty t f = + let rec loop () = if t.length > 0 then (f (pop_nonempty t); loop ()) in + loop () +;; + +let singleton x = + let t = create () in + push t x; + t +;; diff --git a/src/stack.mli b/src/stack.mli new file mode 100644 index 0000000..937ae3c --- /dev/null +++ b/src/stack.mli @@ -0,0 +1 @@ +include Stack_intf.Stack (** @inline *) diff --git a/src/stack_intf.ml b/src/stack_intf.ml new file mode 100644 index 0000000..f2c3ebd --- /dev/null +++ b/src/stack_intf.ml @@ -0,0 +1,79 @@ +(** An interface for stacks that follows [Core]'s conventions, as opposed to OCaml's + standard [Stack] module. *) + +open! Import + +module type S = sig + + type 'a t [@@deriving_inline sexp] + include + sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t + end[@@ocaml.doc "@inline"] + [@@@end] + + include Invariant.S1 with type 'a t := 'a t + + (** [fold], [iter], [find], and [find_map] visit the elements in order from the top of + the stack to the bottom. [to_list] and [to_array] return the elements in order from + the top of the stack to the bottom. + + Iteration functions ([iter], [fold], etc.) have unspecified behavior (although they + should still be memory-safe) when the stack is mutated while they are running (e.g. + by having the passed-in function call [push] or [pop] on the stack). + *) + include Container.S1 with type 'a t := 'a t + + (** [of_list l] returns a stack whose top is the first element of [l] and bottom is the + last element of [l]. *) + val of_list : 'a list -> 'a t + + (** [create ()] returns an empty stack. *) + val create : unit -> _ t + + (** [singleton a] creates a new stack containing only [a]. *) + val singleton : 'a -> 'a t + + (** [push t a] adds [a] to the top of stack [t]. *) + val push : 'a t -> 'a -> unit + + (** [pop t] removes and returns the top element of [t] as [Some a], or returns [None] if + [t] is empty. *) + val pop : 'a t -> 'a option + val pop_exn : 'a t -> 'a + + (** [top t] returns [Some a], where [a] is the top of [t], unless [is_empty t], in which + case [top] returns [None]. *) + val top : 'a t -> 'a option + val top_exn : 'a t -> 'a + + (** [clear t] discards all elements from [t]. *) + val clear : _ t -> unit + + (** [copy t] returns a copy of [t]. *) + val copy : 'a t -> 'a t + + (** [until_empty t f] repeatedly pops an element [a] off of [t] and runs [f a], until + [t] becomes empty. It is fine if [f] adds more elements to [t], in which case the + most-recently-added element will be processed next. *) + val until_empty : 'a t -> ('a -> unit) -> unit +end + +(** A stack implemented with an array. + + The implementation will grow the array as necessary, and will never automatically + shrink the array. One can use [set_capacity] to explicitly resize the array. *) +module type Stack = sig + module type S = S + + include S (** @open *) + + (** [capacity t] returns the length of the array backing [t]. *) + val capacity : _ t -> int + + (** [set_capacity t capacity] sets the length of the array backing [t] to [max capacity + (length t)]. To shrink as much as possible, do [set_capacity t 0]. *) + val set_capacity : _ t -> int -> unit +end + diff --git a/src/staged.ml b/src/staged.ml new file mode 100644 index 0000000..5ce96e3 --- /dev/null +++ b/src/staged.ml @@ -0,0 +1,6 @@ +open! Import + +type 'a t = 'a + +let stage = Fn.id +let unstage = Fn.id diff --git a/src/staged.mli b/src/staged.mli new file mode 100644 index 0000000..902e290 --- /dev/null +++ b/src/staged.mli @@ -0,0 +1,43 @@ +(** A type for making staging explicit in the type of a function. + + For example, you might want to have a function that creates a function for allocating + unique identifiers. Rather than using the type: + + {[ + val make_id_allocator : unit -> unit -> int + ]} + + you would have + + {[ + val make_id_allocator : unit -> (unit -> int) Staged.t + ]} + + Such a function could be defined as follows: + + {[ + let make_id_allocator () = + let ctr = ref 0 in + stage (fun () -> incr ctr; !ctr) + ]} + + and could be invoked as follows: + + {[ + let (id1,id2) = + let alloc = unstage (make_id_allocator ()) in + (alloc (), alloc ()) + ]} + + both {!stage} and {!unstage} functions are available in the toplevel namespace. + + (Note that in many cases, including perhaps the one above, it's preferable to create a + custom type rather than use [Staged].) *) + +open! Import + +type +'a t + +val stage : 'a -> 'a t +val unstage : 'a t -> 'a + diff --git a/src/string.ml b/src/string.ml new file mode 100644 index 0000000..4a2467f --- /dev/null +++ b/src/string.ml @@ -0,0 +1,1267 @@ +open! Import + +module Array = Array0 +module Bytes = Bytes0 + +include String0 + +let invalid_argf = Printf.invalid_argf + +let raise_s = Error.raise_s + +let stage = Staged.stage + +module T = struct + type t = string [@@deriving_inline hash, sexp] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_string + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_string in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = string_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_string + [@@@end] + let compare = compare +end + +include T +include Comparator.Make(T) + +type elt = char + +let is_substring_at_gen = + let rec loop ~str ~str_pos ~sub ~sub_pos ~sub_len ~char_equal = + if sub_pos = sub_len + then true + else if char_equal (unsafe_get str str_pos) (unsafe_get sub sub_pos) + then loop ~str ~str_pos:(str_pos + 1) ~sub ~sub_pos:(sub_pos + 1) ~sub_len ~char_equal + else false + in + fun str ~pos:str_pos ~substring:sub ~char_equal -> + let str_len = length str in + let sub_len = length sub in + if str_pos < 0 || str_pos > str_len + then begin + invalid_argf "String.is_substring_at: invalid index %d for string of length %d" + str_pos str_len () + end; + str_pos + sub_len <= str_len + && loop ~str ~str_pos ~sub ~sub_pos:0 ~sub_len ~char_equal + +let is_suffix_gen string ~suffix ~char_equal = + let string_len = length string in + let suffix_len = length suffix in + string_len >= suffix_len + && is_substring_at_gen string + ~pos:(string_len - suffix_len) + ~substring:suffix + ~char_equal +;; + +let is_prefix_gen string ~prefix ~char_equal = + let string_len = length string in + let prefix_len = length prefix in + string_len >= prefix_len + && is_substring_at_gen string ~pos:0 ~substring:prefix ~char_equal +;; + +module Caseless = struct + module T = struct + type t = string [@@deriving_inline sexp] + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = string_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_string + [@@@end] + + let char_compare_caseless c1 c2 = Char.compare (Char.lowercase c1) (Char.lowercase c2) + let char_equal_caseless c1 c2 = Char.equal (Char.lowercase c1) (Char.lowercase c2) + + let rec compare_loop ~pos ~string1 ~len1 ~string2 ~len2 = + if pos = len1 + then if pos = len2 + then 0 + else -1 + else if pos = len2 + then 1 + else begin + let c = char_compare_caseless (unsafe_get string1 pos) (unsafe_get string2 pos) in + match c with + | 0 -> compare_loop ~pos:(pos + 1) ~string1 ~len1 ~string2 ~len2 + | _ -> c + end + + let compare string1 string2 = + if phys_equal string1 string2 + then 0 + else begin + compare_loop ~pos:0 + ~string1 ~len1:(String.length string1) + ~string2 ~len2:(String.length string2) + end + + let hash_fold_t state t = + let len = length t in + let state = ref (hash_fold_int state len) in + for pos = 0 to len - 1 do + state := hash_fold_char !state (Char.lowercase (unsafe_get t pos)) + done; + !state + + let hash t = Hash.run hash_fold_t t + + let is_suffix s ~suffix = is_suffix_gen s ~suffix ~char_equal:char_equal_caseless + let is_prefix s ~prefix = is_prefix_gen s ~prefix ~char_equal:char_equal_caseless + end + + include T + include Comparable.Make(T) +end + +(* This is copied/adapted from 'blit.ml'. + [sub], [subo] could be implemented using [Blit.Make(Bytes)] plus unsafe casts to/from + string but were inlined here to avoid using [Bytes.unsafe_of_string] as much as possible. + Also note that [blit] and [blito] will be deprecated and removed in the future. +*) +let sub src ~pos ~len = + if pos = 0 && len = String.length src + then src + else begin + Ordered_collection_common.check_pos_len_exn ~pos ~len ~total_length:(length src); + let dst = Bytes.create len in + if len > 0 then Bytes.unsafe_blit_string ~src ~src_pos:pos ~dst ~dst_pos:0 ~len; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + end + +let subo ?(pos = 0) ?len src = + sub src ~pos ~len:(match len with Some i -> i | None -> length src - pos) + +let blit = Bytes.blit_string +let blito ~src ?(src_pos = 0) ?(src_len = length src - src_pos) ~dst ?(dst_pos = 0) () = + blit ~src ~src_pos ~len:src_len ~dst ~dst_pos + +let rec contains_unsafe t ~pos ~end_ char = + pos < end_ + && (Char.equal (unsafe_get t pos) char + || contains_unsafe t ~pos:(pos + 1) ~end_ char) + +let contains ?(pos = 0) ?len t char = + let total_length = String.length t in + let len = Option.value len ~default:(total_length - pos) in + Ordered_collection_common.check_pos_len_exn ~pos ~len ~total_length; + contains_unsafe t ~pos ~end_:(pos + len) char +;; + +let is_empty t = length t = 0 + +let index t char = + try Some (index_exn t char) + with Not_found_s _ | Caml.Not_found -> None + +let rindex t char = + try Some (rindex_exn t char) + with Not_found_s _ | Caml.Not_found -> None + +let index_from t pos char = + try Some (index_from_exn t pos char) + with Not_found_s _ | Caml.Not_found -> None + +let rindex_from t pos char = + try Some (rindex_from_exn t pos char) + with Not_found_s _ | Caml.Not_found -> None + +module Search_pattern = struct + + type t = string * int array [@@deriving_inline sexp_of] + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | (v0, v1) -> + let v0 = sexp_of_string v0 + and v1 = sexp_of_array sexp_of_int v1 in + Ppx_sexp_conv_lib.Sexp.List [v0; v1] + [@@@end] + + (* Find max number of matched characters at [next_text_char], given the current + [matched_chars]. Try to extend the current match, if chars don't match, try to match + fewer chars. If chars match then extend the match. *) + let kmp_internal_loop ~matched_chars ~next_text_char ~pattern ~kmp_arr = + let matched_chars = ref matched_chars in + while !matched_chars > 0 + && Char.( <> ) next_text_char (unsafe_get pattern !matched_chars) do + matched_chars := Array.unsafe_get kmp_arr (!matched_chars - 1) + done; + if Char.equal next_text_char (unsafe_get pattern !matched_chars) then + matched_chars := !matched_chars + 1; + !matched_chars + ;; + + (* Classic KMP pre-processing of the pattern: build the int array, which, for each i, + contains the length of the longest non-trivial prefix of s which is equal to a suffix + ending at s.[i] *) + let create pattern = + let n = length pattern in + let kmp_arr = Array.create ~len:n (-1) in + if n > 0 then begin + Array.unsafe_set kmp_arr 0 0; + let matched_chars = ref 0 in + for i = 1 to n - 1 do + matched_chars := + kmp_internal_loop + ~matched_chars:!matched_chars + ~next_text_char:(unsafe_get pattern i) + ~pattern + ~kmp_arr; + Array.unsafe_set kmp_arr i !matched_chars + done + end; + (pattern, kmp_arr) + ;; + + (* Classic KMP: use the pre-processed pattern to optimize look-behinds on non-matches. + We return int to avoid allocation in [index_exn]. -1 means no match. *) + let index_internal ?(pos=0) (pattern, kmp_arr) ~in_:text = + if pos < 0 || pos > length text - length pattern then + -1 + else begin + let j = ref pos in + let matched_chars = ref 0 in + let k = length pattern in + let n = length text in + while !j < n && !matched_chars < k do + let next_text_char = unsafe_get text !j in + matched_chars := + kmp_internal_loop + ~matched_chars:!matched_chars + ~next_text_char + ~pattern + ~kmp_arr; + j := !j + 1 + done; + if !matched_chars = k then + !j - k + else + -1 + end + ;; + + let matches t str = index_internal t ~in_:str >= 0 + + let index ?pos t ~in_ = + let p = index_internal ?pos t ~in_ in + if p < 0 then + None + else + Some p + ;; + + let index_exn ?pos t ~in_ = + let p = index_internal ?pos t ~in_ in + if p >= 0 then + p + else + raise_s (Sexp.message "Substring not found" + ["substring", sexp_of_string (fst t)]) + ;; + + let index_all (pattern, kmp_arr) ~may_overlap ~in_:text = + if length pattern = 0 then + List.init (1 + length text) ~f:Fn.id + else begin + let matched_chars = ref 0 in + let k = length pattern in + let n = length text in + let found = ref [] in + for j = 0 to n do + if !matched_chars = k then begin + found := (j - k)::!found; + (* we just found a match in the previous iteration *) + match may_overlap with + | true -> matched_chars := Array.unsafe_get kmp_arr (k - 1) + | false -> matched_chars := 0 + end; + if j < n then begin + let next_text_char = unsafe_get text j in + matched_chars := + kmp_internal_loop + ~matched_chars:!matched_chars + ~next_text_char + ~pattern + ~kmp_arr + end + done; + List.rev !found + end + ;; + + let replace_first ?pos t ~in_:s ~with_ = + match index ?pos t ~in_:s with + | None -> s + | Some i -> + let len_s = length s in + let len_t = length (fst t) in + let len_with = length with_ in + let dst = Bytes.create (len_s + len_with - len_t) in + blit ~src:s ~src_pos:0 ~dst ~dst_pos:0 ~len:i; + blit ~src:with_ ~src_pos:0 ~dst ~dst_pos:i ~len:len_with; + blit ~src:s ~src_pos:(i + len_t) ~dst ~dst_pos:(i + len_with) ~len:(len_s - i - len_t); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ;; + + + let replace_all t ~in_:s ~with_ = + let matches = index_all t ~may_overlap:false ~in_:s in + match matches with + | [] -> s + | _::_ -> + let len_s = length s in + let len_t = length (fst t) in + let len_with = length with_ in + let num_matches = List.length matches in + let dst = Bytes.create (len_s + (len_with - len_t) * num_matches) in + let next_dst_pos = ref 0 in + let next_src_pos = ref 0 in + List.iter matches ~f:(fun i -> + let len = i - !next_src_pos in + blit ~src:s ~src_pos:!next_src_pos ~dst ~dst_pos:!next_dst_pos ~len; + blit ~src:with_ ~src_pos:0 ~dst ~dst_pos:(!next_dst_pos + len) ~len:len_with; + next_dst_pos := !next_dst_pos + len + len_with; + next_src_pos := !next_src_pos + len + len_t; + ); + blit ~src:s ~src_pos:!next_src_pos ~dst ~dst_pos:!next_dst_pos + ~len:(len_s - !next_src_pos); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ;; +end + +let substr_index ?pos t ~pattern = + Search_pattern.index ?pos (Search_pattern.create pattern) ~in_:t +;; + +let substr_index_exn ?pos t ~pattern = + Search_pattern.index_exn ?pos (Search_pattern.create pattern) ~in_:t +;; + +let substr_index_all t ~may_overlap ~pattern = + Search_pattern.index_all (Search_pattern.create pattern) ~may_overlap ~in_:t +;; + +let substr_replace_first ?pos t ~pattern = + Search_pattern.replace_first ?pos (Search_pattern.create pattern) ~in_:t +;; + +let substr_replace_all t ~pattern = + Search_pattern.replace_all (Search_pattern.create pattern) ~in_:t +;; + +let is_substring t ~substring = + Option.is_some (substr_index t ~pattern:substring) +;; + +let of_string = Fn.id +let to_string = Fn.id + +let init n ~f = + if n < 0 then invalid_argf "String.init %d" n (); + let t = Bytes.create n in + for i = 0 to n - 1 do + Bytes.set t i (f i); + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t +;; + +let to_list s = + let rec loop acc i = + if i < 0 then + acc + else + loop (s.[i] :: acc) (i-1) + in + loop [] (length s - 1) + +let to_list_rev s = + let len = length s in + let rec loop acc i = + if i = len then + acc + else + loop (s.[i] :: acc) (i+1) + in + loop [] 0 + +let rev t = + let len = length t in + let res = Bytes.create len in + for i = 0 to len - 1 do + unsafe_set res i (unsafe_get t (len - 1 - i)) + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:res +;; + +(** Efficient string splitting *) + +let lsplit2_exn line ~on:delim = + let pos = index_exn line delim in + (sub line ~pos:0 ~len:pos, + sub line ~pos:(pos+1) ~len:(length line - pos - 1) + ) + +let rsplit2_exn line ~on:delim = + let pos = rindex_exn line delim in + (sub line ~pos:0 ~len:pos, + sub line ~pos:(pos+1) ~len:(length line - pos - 1) + ) + +let lsplit2 line ~on = + try Some (lsplit2_exn line ~on) with Not_found_s _ | Caml.Not_found -> None + +let rsplit2 line ~on = + try Some (rsplit2_exn line ~on) with Not_found_s _ | Caml.Not_found -> None + +let rec char_list_mem l (c:char) = + match l with + | [] -> false + | hd::tl -> Char.equal hd c || char_list_mem tl c + +let split_gen str ~on = + let is_delim = + match on with + | `char c' -> (fun c -> Char.equal c c') + | `char_list l -> (fun c -> char_list_mem l c) + in + let len = length str in + let rec loop acc last_pos pos = + if pos = -1 then + sub str ~pos:0 ~len:last_pos :: acc + else + if is_delim str.[pos] then + let pos1 = pos + 1 in + let sub_str = sub str ~pos:pos1 ~len:(last_pos - pos1) in + loop (sub_str :: acc) pos (pos - 1) + else loop acc last_pos (pos - 1) + in + loop [] len (len - 1) +;; + +let split str ~on = split_gen str ~on:(`char on) ;; + +let split_on_chars str ~on:chars = + split_gen str ~on:(`char_list chars) +;; + +let split_lines = + let back_up_at_newline ~t ~pos ~eol = + pos := !pos - (if !pos > 0 && Char.equal t.[!pos - 1] '\r' then 2 else 1); + eol := !pos + 1; + in + fun t -> + let n = length t in + if n = 0 + then [] + else + (* Invariant: [-1 <= pos < eol]. *) + let pos = ref (n - 1) in + let eol = ref n in + let ac = ref [] in + (* We treat the end of the string specially, because if the string ends with a + newline, we don't want an extra empty string at the end of the output. *) + if Char.equal t.[!pos] '\n' then back_up_at_newline ~t ~pos ~eol; + while !pos >= 0 do + if Char.( <> ) t.[!pos] '\n' + then decr pos + else + (* Because [pos < eol], we know that [start <= eol]. *) + let start = !pos + 1 in + ac := sub t ~pos:start ~len:(!eol - start) :: !ac; + back_up_at_newline ~t ~pos ~eol + done; + sub t ~pos:0 ~len:!eol :: !ac +;; + +let is_suffix s ~suffix = is_suffix_gen s ~suffix ~char_equal:Char.equal +let is_prefix s ~prefix = is_prefix_gen s ~prefix ~char_equal:Char.equal + +let is_substring_at s ~pos ~substring = + is_substring_at_gen s ~pos ~substring ~char_equal:Char.equal + +let wrap_sub_n t n ~name ~pos ~len ~on_error = + if n < 0 then + invalid_arg (name ^ " expecting nonnegative argument") + else + try + sub t ~pos ~len + with _ -> + on_error + +let drop_prefix t n = wrap_sub_n ~name:"drop_prefix" t n ~pos:n ~len:(length t - n) ~on_error:"" +let drop_suffix t n = wrap_sub_n ~name:"drop_suffix" t n ~pos:0 ~len:(length t - n) ~on_error:"" +let prefix t n = wrap_sub_n ~name:"prefix" t n ~pos:0 ~len:n ~on_error:t +let suffix t n = wrap_sub_n ~name:"suffix" t n ~pos:(length t - n) ~len:n ~on_error:t + +let lfindi ?(pos=0) t ~f = + let n = length t in + let rec loop i = + if i = n then None + else if f i t.[i] then Some i + else loop (i + 1) + in + loop pos +;; + +let find t ~f = + match lfindi t ~f:(fun _ c -> f c) with + | None -> None | Some i -> Some t.[i] + +let find_map t ~f = + let n = length t in + let rec loop i = + if i = n then None + else + match f t.[i] with + | None -> loop (i + 1) + | Some _ as res -> res + in + loop 0 +;; + +let rfindi ?pos t ~f = + let rec loop i = + if i < 0 then None + else begin + if f i t.[i] then Some i + else loop (i - 1) + end + in + let pos = + match pos with + | Some pos -> pos + | None -> length t - 1 + in + loop pos +;; + +let last_non_drop ~drop t = rfindi t ~f:(fun _ c -> not (drop c)) + +let rstrip ?(drop=Char.is_whitespace) t = + match last_non_drop t ~drop with + | None -> "" + | Some i -> + if i = length t - 1 + then t + else prefix t (i + 1) +;; + +let first_non_drop ~drop t = lfindi t ~f:(fun _ c -> not (drop c)) + +let lstrip ?(drop=Char.is_whitespace) t = + match first_non_drop t ~drop with + | None -> "" + | Some 0 -> t + | Some n -> drop_prefix t n +;; + +(* [strip t] could be implemented as [lstrip (rstrip t)]. The implementation + below saves (at least) a factor of two allocation, by only allocating the + final result. This also saves some amount of time. *) +let strip ?(drop=Char.is_whitespace) t = + let length = length t in + if length = 0 || not (drop t.[0] || drop t.[length - 1]) + then t + else + match first_non_drop t ~drop with + | None -> "" + | Some first -> + match last_non_drop t ~drop with + | None -> assert false + | Some last -> sub t ~pos:first ~len:(last - first + 1) +;; + +let mapi t ~f = + let l = length t in + let t' = Bytes.create l in + for i = 0 to l - 1 do + Bytes.unsafe_set t' i (f i t.[i]) + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t' + +(* repeated code to avoid requiring an extra allocation for a closure on each call. *) +let map t ~f = + let l = length t in + let t' = Bytes.create l in + for i = 0 to l - 1 do + Bytes.unsafe_set t' i (f t.[i]) + done; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t' + +let to_array s = Array.init (length s) ~f:(fun i -> s.[i]) + +let exists = + let rec loop s i ~len ~f = i < len && (f s.[i] || loop s (i + 1) ~len ~f) in + fun s ~f -> loop s 0 ~len:(length s) ~f +;; + +let for_all = + let rec loop s i ~len ~f = i = len || (f s.[i] && loop s (i + 1) ~len ~f) in + fun s ~f -> loop s 0 ~len:(length s) ~f +;; + +let fold t ~init ~f = + let n = length t in + let rec loop i ac = if i = n then ac else loop (i + 1) (f ac t.[i]) in + loop 0 init +;; + +let foldi t ~init ~f = + let n = length t in + let rec loop i ac = if i = n then ac else loop (i + 1) (f i ac t.[i]) in + loop 0 init +;; + +let count t ~f = Container.count ~fold t ~f +let sum m t ~f = Container.sum ~fold m t ~f + +let min_elt t = Container.min_elt ~fold t +let max_elt t = Container.max_elt ~fold t +let fold_result t ~init ~f = Container.fold_result ~fold ~init ~f t +let fold_until t ~init ~f = Container.fold_until ~fold ~init ~f t + +let mem = + let rec loop t c ~pos:i ~len = + i < len && (Char.equal c (unsafe_get t i) || loop t c ~pos:(i + 1) ~len) + in + fun t c -> + loop t c ~pos:0 ~len:(length t) +;; + +let tr ~target ~replacement s = + if Char.equal target replacement + then s + else if mem s target + then map s ~f:(fun c -> if Char.equal c target then replacement else c) + else s +;; + +let tr_inplace ~target ~replacement s = (* destructive version of tr *) + for i = 0 to Bytes.length s - 1 do + if Char.equal (Bytes.unsafe_get s i) target then Bytes.unsafe_set s i replacement + done + +let tr_multi ~target ~replacement = + if is_empty target + then stage Fn.id + else if is_empty replacement + then invalid_arg "tr_multi replacement is empty string" + else + match Bytes_tr.tr_create_map ~target ~replacement with + | None -> stage Fn.id + | Some tr_map -> + stage (fun s -> + if exists s ~f:(fun c -> Char.(<>) c (unsafe_get tr_map (Char.to_int c))) + then map s ~f:(fun c -> unsafe_get tr_map (Char.to_int c)) + else s) + +(* fast version, if we ever need it: + {[ + let concat_array ~sep ar = + let ar_len = Array.length ar in + if ar_len = 0 then "" + else + let sep_len = length sep in + let res_len_ref = ref (sep_len * (ar_len - 1)) in + for i = 0 to ar_len - 1 do + res_len_ref := !res_len_ref + length ar.(i) + done; + let res = create !res_len_ref in + let str_0 = ar.(0) in + let len_0 = length str_0 in + blit ~src:str_0 ~src_pos:0 ~dst:res ~dst_pos:0 ~len:len_0; + let pos_ref = ref len_0 in + for i = 1 to ar_len - 1 do + let pos = !pos_ref in + blit ~src:sep ~src_pos:0 ~dst:res ~dst_pos:pos ~len:sep_len; + let new_pos = pos + sep_len in + let str_i = ar.(i) in + let len_i = length str_i in + blit ~src:str_i ~src_pos:0 ~dst:res ~dst_pos:new_pos ~len:len_i; + pos_ref := new_pos + len_i + done; + res + ]} *) + +let concat_array ?sep ar = concat ?sep (Array.to_list ar) + +let concat_map ?sep s ~f = concat_array ?sep (Array.map (to_array s) ~f) + +(* [filter t f] is implemented by the following algorithm. + + Let [n = length t]. + + 1. Find the lowest [i] such that [not (f t.[i])]. + + 2. If there is no such [i], then return [t]. + + 3. If there is such an [i], allocate a string, [out], to hold the result. [out] has + length [n - 1], which is the maximum possible output size given that there is at least + one character not satisfying [f]. + + 4. Copy characters at indices 0 ... [i - 1] from [t] to [out]. + + 5. Walk through characters at indices [i+1] ... [n-1] of [t], copying those that + satisfy [f] from [t] to [out]. + + 6. If we completely filled [out], then return it. If not, return the prefix of [out] + that we did fill in. + + This algorithm has the property that it doesn't allocate a new string if there's + nothing to filter, which is a common case. *) +let filter t ~f = + let n = length t in + let i = ref 0 in + while !i < n && f t.[!i]; do + incr i + done; + if !i = n then + t + else begin + let out = Bytes.create (n - 1) in + blit ~src:t ~src_pos:0 ~dst:out ~dst_pos:0 ~len:!i; + let out_pos = ref !i in + incr i; + while !i < n; do + let c = t.[!i] in + if f c then (Bytes.set out !out_pos c; incr out_pos); + incr i + done; + let out = Bytes.unsafe_to_string ~no_mutation_while_string_reachable:out in + if !out_pos = n - 1 then + out + else + sub out ~pos:0 ~len:!out_pos + end +;; + +let chop_prefix s ~prefix = + if is_prefix s ~prefix then + Some (drop_prefix s (length prefix)) + else + None + +let chop_prefix_exn s ~prefix = + match chop_prefix s ~prefix with + | Some str -> str + | None -> + raise (Invalid_argument + (Printf.sprintf "String.chop_prefix_exn %S %S" s prefix)) + +let chop_suffix s ~suffix = + if is_suffix s ~suffix then + Some (drop_suffix s (length suffix)) + else + None + +let chop_suffix_exn s ~suffix = + match chop_suffix s ~suffix with + | Some str -> str + | None -> + raise (Invalid_argument + (Printf.sprintf "String.chop_suffix_exn %S %S" s suffix)) + +(* There used to be a custom implementation that was faster for very short strings + (peaking at 40% faster for 4-6 char long strings). + This new function is around 20% faster than the default hash function, but slower + than the previous custom implementation. However, the new OCaml function is well + behaved, and this implementation is less likely to diverge from the default OCaml + implementation does, which is a desirable property. (The only way to avoid the + divergence is to expose the macro redefined in hash_stubs.c in the hash.h header of + the OCaml compiler.) *) +module Hash = struct + external hash : string -> int = "Base_hash_string" [@@noalloc] +end + +(* [include Hash] to make the [external] version override the [hash] from + [Hashable.Make_binable], so that we get a little bit of a speedup by exposing it as + external in the mli. *) +let _ = hash +include Hash + +include Comparable.Validate (T) + +(* for interactive top-levels -- modules deriving from String should have String's pretty + printer. *) +let pp = Caml.Format.pp_print_string + +let of_char c = make 1 c + +let of_char_list l = + let t = Bytes.create (List.length l) in + List.iteri l ~f:(fun i c -> Bytes.set t i c); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:t + +module Escaping = struct + (* If this is changed, make sure to update [escape], which attempts to ensure all the + invariants checked here. *) + let build_and_validate_escapeworthy_map escapeworthy_map escape_char func = + let escapeworthy_map = + if List.Assoc.mem escapeworthy_map ~equal:Char.equal escape_char + then escapeworthy_map + else (escape_char, escape_char) :: escapeworthy_map + in + let arr = Array.create ~len:256 (-1) in + let vals = Array.create ~len:256 false in + let rec loop = function + | [] -> Ok arr + | (c_from, c_to) :: l -> + let k, v = match func with + | `Escape -> Char.to_int c_from, c_to + | `Unescape -> Char.to_int c_to, c_from + in + if arr.(k) <> -1 || vals.(Char.to_int v) then + Or_error.error_s + (Sexp.message "escapeworthy_map not one-to-one" + [ "c_from", sexp_of_char c_from + ; "c_to", sexp_of_char c_to + ; "escapeworthy_map", + sexp_of_list (sexp_of_pair sexp_of_char sexp_of_char) + escapeworthy_map + ]) + else (arr.(k) <- Char.to_int v; vals.(Char.to_int v) <- true; loop l) + in + loop escapeworthy_map + ;; + + let escape_gen ~escapeworthy_map ~escape_char = + match + build_and_validate_escapeworthy_map escapeworthy_map escape_char `Escape + with + | Error _ as x -> x + | Ok escapeworthy -> + Ok (fun src -> + (* calculate a list of (index of char to escape * escaped char) first, the order + is from tail to head *) + let to_escape_len = ref 0 in + let to_escape = + foldi src ~init:[] ~f:(fun i acc c -> + match escapeworthy.(Char.to_int c) with + | -1 -> acc + | n -> + (* (index of char to escape * escaped char) *) + incr to_escape_len; + (i, Char.unsafe_of_int n) :: acc) + in + match to_escape with + | [] -> src + | _ -> + (* [to_escape] divide [src] to [List.length to_escape + 1] pieces separated by + the chars to escape. + + Lets take + {[ + escape_gen_exn + ~escapeworthy_map:[('a', 'A'); ('b', 'B'); ('c', 'C')] + ~escape_char:'_' + ]} + for example, and assume the string to escape is + + "000a111b222c333" + + then [to_escape] is [(11, 'C'); (7, 'B'); (3, 'A')]. + + Then we create a [dst] of length [length src + 3] to store the + result, copy piece "333" to [dst] directly, then copy '_' and 'C' to [dst]; + then move on to next; after 3 iterations, copy piece "000" and we are done. + + Finally the result will be + + "000_A111_B222_C333" *) + let src_len = length src in + let dst_len = src_len + !to_escape_len in + let dst = Bytes.create dst_len in + let rec loop last_idx last_dst_pos = function + | [] -> + (* copy "000" at last *) + blit ~src ~src_pos:0 ~dst ~dst_pos:0 ~len:last_idx + | (idx, escaped_char) :: to_escape -> (*[idx] = the char to escape*) + (* take first iteration for example *) + (* calculate length of "333", minus 1 because we don't copy 'c' *) + let len = last_idx - idx - 1 in + (* set the dst_pos to copy to *) + let dst_pos = last_dst_pos - len in + (* copy "333", set [src_pos] to [idx + 1] to skip 'c' *) + blit ~src ~src_pos:(idx + 1) ~dst ~dst_pos ~len; + (* backoff [dst_pos] by 2 to copy '_' and 'C' *) + let dst_pos = dst_pos - 2 in + Bytes.set dst dst_pos escape_char; + Bytes.set dst (dst_pos + 1) escaped_char; + loop idx dst_pos to_escape + in + (* set [last_dst_pos] and [last_idx] to length of [dst] and [src] first *) + loop src_len dst_len to_escape; + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ) + ;; + + let escape_gen_exn ~escapeworthy_map ~escape_char = + Or_error.ok_exn (escape_gen ~escapeworthy_map ~escape_char) |> stage + ;; + + let escape ~escapeworthy ~escape_char = + (* For [escape_gen_exn], we don't know how to fix invalid escapeworthy_map so we have + to raise exception; but in this case, we know how to fix duplicated elements in + escapeworthy list, so we just fix it instead of raising exception to make this + function easier to use. *) + let escapeworthy_map = + escapeworthy + |> List.dedup_and_sort ~compare:Char.compare + |> List.map ~f:(fun c -> (c, c)) + in + escape_gen_exn ~escapeworthy_map ~escape_char + ;; + + (* In an escaped string, any char is either `Escaping, `Escaped or `Literal. For + example, the escape statuses of chars in string "a_a__" with escape_char = '_' are + + a : `Literal + _ : `Escaping + a : `Escaped + _ : `Escaping + _ : `Escaped + + [update_escape_status str ~escape_char i previous_status] gets escape status of + str.[i] basing on escape status of str.[i - 1] *) + let update_escape_status str ~escape_char i = function + | `Escaping -> `Escaped + | `Literal + | `Escaped -> if Char.equal str.[i] escape_char then `Escaping else `Literal + ;; + + let unescape_gen ~escapeworthy_map ~escape_char = + match + build_and_validate_escapeworthy_map escapeworthy_map escape_char `Unescape + with + | Error _ as x -> x + | Ok escapeworthy -> + Ok (fun src -> + (* Continue the example in [escape_gen_exn], now we unescape + + "000_A111_B222_C333" + + back to + + "000a111b222c333" + + Then [to_unescape] is [14; 9; 4], which is indexes of '_'s. + + Then we create a string [dst] to store the result, copy "333" to it, then copy + 'c', then move on to next iteration. After 3 iterations copy "000" and we are + done. *) + (* indexes of escape chars *) + let to_unescape = + let rec loop i status acc = + if i >= length src then acc + else + let status = update_escape_status src ~escape_char i status in + loop (i + 1) status + (match status with + | `Escaping -> i :: acc + | `Escaped | `Literal -> acc) + in + loop 0 `Literal [] + in + match to_unescape with + | [] -> src + | idx::to_unescape' -> + let dst = Bytes.create (length src - List.length to_unescape) in + let rec loop last_idx last_dst_pos = function + | [] -> + (* copy "000" at last *) + blit ~src ~src_pos:0 ~dst ~dst_pos:0 ~len:last_idx + | idx::to_unescape -> (* [idx] = index of escaping char *) + (* take 1st iteration as example, calculate the length of "333", minus 2 to + skip '_C' *) + let len = last_idx - idx - 2 in + (* point [dst_pos] to the position to copy "333" to *) + let dst_pos = last_dst_pos - len in + (* copy "333" *) + blit ~src ~src_pos:(idx + 2) ~dst ~dst_pos ~len; + (* backoff [dst_pos] by 1 to copy 'c' *) + let dst_pos = dst_pos - 1 in + Bytes.set dst dst_pos ( match escapeworthy.(Char.to_int src.[idx + 1]) with + | -1 -> src.[idx + 1] + | n -> Char.unsafe_of_int n); + (* update [last_dst_pos] and [last_idx] *) + loop idx dst_pos to_unescape + in + ( if idx < length src - 1 then + (* set [last_dst_pos] and [last_idx] to length of [dst] and [src] *) + loop (length src) (Bytes.length dst) to_unescape + else + (* for escaped string ending with an escaping char like "000_", just ignore + the last escaping char *) + loop (length src - 1) (Bytes.length dst) to_unescape' + ); + Bytes.unsafe_to_string ~no_mutation_while_string_reachable:dst + ) + ;; + + let unescape_gen_exn ~escapeworthy_map ~escape_char = + Or_error.ok_exn (unescape_gen ~escapeworthy_map ~escape_char) |> stage + ;; + + let unescape ~escape_char = + unescape_gen_exn ~escapeworthy_map:[] ~escape_char + + let preceding_escape_chars str ~escape_char pos = + let rec loop p cnt = + if (p < 0) || (Char.( <> ) str.[p] escape_char) then + cnt + else + loop (p - 1) (cnt + 1) + in + loop (pos - 1) 0 + ;; + + (* In an escaped string, any char is either `Escaping, `Escaped or `Literal. For + example, the escape statuses of chars in string "a_a__" with escape_char = '_' are + + a : `Literal + _ : `Escaping + a : `Escaped + _ : `Escaping + _ : `Escaped + + [update_escape_status str ~escape_char i previous_status] gets escape status of + str.[i] basing on escape status of str.[i - 1] *) + let update_escape_status str ~escape_char i = function + | `Escaping -> `Escaped + | `Literal + | `Escaped -> if Char.equal str.[i] escape_char then `Escaping else `Literal + ;; + + let escape_status str ~escape_char pos = + let odd = (preceding_escape_chars str ~escape_char pos) mod 2 = 1 in + match odd, Char.equal str.[pos] escape_char with + | true, (true|false) -> `Escaped + | false, true -> `Escaping + | false, false -> `Literal + ;; + + let check_bound str pos function_name = + if pos >= length str || pos < 0 then + invalid_argf "%s: out of bounds" function_name () + ;; + + let is_char_escaping str ~escape_char pos = + check_bound str pos "is_char_escaping"; + match escape_status str ~escape_char pos with + | `Escaping -> true + | `Escaped | `Literal -> false + ;; + + let is_char_escaped str ~escape_char pos = + check_bound str pos "is_char_escaped"; + match escape_status str ~escape_char pos with + | `Escaped -> true + | `Escaping | `Literal -> false + ;; + + let is_char_literal str ~escape_char pos = + check_bound str pos "is_char_literal"; + match escape_status str ~escape_char pos with + | `Literal -> true + | `Escaped | `Escaping -> false + ;; + + let index_from str ~escape_char pos char = + check_bound str pos "index_from"; + let rec loop i status = + if i >= pos + && (match status with `Literal -> true | `Escaped | `Escaping -> false) + && Char.equal str.[i] char + then Some i + else ( + let i = i + 1 in + if i >= length str then None + else loop i (update_escape_status str ~escape_char i status)) + in + loop pos (escape_status str ~escape_char pos) + ;; + + let index_from_exn str ~escape_char pos char = + match index_from str ~escape_char pos char with + | None -> + raise_s + (Sexp.message "index_from_exn: not found" + [ "str" , sexp_of_t str + ; "escape_char" , sexp_of_char escape_char + ; "pos" , sexp_of_int pos + ; "char" , sexp_of_char char + ]) + | Some pos -> pos + ;; + + let index str ~escape_char char = index_from str ~escape_char 0 char + let index_exn str ~escape_char char = index_from_exn str ~escape_char 0 char + + let rindex_from str ~escape_char pos char = + check_bound str pos "rindex_from"; + (* if the target char is the same as [escape_char], we have no way to determine which + escape_char is literal, so just return None *) + if Char.equal char escape_char then None + else + let rec loop pos = + if pos < 0 then None + else ( + let escape_chars = preceding_escape_chars str ~escape_char pos in + if escape_chars mod 2 = 0 && Char.equal str.[pos] char + then Some pos else loop (pos - escape_chars - 1)) + in + loop pos + ;; + + let rindex_from_exn str ~escape_char pos char = + match rindex_from str ~escape_char pos char with + | None -> + raise_s + (Sexp.message "rindex_from_exn: not found" + [ "str" , sexp_of_t str + ; "escape_char" , sexp_of_char escape_char + ; "pos" , sexp_of_int pos + ; "char" , sexp_of_char char + ]) + | Some pos -> pos + ;; + + let rindex str ~escape_char char = + if is_empty str + then None + else rindex_from str ~escape_char (length str - 1) char + ;; + + let rindex_exn str ~escape_char char = + rindex_from_exn str ~escape_char (length str - 1) char + ;; + + (* [split_gen str ~escape_char ~on] works similarly to [String.split_gen], with an + additional requirement: only split on literal chars, not escaping or escaped *) + let split_gen str ~escape_char ~on = + let is_delim = match on with + | `char c' -> (fun c -> Char.equal c c') + | `char_list l -> (fun c -> char_list_mem l c) + in + let len = length str in + let rec loop acc status last_pos pos = + if pos = len then + List.rev (sub str ~pos:last_pos ~len:(len - last_pos) :: acc) + else + let status = update_escape_status str ~escape_char pos status in + if (match status with `Literal -> true | `Escaped | `Escaping -> false) + && is_delim str.[pos] + then ( + let sub_str = sub str ~pos:last_pos ~len:(pos - last_pos) in + loop (sub_str :: acc) status (pos + 1) (pos + 1)) + else loop acc status last_pos (pos + 1) + in + loop [] `Literal 0 0 + ;; + + let split str ~on = split_gen str ~on:(`char on) ;; + + let split_on_chars str ~on:chars = + split_gen str ~on:(`char_list chars) + ;; + + let split_at str pos = + sub str ~pos:0 ~len:pos, + sub str ~pos:(pos + 1) ~len:(length str - pos - 1) + ;; + + let lsplit2 str ~on ~escape_char = + Option.map (index str ~escape_char on) ~f:(fun x -> split_at str x) + ;; + + let rsplit2 str ~on ~escape_char = + Option.map (rindex str ~escape_char on) ~f:(fun x -> split_at str x) + ;; + + let lsplit2_exn str ~on ~escape_char = + split_at str (index_exn str ~escape_char on) + ;; + let rsplit2_exn str ~on ~escape_char = + split_at str (rindex_exn str ~escape_char on) + ;; + + (* [last_non_drop_literal] and [first_non_drop_literal] are either both [None] or both + [Some]. If [Some], then the former is >= the latter. *) + let last_non_drop_literal ~drop ~escape_char t = + rfindi t ~f:(fun i c -> + not (drop c) + || is_char_escaping t ~escape_char i + || is_char_escaped t ~escape_char i) + let first_non_drop_literal ~drop ~escape_char t = + lfindi t ~f:(fun i c -> + not (drop c) + || is_char_escaping t ~escape_char i + || is_char_escaped t ~escape_char i) + + let rstrip_literal ?(drop=Char.is_whitespace) t ~escape_char = + match last_non_drop_literal t ~drop ~escape_char with + | None -> "" + | Some i -> + if i = length t - 1 + then t + else prefix t (i + 1) + ;; + + let lstrip_literal ?(drop=Char.is_whitespace) t ~escape_char = + match first_non_drop_literal t ~drop ~escape_char with + | None -> "" + | Some 0 -> t + | Some n -> drop_prefix t n + ;; + + (* [strip t] could be implemented as [lstrip (rstrip t)]. The implementation + below saves (at least) a factor of two allocation, by only allocating the + final result. This also saves some amount of time. *) + let strip_literal ?(drop=Char.is_whitespace) t ~escape_char = + let length = length t in + (* performance hack: avoid copying [t] in common cases *) + if length = 0 || not (drop t.[0] || drop t.[length - 1]) + then t + else + match first_non_drop_literal t ~drop ~escape_char with + | None -> "" + | Some first -> + match last_non_drop_literal t ~drop ~escape_char with + | None -> assert false + | Some last -> sub t ~pos:first ~len:(last - first + 1) + ;; +end + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! String_replace_polymorphic_compare + +let between t ~low ~high = low <= t && t <= high +let clamp_unchecked t ~min ~max = + if t < min then min else if t <= max then t else max + +let clamp_exn t ~min ~max = + assert (min <= max); + clamp_unchecked t ~min ~max + +let clamp t ~min ~max = + if min > max then + Or_error.error_s + (Sexp.message "clamp requires [min <= max]" + [ "min", T.sexp_of_t min + ; "max", T.sexp_of_t max + ]) + else + Ok (clamp_unchecked t ~min ~max) + +let create = Bytes.create +let fill = Bytes.fill + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include String_replace_polymorphic_compare diff --git a/src/string.mli b/src/string.mli new file mode 100644 index 0000000..3bf1d2c --- /dev/null +++ b/src/string.mli @@ -0,0 +1,461 @@ +(** An extension of the standard [StringLabels]. If you [open Base], you'll get these + extensions in the [String] module. *) + +open! Import + +type t = string [@@deriving_inline hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + + +val blit : (t, bytes) Blit.blit [@@deprecated "[since 2017-10] Use [Bytes.blit] instead"] +val blito : (t, bytes) Blit.blito [@@deprecated "[since 2017-10] Use [Bytes.blito] instead"] +val unsafe_blit : (t, bytes) Blit.blit [@@deprecated "[since 2017-10] Use [Bytes.unsafe_blit] instead"] + +val sub : (t, t) Blit.sub +val subo : (t, t) Blit.subo + +include Container.S0 with type t := t with type elt = char +include Identifiable.S with type t := t + +(** Maximum length of a string. *) +val max_length : int + +external length : t -> int = "%string_length" +external get : t -> int -> char = "%string_safe_get" + +(** [unsafe_get t i] is like [get t i] but does not perform bounds checking. The caller + must ensure that it is a memory-safe operation. *) +external unsafe_get : string -> int -> char = "%string_unsafe_get" + +val create : int -> bytes [@@deprecated "[since 2017-10] Use [Bytes.create] instead"] + +val make : int -> char -> t + +(** Assuming you haven't passed -unsafe-string to the compiler, strings are immutable, so + there'd be no motivation to make a copy. *) +val copy : t -> t [@@deprecated "[since 2018-03] Use [Bytes.copy] instead"] + +val init : int -> f:(int -> char) -> t + +val fill : bytes -> pos:int -> len:int -> char -> unit [@@deprecated "[since 2017-10] Use [Bytes.fill] instead"] + +(** String append. Also available unqualified, but re-exported here for documentation + purposes. + + Note that [a ^ b] must copy both [a] and [b] into a newly-allocated result string, so + [a ^ b ^ c ^ ... ^ z] is quadratic in the number of strings. [String.concat] does not + have this problem -- it allocates the result buffer only once. *) +val ( ^ ) : t -> t -> t + +(** Concatenates all strings in the list using separator [sep] (with a default separator + [""]). *) +val concat : ?sep:t -> t list -> t + +(** Special characters are represented by escape sequences, following the lexical + conventions of OCaml. *) +val escaped : t -> t + +val contains : ?pos:int -> ?len:int -> t -> char -> bool + +(** Operates on the whole string using the US-ASCII character set, + e.g. [uppercase "foo" = "FOO"]. *) +val uppercase : t -> t +val lowercase : t -> t + +(** Operates on just the first character using the US-ASCII character set, + e.g. [capitalize "foo" = "Foo"]. *) +val capitalize : t -> t +val uncapitalize : t -> t + +(** [index] gives the index of the first appearance of [char] in the string when + searching from left to right, or [None] if it's not found. [rindex] does the same but + searches from the right. + + For example, [String.index "Foo" 'o'] is [Some 1] while [String.rindex "Foo" 'o'] is + [Some 2]. + + The [_exn] versions return the actual index (instead of an option) when [char] is + found, and throw an exception otherwise. +*) + +(** [Caseless] compares and hashes strings ignoring case, so that for example + [Caseless.equal "OCaml" "ocaml"] and [Caseless.("apple" < "Banana")] are [true], and + [Caseless.Map], [Caseless.Table] lookup and [Caseless.Set] membership is + case-insensitive. + + [Caseless] also provides case-insensitive [is_suffix] and [is_prefix] functions, so + that for example [Caseless.is_suffix "OCaml" ~suffix:"AmL"] and [Caseless.is_prefix + "OCaml" ~prefix:"oc"] are [true]. *) +module Caseless : sig + type nonrec t = t [@@deriving_inline hash, sexp] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t + end[@@ocaml.doc "@inline"] + [@@@end] + include Comparable.S with type t := t + + val is_suffix : t -> suffix:t -> bool + val is_prefix : t -> prefix:t -> bool +end + +(** [index_exn] and [index_from_exn] raise [Caml.Not_found] or [Not_found_s] when [char] + cannot be found in [s]. *) +val index : t -> char -> int option +val index_exn : t -> char -> int +val index_from : t -> int -> char -> int option +val index_from_exn : t -> int -> char -> int + +(** [rindex_exn] and [rindex_from_exn] raise [Caml.Not_found] or [Not_found_s] when [char] + cannot be found in [s]. *) +val rindex : t -> char -> int option +val rindex_exn : t -> char -> int +val rindex_from : t -> int -> char -> int option +val rindex_from_exn : t -> int -> char -> int + +(** Substring search and replace functions. They use the Knuth-Morris-Pratt algorithm + (KMP) under the hood. + + The functions in the [Search_pattern] module allow the program to preprocess the + searched pattern once and then use it many times without further allocations. *) +module Search_pattern : sig + + type t [@@deriving_inline sexp_of] + include + sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** [create pattern] preprocesses [pattern] as per KMP, building an [int array] of + length [length pattern]. All inputs are valid. *) + val create : string -> t + + (** [matches pat str] returns true if [str] matches [pat] *) + val matches : t -> string -> bool + + (** [pos < 0] or [pos >= length string] result in no match (hence [index] returns + [None] and [index_exn] raises). *) + val index : ?pos:int -> t -> in_:string -> int option + val index_exn : ?pos:int -> t -> in_:string -> int + + (** [may_overlap] determines whether after a successful match, [index_all] should start + looking for another one at the very next position ([~may_overlap:true]), or jump to + the end of that match and continue from there ([~may_overlap:false]), e.g.: + + - [index_all (create "aaa") ~may_overlap:false ~in_:"aaaaBaaaaaa" = [0; 5; 8]] + - [index_all (create "aaa") ~may_overlap:true ~in_:"aaaaBaaaaaa" = [0; 1; 5; 6; 7; + 8]] + + E.g., [replace_all] internally calls [index_all ~may_overlap:false]. *) + val index_all : t -> may_overlap:bool -> in_:string -> int list + + (** Note that the result of [replace_all pattern ~in_:text ~with_:r] may still + contain [pattern], e.g., + + {[ + replace_all (create "bc") ~in_:"aabbcc" ~with_:"cb" = "aabcbc" + ]} *) + val replace_first : ?pos:int -> t -> in_:string -> with_:string -> string + val replace_all : t -> in_:string -> with_:string -> string +end + +(** Substring search and replace convenience functions. They call [Search_pattern.create] + and then forget the preprocessed pattern when the search is complete. [pos < 0] or + [pos >= length t] result in no match (hence [substr_index] returns [None] and + [substr_index_exn] raises). [may_overlap] indicates whether to report overlapping + matches, see [Search_pattern.index_all]. *) +val substr_index : ?pos:int -> t -> pattern:t -> int option +val substr_index_exn : ?pos:int -> t -> pattern:t -> int +val substr_index_all : t -> may_overlap:bool -> pattern:t -> int list + +val substr_replace_first : ?pos:int -> t -> pattern:t -> with_:t -> t + +(** As with [Search_pattern.replace_all], the result may still contain [pattern]. *) +val substr_replace_all : t -> pattern:t -> with_:t -> t + +(** [is_substring ~substring:"bar" "foo bar baz"] is true. *) +val is_substring : t -> substring:t -> bool + +(** [is_substring_at "foo bar baz" ~pos:4 ~substring:"bar"] is true. *) +val is_substring_at : t -> pos:int -> substring:t -> bool + +(** Returns the reversed list of characters contained in a list. *) +val to_list_rev : t -> char list + +(** [rev t] returns [t] in reverse order. *) +val rev : t -> t + +(** [is_suffix s ~suffix] returns [true] if [s] ends with [suffix]. *) +val is_suffix : t -> suffix:t -> bool + +(** [is_prefix s ~prefix] returns [true] if [s] starts with [prefix]. *) +val is_prefix : t -> prefix:t -> bool + +(** If the string [s] contains the character [on], then [lsplit2_exn s ~on] returns a pair + containing [s] split around the first appearance of [on] (from the left). Raises + [Caml.Not_found] or [Not_found_s] when [on] cannot be found in [s]. *) +val lsplit2_exn : t -> on:char -> t * t + +(** If the string [s] contains the character [on], then [rsplit2_exn s ~on] returns a pair + containing [s] split around the first appearance of [on] (from the right). Raises + [Caml.Not_found] or [Not_found_s] when [on] cannot be found in [s]. *) +val rsplit2_exn : t -> on:char -> t * t + +(** [lsplit2 s ~on] optionally returns [s] split into two strings around the + first appearance of [on] from the left. *) +val lsplit2 : t -> on:char -> (t * t) option + +(** [rsplit2 s ~on] optionally returns [s] split into two strings around the first + appearance of [on] from the right. *) +val rsplit2 : t -> on:char -> (t * t) option + +(** [split s ~on] returns a list of substrings of [s] that are separated by [on]. + Consecutive [on] characters will cause multiple empty strings in the result. + Splitting the empty string returns a list of the empty string, not the empty list. *) +val split : t -> on:char -> t list + +(** [split_on_chars s ~on] returns a list of all substrings of [s] that are separated by + one of the chars from [on]. [on] are not grouped. So a grouping of [on] in the + source string will produce multiple empty string splits in the result. *) +val split_on_chars : t -> on:char list -> t list + +(** [split_lines t] returns the list of lines that comprise [t]. The lines do not include + the trailing ["\n"] or ["\r\n"]. *) +val split_lines : t -> t list + +(** [lfindi ?pos t ~f] returns the smallest [i >= pos] such that [f i t.[i]], if there is + such an [i]. By default, [pos = 0]. *) +val lfindi : ?pos : int -> t -> f:(int -> char -> bool) -> int option + +(** [rfindi ?pos t ~f] returns the largest [i <= pos] such that [f i t.[i]], if there is + such an [i]. By default [pos = length t - 1]. *) +val rfindi : ?pos : int -> t -> f:(int -> char -> bool) -> int option + +(** [lstrip ?drop s] returns a string with consecutive chars satisfying [drop] (by default + white space, e.g. tabs, spaces, newlines, and carriage returns) stripped from the + beginning of [s]. *) +val lstrip : ?drop:(char -> bool) -> t -> t + +(** [rstrip ?drop s] returns a string with consecutive chars satisfying [drop] (by default + white space, e.g. tabs, spaces, newlines, and carriage returns) stripped from the end + of [s]. *) +val rstrip : ?drop:(char -> bool) -> t -> t + +(** [strip ?drop s] returns a string with consecutive chars satisfying [drop] (by default + white space, e.g. tabs, spaces, newlines, and carriage returns) stripped from the + beginning and end of [s]. *) +val strip : ?drop:(char -> bool) -> t -> t + +val map : t -> f : (char -> char) -> t + +(** Like [map], but passes each character's index to [f] along with the char. *) +val mapi : t -> f : (int -> char -> char) -> t + +(** [foldi] works similarly to [fold], but also passes the index of each character to + [f]. *) +val foldi : t -> init : 'a -> f : (int -> 'a -> char -> 'a) -> 'a + +(** Like [map], but allows the replacement of a single character with zero or two or more + characters. *) +val concat_map : ?sep:t -> t -> f : (char -> t) -> t + +(** [filter s ~f:predicate] discards characters not satisfying [predicate]. *) +val filter : t -> f : (char -> bool) -> t + +(** [tr ~target ~replacement s] replaces every instance of [target] in [s] with + [replacement]. *) +val tr : target:char -> replacement:char -> t -> t + +(** [tr_inplace ~target ~replacement s] destructively modifies [s] (in place!), replacing + every instance of [target] in [s] with [replacement]. *) +val tr_inplace : target:char -> replacement:char -> bytes -> unit +[@@deprecated "[since 2017-10] Use [Bytes.tr] instead"] + +(** [tr_multi ~target ~replacement] returns a function that replaces every + instance of a character in [target] with the corresponding character in + [replacement]. + + If [replacement] is shorter than [target], it is lengthened by repeating + its last character. Empty [replacement] is illegal unless [target] also is. + + If [target] contains multiple copies of the same character, the last + corresponding [replacement] character is used. Note that character ranges + are {b not} supported, so [~target:"a-z"] means the literal characters ['a'], + ['-'], and ['z']. *) +val tr_multi : target:t -> replacement:t -> (t -> t) Staged.t + +(** [chop_suffix_exn s ~suffix] returns [s] without the trailing [suffix], + raising [Invalid_argument] if [suffix] is not a suffix of [s]. *) +val chop_suffix_exn : t -> suffix:t -> t + +(** [chop_prefix_exn s ~prefix] returns [s] without the leading [prefix], + raising [Invalid_argument] if [prefix] is not a prefix of [s]. *) +val chop_prefix_exn : t -> prefix:t -> t + +val chop_suffix : t -> suffix:t -> t option + +val chop_prefix : t -> prefix:t -> t option + +(** [suffix s n] returns the longest suffix of [s] of length less than or equal to [n]. *) +val suffix : t -> int -> t + +(** [prefix s n] returns the longest prefix of [s] of length less than or equal to [n]. *) +val prefix : t -> int -> t + +(** [drop_suffix s n] drops the longest suffix of [s] of length less than or equal to + [n]. *) +val drop_suffix : t -> int -> t + +(** [drop_prefix s n] drops the longest prefix of [s] of length less than or equal to + [n]. *) +val drop_prefix : t -> int -> t + +(** [concat_array sep ar] like {!String.concat}, but operates on arrays. *) +val concat_array : ?sep : t -> t array -> t + +(** Slightly faster hash function on strings. *) +external hash : t -> int = "Base_hash_string" [@@noalloc] + +(** Fast equality function on strings, doesn't use [compare_val]. *) +val equal : t -> t -> bool + +val of_char : char -> t + +val of_char_list : char list -> t + +(** Operations for escaping and unescaping strings, with parameterized escape and + escapeworthy characters. Escaping/unescaping using this module is more efficient than + using Pcre. Benchmark code can be found in core/benchmarks/string_escaping.ml. *) +module Escaping : sig + (** [escape_gen_exn escapeworthy_map escape_char] returns a function that will escape a + string [s] as follows: if [(c1,c2)] is in [escapeworthy_map], then all occurrences + of [c1] are replaced by [escape_char] concatenated to [c2]. + + Raises an exception if [escapeworthy_map] is not one-to-one. If [escape_char] is + not in [escapeworthy_map], then it will be escaped to itself.*) + val escape_gen_exn + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Staged.t + + val escape_gen + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Or_error.t + + (** [escape ~escapeworthy ~escape_char s] is + {[ + escape_gen_exn ~escapeworthy_map:(List.zip_exn escapeworthy escapeworthy) + ~escape_char + ]} + Duplicates and [escape_char] will be removed from [escapeworthy]. So, no + exception will be raised *) + val escape : escapeworthy:char list -> escape_char:char -> (string -> string) Staged.t + + (** [unescape_gen_exn] is the inverse operation of [escape_gen_exn]. That is, + {[ + let escape = Staged.unstage (escape_gen_exn ~escapeworthy_map ~escape_char) in + let unescape = Staged.unstage (unescape_gen_exn ~escapeworthy_map ~escape_char) in + assert (s = unescape (escape s)) + ]} + always succeed when ~escapeworthy_map is not causing exceptions. *) + val unescape_gen_exn + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Staged.t + + val unescape_gen + : escapeworthy_map:(char * char) list + -> escape_char:char + -> (string -> string) Or_error.t + + (** [unescape ~escape_char] is defined as [unescape_gen_exn ~map:\[\] ~escape_char] *) + val unescape : escape_char:char -> (string -> string) Staged.t + + (** Any char in an escaped string is either escaping, escaped, or literal. For example, + for escaped string ["0_a0__0"] with [escape_char] as ['_'], pos 1 and 4 are + escaping, 2 and 5 are escaped, and the rest are literal. + + [is_char_escaping s ~escape_char pos] returns true if the char at [pos] is escaping, + false otherwise. *) + val is_char_escaping : string -> escape_char:char -> int -> bool + + (** [is_char_escaped s ~escape_char pos] returns true if the char at [pos] is escaped, + false otherwise. *) + val is_char_escaped : string -> escape_char:char -> int -> bool + + (** [is_char_literal s ~escape_char pos] returns true if the char at [pos] is not + escaped or escaping. *) + val is_char_literal : string -> escape_char:char -> int -> bool + + (** [index s ~escape_char char] finds the first literal (not escaped) instance of [char] + in s starting from 0. *) + val index : string -> escape_char:char -> char -> int option + val index_exn : string -> escape_char:char -> char -> int + + (** [rindex s ~escape_char char] finds the first literal (not escaped) instance of + [char] in [s] starting from the end of [s] and proceeding towards 0. *) + val rindex : string -> escape_char:char -> char -> int option + val rindex_exn : string -> escape_char:char -> char -> int + + (** [index_from s ~escape_char pos char] finds the first literal (not escaped) instance + of [char] in [s] starting from [pos] and proceeding towards the end of [s]. *) + val index_from : string -> escape_char:char -> int -> char -> int option + val index_from_exn : string -> escape_char:char -> int -> char -> int + + (** [rindex_from s ~escape_char pos char] finds the first literal (not escaped) + instance of [char] in [s] starting from [pos] and towards 0. *) + val rindex_from : string -> escape_char:char -> int -> char -> int option + val rindex_from_exn : string -> escape_char:char -> int -> char -> int + + (** [split s ~escape_char ~on] returns a list of substrings of [s] that are separated by + literal versions of [on]. Consecutive [on] characters will cause multiple empty + strings in the result. Splitting the empty string returns a list of the empty + string, not the empty list. + + E.g., [split ~escape_char:'_' ~on:',' "foo,bar_,baz" = ["foo"; "bar_,baz"]]. *) + val split : string -> on:char -> escape_char:char -> string list + + (** [split_on_chars s ~on] returns a list of all substrings of [s] that are separated by + one of the literal chars from [on]. [on] are not grouped. So a grouping of [on] in + the source string will produce multiple empty string splits in the result. + + E.g., [split_on_chars ~escape_char:'_' ~on:[',';'|'] "foo_|bar,baz|0" -> + ["foo_|bar"; "baz"; "0"]]. *) + val split_on_chars : string -> on:char list -> escape_char:char -> string list + + (** [lsplit2 s ~on ~escape_char] splits s into a pair on the first literal instance of + [on] (meaning the first unescaped instance) starting from the left. *) + val lsplit2 : string -> on:char -> escape_char:char -> (string * string) option + val lsplit2_exn : string -> on:char -> escape_char:char -> (string * string) + + (** [rsplit2 s ~on ~escape_char] splits [s] into a pair on the first literal + instance of [on] (meaning the first unescaped instance) starting from the + right. *) + val rsplit2 : string -> on:char -> escape_char:char -> (string * string) option + val rsplit2_exn : string -> on:char -> escape_char:char -> (string * string) + + (** These are the same as [lstrip], [rstrip], and [strip] for generic strings, except + that they only drop literal characters -- they do not drop characters that are + escaping or escaped. This makes sense if you're trying to get rid of junk + whitespace (for example), because escaped whitespace seems more likely to be + deliberate and not junk. *) + val lstrip_literal : ?drop:(char -> bool) -> t -> escape_char:char -> t + val rstrip_literal : ?drop:(char -> bool) -> t -> escape_char:char -> t + val strip_literal : ?drop:(char -> bool) -> t -> escape_char:char -> t +end + +val set : bytes -> int -> char -> unit [@@deprecated "[since 2017-10] Use [Bytes.set] instead"] +val unsafe_set : bytes -> int -> char -> unit [@@deprecated "[since 2017-10] Use [Bytes.unsafe_set] instead"] diff --git a/src/string0.ml b/src/string0.ml new file mode 100644 index 0000000..1cb6a9b --- /dev/null +++ b/src/string0.ml @@ -0,0 +1,62 @@ +(* [String0] defines string functions that are primitives or can be simply defined in + terms of [Caml.String]. [String0] is intended to completely express the part of + [Caml.String] that [Base] uses -- no other file in Base other than string0.ml should + use [Caml.String]. [String0] has few dependencies, and so is available early in Base's + build order. + + All Base files that need to use strings, including the subscript syntax + [x.(i)] or [x.(i) <- e] which the OCaml parser desugars into calls to + [String], and come before [Base.String] in build order should do + + {[ + module String = String0] + ]} + + Defining [module String = String0] is also necessary because it prevents + ocamldep from mistakenly causing a file to depend on [Base.String]. *) + +let capitalize = Caml.String.capitalize_ascii +let lowercase = Caml.String.lowercase_ascii +let uncapitalize = Caml.String.uncapitalize_ascii +let uppercase = Caml.String.uppercase_ascii + +open! Import0 + +module Sys = Sys0 + +module String = struct + external get : string -> int -> char = "%string_safe_get" + external length : string -> int = "%string_length" + external unsafe_get : string -> int -> char = "%string_unsafe_get" + + include Bytes_set_primitives +end + +include String + +let max_length = Sys.max_string_length + +let (^) = (^) + +let blit = Caml.String.blit +let compare = Caml.String.compare +let copy = Caml.String.copy [@@warning "-3"] +let escaped = Caml.String.escaped +let index_exn = Caml.String.index +let index_from_exn = Caml.String.index_from +let make = Caml.String.make +let rindex_exn = Caml.String.rindex +let rindex_from_exn = Caml.String.rindex_from +let sub = Caml.String.sub +let unsafe_blit = Caml.String.unsafe_blit + +let concat ?(sep = "") l = + match l with + | [] -> "" + (* The stdlib does not specialize this case because it could break existing projects. *) + | [x] -> x + | l -> Caml.String.concat ~sep l + +(* These are eta expanded in order to permute parameter order to follow Base + conventions. *) +let iter t ~f = Caml.String.iter t ~f diff --git a/src/stringable.ml b/src/stringable.ml new file mode 100644 index 0000000..f5efcd0 --- /dev/null +++ b/src/stringable.ml @@ -0,0 +1,10 @@ +(** Provides type-specific conversion functions to and from [string]. *) + +open! Import + +module type S = sig + type t + + val of_string : string -> t + val to_string : t -> string +end diff --git a/src/sys.ml b/src/sys.ml new file mode 100644 index 0000000..ad91468 --- /dev/null +++ b/src/sys.ml @@ -0,0 +1,3 @@ +open! Import + +include Sys0 diff --git a/src/sys.mli b/src/sys.mli new file mode 100644 index 0000000..5999112 --- /dev/null +++ b/src/sys.mli @@ -0,0 +1,95 @@ +(** Cross-platform system configuration values. *) + +(** The command line arguments given to the process. + The first element is the command name used to invoke the program. + The following elements are the command-line arguments given to the program. + + When running in JavaScript in the browser, it is [[| "a.out" |]]. *) +val argv : string array + +(** [interactive] is set to [true] when being executed in the [ocaml] REPL, and [false] + otherwise. *) +val interactive : bool ref + +(** [os_type] describes the operating system that the OCaml program is running on. Its + value is one of ["Unix"], ["Win32"], or ["Cygwin"]. When running in JavaScript, it is + ["Unix"]. *) +val os_type : string + +(** [unix] is [true] if [os_type = "Unix"]. *) +val unix : bool + +(** [win32] is [true] if [os_type = "Win32"]. *) +val win32 : bool + +(** [cygwin] is [true] if [os_type = "Cygwin"]. *) +val cygwin : bool + +(** Currently, the official distribution only supports [Native] and [Bytecode], + but it can be other backends with alternative compilers, for example, + JavaScript. *) +type backend_type = Sys0.backend_type = + | Native + | Bytecode + | Other of string + +(** Backend type currently executing the OCaml program. *) +val backend_type : backend_type + +(** [word_size_in_bits] is the number of bits in one word on the machine currently + executing the OCaml program. Generally speaking it will be either [32] or [64]. When + running in JavaScript, it will be [32]. *) +val word_size_in_bits : int + +(** [int_size_in_bits] is the number of bits in the [int] type. Generally, on + 32-bit platforms, its value will be [31], and on 64 bit platforms its value + will be [63]. When running in JavaScript, it will be [32]. {!Int.num_bits} + is the same as this value. *) +val int_size_in_bits : int + +(** [big_endian] is true when the program is running on a big-endian + architecture. When running in JavaScript, it will be [false]. *) +val big_endian : bool + +(** [max_string_length] is the maximum allowed length of a [string] or [Bytes.t]. + {!String.max_length} is the same as this value. *) +val max_string_length : int + +(** [max_array_length] is the maximum allowed length of an ['a array]. + {!Array.max_length} is the same as this value. *) +val max_array_length : int + +(** Returns the name of the runtime variant the program is running on. This is normally + the argument given to [-runtime-variant] at compile time, but for byte-code it can be + changed after compilation. When running in JavaScript, it will be [""]. *) +val runtime_variant : unit -> string + +(** Returns the value of the runtime parameters, in the same format as the contents of the + [OCAMLRUNPARAM] environment variable. When running in JavaScript, it will be [""]. *) +val runtime_parameters : unit -> string + +(** [ocaml_version] is the OCaml version with which the program was compiled. It is a + string of the form ["major.minor[.patchlevel][+additional-info]"], where major, minor, + and patchlevel are integers, and additional-info is an arbitrary string. The + [[.patchlevel]] and [[+additional-info]] parts may be absent. *) +val ocaml_version : string + +(** Controls whether the OCaml runtime system can emit warnings on stderr. Currently, the + only supported warning is triggered when a channel created by [open_*] functions is + finalized without being closed. Runtime warnings are enabled by default. *) +val enable_runtime_warnings : bool -> unit + +(** Returns whether runtime warnings are currently enabled. *) +val runtime_warnings_enabled : unit -> bool + +(** For the purposes of optimization, [opaque_identity] behaves like an unknown (and thus + possibly side-effecting) function. At runtime, [opaque_identity] disappears + altogether. A typical use of this function is to prevent pure computations from being + optimized away in benchmarking loops. For example: + + {[ + for _round = 1 to 100_000 do + ignore (Sys.opaque_identity (my_pure_computation ())) + done + ]} *) +external opaque_identity : 'a -> 'a = "%opaque" diff --git a/src/sys0.ml b/src/sys0.ml new file mode 100644 index 0000000..e2f69ac --- /dev/null +++ b/src/sys0.ml @@ -0,0 +1,43 @@ + +(* [Sys0] defines functions that are primitives or can be simply defined in + terms of [Caml.Sys]. [Sys0] is intended to completely express the part of + [Caml.Sys] that [Base] uses -- no other file in Base other than sys.ml + should use [Caml.Sys]. [Sys0] has few dependencies, and so is available + early in Base's build order. All Base files that need to use these + functions and come before [Base.Sys] in build order should do + [module Sys = Sys0]. Defining [module Sys = Sys0] is also necessary because + it prevents ocamldep from mistakenly causing a file to depend on [Base.Sys]. *) + +open! Import0 + +type backend_type = Caml.Sys.backend_type = + | Native + | Bytecode + | Other of string + +let backend_type = Caml.Sys.backend_type + +let interactive = Caml.Sys.interactive +let os_type = Caml.Sys.os_type +let unix = Caml.Sys.unix +let win32 = Caml.Sys.win32 +let cygwin = Caml.Sys.cygwin + +let word_size_in_bits = Caml.Sys.word_size +let int_size_in_bits = Caml.Sys.int_size +let big_endian = Caml.Sys.big_endian +let max_string_length = Caml.Sys.max_string_length +let max_array_length = Caml.Sys.max_array_length +let runtime_variant = Caml.Sys.runtime_variant +let runtime_parameters = Caml.Sys.runtime_parameters + +let argv = Caml.Sys.argv +let getenv = Caml.Sys.getenv + +let ocaml_version = Caml.Sys.ocaml_version +let enable_runtime_warnings = Caml.Sys.enable_runtime_warnings +let runtime_warnings_enabled = Caml.Sys.runtime_warnings_enabled + +external opaque_identity : 'a -> 'a = "%opaque" + +exception Break = Caml.Sys.Break diff --git a/src/t.ml b/src/t.ml new file mode 100644 index 0000000..9dd5970 --- /dev/null +++ b/src/t.ml @@ -0,0 +1,10 @@ +(** This module defines various abstract interfaces that are convenient when one needs a + module that matches a bare signature with just a type. This sometimes occurs in + functor arguments and in interfaces. *) + +open! Import + +module type T = sig type t end +module type T1 = sig type 'a t end +module type T2 = sig type ('a, 'b) t end +module type T3 = sig type ('a, 'b, 'c) t end diff --git a/src/type_equal.ml b/src/type_equal.ml new file mode 100644 index 0000000..a57271b --- /dev/null +++ b/src/type_equal.ml @@ -0,0 +1,172 @@ +open! Import + +type ('a, 'b) t = T : ('a, 'a) t [@@deriving_inline sexp_of] +let sexp_of_t : type a b. + (a -> Ppx_sexp_conv_lib.Sexp.t) -> + (b -> Ppx_sexp_conv_lib.Sexp.t) -> (a, b) t -> Ppx_sexp_conv_lib.Sexp.t + = fun _of_a -> fun _of_b -> function | T -> Ppx_sexp_conv_lib.Sexp.Atom "T" +[@@@end] +type ('a, 'b) equal = ('a, 'b) t + +let refl = T + +let sym (type a) (type b) (T : (a, b) t) = (T : (b, a) t) + +let trans (type a) (type b) (type c) (T : (a, b) t) (T : (b, c) t) = (T : (a, c) t) + +let conv (type a) (type b) (T : (a, b) t) (a : a) = (a : b) + +module Lift (X : sig type 'a t end) = struct + let lift (type a) (type b) (T : (a, b) t) = (T : (a X.t, b X.t) t) +end + +module Lift2 (X : sig type ('a1, 'a2) t end) = struct + let lift (type a1) (type b1) (type a2) (type b2) (T : (a1, b1) t) (T : (a2, b2) t) = + (T : ((a1, a2) X.t, (b1, b2) X.t) t) + ;; +end + +module Lift3 (X : sig type ('a1, 'a2, 'a3) t end) = struct + let lift (type a1) (type b1) (type a2) (type b2) (type a3) (type b3) + (T : (a1, b1) t) (T : (a2, b2) t) (T : (a3, b3) t) = + (T : ((a1, a2, a3) X.t, (b1, b2, b3) X.t) t) + ;; +end + +let detuple2 (type a1) (type a2) (type b1) (type b2) + (T : (a1 * a2, b1 * b2) t) : (a1, b1) t * (a2, b2) t = + T, T +;; + +let tuple2 (type a1) (type a2) (type b1) (type b2) + (T : (a1, b1) t) (T : (a2, b2) t) : (a1 * a2, b1 * b2) t = + T +;; + +module type Injective = sig + type 'a t + val strip : ('a t, 'b t) equal -> ('a, 'b) equal +end + +module type Injective2 = sig + type ('a1, 'a2) t + val strip : (('a1, 'a2) t, ('b1, 'b2) t) equal -> ('a1, 'b1) equal * ('a2, 'b2) equal +end + +module Composition_preserves_injectivity (M1 : Injective) (M2 : Injective) = struct + type 'a t = 'a M1.t M2.t + let strip e = M1.strip (M2.strip e) +end + +module Obj = struct + module Extension_constructor = struct + [@@@ocaml.warning "-3"] + let id = Caml.Obj.extension_id + let of_val = Caml.Obj.extension_constructor + end +end + +module Id = struct + module Uid = Int + + module Witness = struct + module Key = struct + type _ t = .. + + type type_witness_int = [ `type_witness of int ] [@@deriving_inline sexp_of] + let sexp_of_type_witness_int : type_witness_int -> Ppx_sexp_conv_lib.Sexp.t = + function + | `type_witness v0 -> + Ppx_sexp_conv_lib.Sexp.List + [Ppx_sexp_conv_lib.Sexp.Atom "type_witness"; sexp_of_int v0] + [@@@end] + + let sexp_of_t _sexp_of_a t = + (`type_witness (Obj.Extension_constructor.id (Obj.Extension_constructor.of_val t))) + |> sexp_of_type_witness_int + ;; + end + + module type S = sig + type t + type _ Key.t += Key : t Key.t + end + + type 'a t = (module S with type t = 'a) + + let sexp_of_t (type a) sexp_of_a (module M : S with type t = a) = + M.Key |> Key.sexp_of_t sexp_of_a + ;; + + let create (type t) () = + let module M = struct + type nonrec t = t + type _ Key.t += Key : t Key.t + end in + (module M : S with type t = t) + ;; + + let uid (type a) (module M : S with type t = a) = + Obj.Extension_constructor.id (Obj.Extension_constructor.of_val M.Key) + + (* We want a constant allocated once that [same] can return whenever it gets the same + witnesses. If we write the constant inside the body of [same], the native-code + compiler will do the right thing and lift it out. But for clarity and robustness, + we do it ourselves. *) + let some_t = Some T + + let same (type a) (type b) (a : a t) (b : b t) : (a, b) equal option = + let module A = (val a : S with type t = a) in + let module B = (val b : S with type t = b) in + match A.Key with + | B.Key -> some_t + | _ -> None + ;; + end + + + type 'a t = + { witness : 'a Witness.t + ; name : string + ; to_sexp : 'a -> Sexp.t + } + + let sexp_of_t _ { witness; name; to_sexp } : Sexp.t = + if am_testing + then Atom name + else List [ List [ Atom "name"; Atom name ] + ; List [ Atom "witness"; witness |> Witness.sexp_of_t to_sexp ]] + ;; + + let to_sexp t x = t.to_sexp x + let name t = t.name + + let create ~name to_sexp = + { witness = Witness.create () + ; name + ; to_sexp + } + ;; + + let uid t = Witness.uid t.witness + + let hash t = uid t + + let hash_fold_t s t = hash_fold_int s (uid t) + + let same_witness t1 t2 = Witness.same t1.witness t2.witness + + let same t1 t2 = Option.is_some (same_witness t1 t2) + + let same_witness_exn t1 t2 = + match same_witness t1 t2 with + | Some w -> w + | None -> + Error.raise_s + (Sexp.message "Type_equal.Id.same_witness_exn got different ids" + [ "", + sexp_of_pair (sexp_of_t sexp_of_opaque) (sexp_of_t sexp_of_opaque) (t1, t2) + ]) + ;; +end + diff --git a/src/type_equal.mli b/src/type_equal.mli new file mode 100644 index 0000000..eb33bf4 --- /dev/null +++ b/src/type_equal.mli @@ -0,0 +1,236 @@ +(** The purpose of [Type_equal] is to represent type equalities that the type checker + otherwise would not know, perhaps because the type equality depends on dynamic data, + or perhaps because the type system isn't powerful enough. + + A value of type [(a, b) Type_equal.t] represents that types [a] and [b] are equal. + One can think of such a value as a proof of type equality. The [Type_equal] module + has operations for constructing and manipulating such proofs. For example, the + functions [refl], [sym], and [trans] express the usual properties of reflexivity, + symmetry, and transitivity of equality. + + If one has a value [t : (a, b) Type_equal.t] that proves types [a] and [b] are equal, + there are two ways to use [t] to safely convert a value of type [a] to a value of type + [b]: [Type_equal.conv] or pattern matching on [Type_equal.T]: + + {[ + let f (type a) (type b) (t : (a, b) Type_equal.t) (a : a) : b = + Type_equal.conv t a + + let f (type a) (type b) (t : (a, b) Type_equal.t) (a : a) : b = + let Type_equal.T = t in a + ]} + + At runtime, conversion by either means is just the identity -- nothing is changing + about the value. Consistent with this, a value of type [Type_equal.t] is always just + a constructor [Type_equal.T]; the value has no interesting semantic content. + [Type_equal] gets its power from the ability to, in a type-safe way, prove to the type + checker that two types are equal. The [Type_equal.t] value that is passed is + necessary for the type-checker's rules to be correct, but the compiler could, in + principle, not pass around values of type [Type_equal.t] at runtime. +*) + +open! Import +open T + +type ('a, 'b) t = T : ('a, 'a) t [@@deriving_inline sexp_of] +include +sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> + ('b -> Ppx_sexp_conv_lib.Sexp.t) -> + ('a, 'b) t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] +type ('a, 'b) equal = ('a, 'b) t (** just an alias, needed when [t] gets shadowed below *) + +(** [refl], [sym], and [trans] construct proofs that type equality is reflexive, + symmetric, and transitive. *) + +val refl : ('a, 'a) t +val sym : ('a, 'b) t -> ('b, 'a) t +val trans : ('a, 'b) t -> ('b, 'c) t -> ('a, 'c) t + +(** [conv t x] uses the type equality [t : (a, b) t] as evidence to safely cast [x] + from type [a] to type [b]. [conv] is semantically just the identity function. + + In a program that has [t : (a, b) t] where one has a value of type [a] that one wants + to treat as a value of type [b], it is often sufficient to pattern match on + [Type_equal.T] rather than use [conv]. However, there are situations where OCaml's + type checker will not use the type equality [a = b], and one must use [conv]. For + example: + + {[ + module F (M1 : sig type t end) (M2 : sig type t end) : sig + val f : (M1.t, M2.t) equal -> M1.t -> M2.t + end = struct + let f equal (m1 : M1.t) = conv equal m1 + end + ]} + + If one wrote the body of [F] using pattern matching on [T]: + + {[ + let f (T : (M1.t, M2.t) equal) (m1 : M1.t) = (m1 : M2.t) + ]} + + this would give a type error. *) +val conv : ('a, 'b) t -> 'a -> 'b + +(** It is always safe to conclude that if type [a] equals [b], then for any type ['a t], + type [a t] equals [b t]. The OCaml type checker uses this fact when it can. However, + sometimes, e.g., when using [conv], one needs to explicitly use this fact to construct + an appropriate [Type_equal.t]. The [Lift*] functors do this. *) + +module Lift (X : T1) : sig + val lift : ('a, 'b) t -> ('a X.t, 'b X.t) t +end + +module Lift2 (X : T2) : sig + val lift : ('a1, 'b1) t -> ('a2, 'b2) t -> (('a1, 'a2) X.t, ('b1, 'b2) X.t) t +end + +module Lift3 (X : T3) : sig + val lift + : ('a1, 'b1) t + -> ('a2, 'b2) t + -> ('a3, 'b3) t + -> (('a1, 'a2, 'a3) X.t, ('b1, 'b2, 'b3) X.t) t +end + +(** [tuple2] and [detuple2] convert between equality on a 2-tuple and its components. *) + +val detuple2 : ('a1 * 'a2, 'b1 * 'b2) t -> ('a1, 'b1) t * ('a2, 'b2) t +val tuple2 : ('a1, 'b1) t -> ('a2, 'b2) t -> ('a1 * 'a2, 'b1 * 'b2) t + +(** [Injective] is an interface that states that a type is injective, where the type is + viewed as a function from types to other types. The typical usage is: + + {[ + type 'a t + include Injective with type 'a t := 'a t + ]} + + For example, ['a list] is an injective type, because whenever ['a list = 'b list], we + know that ['a] = ['b]. On the other hand, if we define: + + {[ + type 'a t = unit + ]} + + then clearly [t] isn't injective, because, e.g., [int t = bool t], but [int <> bool]. + + If [module M : Injective], then [M.strip] provides a way to get a proof that two types + are equal from a proof that both types transformed by [M.t] are equal. + + OCaml has no built-in language feature to state that a type is injective, which is why + we have [module type Injective]. However, OCaml can infer that a type is injective, + and we can use this to match [Injective]. A typical implementation will look like + this: + + {[ + let strip (type a) (type b) + (Type_equal.T : (a t, b t) Type_equal.t) : (a, b) Type_equal.t = + Type_equal.T + ]} + + This will not type check for all type constructors (certainly not for non-injective + ones!), but it's always safe to try the above implementation if you are unsure. If + OCaml accepts this definition, then the type is injective. On the other hand, if + OCaml doesn't, then the type may or may not be injective. For example, if the + definition of the type depends on abstract types that match [Injective], OCaml will + not automatically use their injectivity, and one will have to write a more complicated + definition of [strip] that causes OCaml to use that fact. For example: + + {[ + module F (M : Type_equal.Injective) : Type_equal.Injective = struct + type 'a t = 'a M.t * int + + let strip (type a) (type b) + (e : (a t, b t) Type_equal.t) : (a, b) Type_equal.t = + let e1, _ = Type_equal.detuple2 e in + M.strip e1 + ;; + end + ]} + + If in the definition of [F] we had written the simpler implementation of [strip] that + didn't use [M.strip], then OCaml would have reported a type error. +*) +module type Injective = sig + type 'a t + val strip : ('a t, 'b t) equal -> ('a, 'b) equal +end + +(** [Injective2] is for a binary type that is injective in both type arguments. *) +module type Injective2 = sig + type ('a1, 'a2) t + val strip : (('a1, 'a2) t, ('b1, 'b2) t) equal -> ('a1, 'b1) equal * ('a2, 'b2) equal +end + +(** [Composition_preserves_injectivity] is a functor that proves that composition of + injective types is injective. *) +module Composition_preserves_injectivity (M1 : Injective) (M2 : Injective) + : Injective with type 'a t = 'a M1.t M2.t + +(** [Id] provides identifiers for types, and the ability to test (via [Id.same]) at + runtime if two identifiers are equal, and if so to get a proof of equality of their + types. Unlike values of type [Type_equal.t], values of type [Id.t] do have semantic + content and must have a nontrivial runtime representation. *) +module Id : sig + type 'a t [@@deriving_inline sexp_of] + include + sig + [@@@ocaml.warning "-32"] + val sexp_of_t : + ('a -> Ppx_sexp_conv_lib.Sexp.t) -> 'a t -> Ppx_sexp_conv_lib.Sexp.t + end[@@ocaml.doc "@inline"] + [@@@end] + + (** Every [Id.t] contains a unique id that is distinct from the [Uid.t] in any other + [Id.t]. *) + module Uid : sig + type t [@@deriving_inline hash] + include + sig + [@@@ocaml.warning "-32"] + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + end[@@ocaml.doc "@inline"] + [@@@end] + include Sexpable.S with type t := t + include Comparable.S with type t := t + end + + val uid : _ t -> Uid.t + + (** [create ~name] defines a new type identity. Two calls to [create] will result in + two distinct identifiers, even for the same arguments with the same type. If the + type ['a] doesn't support sexp conversion, then a good practice is to have the + converter be [<:sexp_of< _ >>], (or [sexp_of_opaque], if not using pa_sexp). *) + val create + : name:string + -> ('a -> Sexp.t) + -> 'a t + + (** Accessors *) + + val hash : _ t -> int + val name : _ t -> string + val to_sexp : 'a t -> 'a -> Sexp.t + + val hash_fold_t : Hash.state -> _ t -> Hash.state + + (** [same_witness t1 t2] and [same_witness_exn t1 t2] return a type equality proof iff + the two identifiers are the same (i.e., physically equal, resulting from the same + call to [create]). This is a useful way to achieve a sort of dynamic typing. + [same_witness] does not allocate a [Some] every time it is called. + + [same t1 t2 = is_some (same_witness t1 t2)]. + *) + + val same : _ t -> _ t -> bool + val same_witness : 'a t -> 'b t -> ('a, 'b) equal option + val same_witness_exn : 'a t -> 'b t -> ('a, 'b) equal +end diff --git a/src/uchar.ml b/src/uchar.ml new file mode 100644 index 0000000..57bbe4d --- /dev/null +++ b/src/uchar.ml @@ -0,0 +1,81 @@ +open! Import + +let failwithf = Printf.failwithf + +module T = struct + include Uchar0 + + let module_name = "Base.Uchar" + + let hash_fold_t state t = Hash.fold_int state (to_int t) + let hash t = Hash.run hash_fold_t t + + let to_string t = + Printf.sprintf "U+%04X" (to_int t) + (* Do not actually export this. See discussion in the .mli *) + + let sexp_of_t t = Sexp.Atom (to_string t) + let t_of_sexp sexp = + match sexp with + | Sexp.List _ -> of_sexp_error "Uchar.t_of_sexp: atom needed" sexp + | Sexp.Atom s -> + try + Caml.Scanf.sscanf s "U+%X" (fun i -> Uchar0.of_int i) + with _ -> + of_sexp_error "Uchar.t_of_sexp: atom of the form U+XXXX needed" sexp +end + +include T +include Pretty_printer.Register(T) +include Comparable.Make(T) + +(* Open replace_polymorphic_compare after including functor instantiations so they do not + shadow its definitions. This is here so that efficient versions of the comparison + functions are available within this module. *) +open! Uchar_replace_polymorphic_compare + +let int_is_scalar = is_valid + +let succ_exn c = + try Uchar0.succ c + with Invalid_argument msg -> failwithf "Uchar.succ_exn: %s" msg () + +let succ c = + try Some (Uchar0.succ c) + with Invalid_argument _ -> None + +let pred_exn c = + try Uchar0.pred c + with Invalid_argument msg -> failwithf "Uchar.pred_exn: %s" msg () + +let pred c = + try Some (Uchar0.pred c) + with Invalid_argument _ -> None + +let of_scalar i = + if int_is_scalar i + then Some (unsafe_of_int i) + else None + +let of_scalar_exn i = + if int_is_scalar i + then unsafe_of_int i + else failwithf "Uchar.of_int_exn got a invalid Unicode scalar value: %04X" i () + +let to_scalar t = Uchar0.to_int t + +let to_char c = + if is_char c + then Some (unsafe_to_char c) + else None + +let to_char_exn c = + if is_char c + then unsafe_to_char c + else failwithf "Uchar.to_char_exn got a non latin-1 character: U+%04X" (to_int c) () + +(* Include type-specific [Replace_polymorphic_compare] at the end, after + including functor application that could shadow its definitions. This is + here so that efficient versions of the comparison functions are exported by + this module. *) +include Uchar_replace_polymorphic_compare diff --git a/src/uchar.mli b/src/uchar.mli new file mode 100644 index 0000000..541b22f --- /dev/null +++ b/src/uchar.mli @@ -0,0 +1,55 @@ +(** Unicode character operations. *) + +open! Import + +type t = Uchar0.t [@@deriving_inline compare, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Comparable.S with type t := t +include Pretty_printer.S with type t := t + + +(** [succ_exn t] is the scalar value after [t] in the set of Unicode scalar values, and + raises if [t = max_value]. *) +val succ : t -> t option +val succ_exn : t -> t + +(** [pred_exn t] is the scalar value before [t] in the set of Unicode scalar values, and + raises if [t = min_value]. *) +val pred : t -> t option +val pred_exn : t -> t + +(** [is_char t] is [true] iff [n] is in the latin-1 character set. *) +val is_char : t -> bool + +(** [to_char_exn t] is [t] as a [char] if it is in the latin-1 character set, and raises + otherwise. *) +val to_char : t -> char option +val to_char_exn : t -> char + +(** [of_char c] is [c] as a Unicode character. *) +val of_char : char -> t + +(** [int_is_scalar n] is [true] iff [n] is an Unicode scalar value (i.e., in the ranges + [0x0000]...[0xD7FF] or [0xE000]...[0x10FFFF]). *) +val int_is_scalar : int -> bool + +(** [of_scalar_exn n] is [n] as a Unicode character. Raises if [not (int_is_scalar + i)]. *) +val of_scalar : int -> t option +val of_scalar_exn : int -> t + +(** [to_scalar t] is [t] as an integer scalar value. *) +val to_scalar : t -> int + +val min_value : t +val max_value : t diff --git a/src/uchar0.ml b/src/uchar0.ml new file mode 100644 index 0000000..59baacd --- /dev/null +++ b/src/uchar0.ml @@ -0,0 +1,21 @@ +open! Import0 + +type t = Caml.Uchar.t + +let succ = Caml.Uchar.succ +let pred = Caml.Uchar.pred +let is_valid = Caml.Uchar.is_valid +let is_char = Caml.Uchar.is_char +let unsafe_to_char = Caml.Uchar.unsafe_to_char +let unsafe_of_int = Caml.Uchar.unsafe_of_int + +let of_int = Caml.Uchar.of_int +let to_int = Caml.Uchar.to_int + +let of_char = Caml.Uchar.of_char + +let compare = Caml.Uchar.compare +let equal = Caml.Uchar.equal + +let min_value = Caml.Uchar.min +let max_value = Caml.Uchar.max diff --git a/src/uniform_array.ml b/src/uniform_array.ml new file mode 100644 index 0000000..1e5deae --- /dev/null +++ b/src/uniform_array.ml @@ -0,0 +1,119 @@ +open! Import + +(* WARNING: + We use non-memory-safe things throughout the [Trusted] module. + Most of it is only safe in combination with the type signature (e.g. exposing + [val copy : 'a t -> 'b t] would be a big mistake). *) +module Trusted : sig + + type 'a t + val empty : 'a t + val unsafe_create_uninitialized : len:int -> 'a t + val create_obj_array : len:int -> 'a t + val create : len:int -> 'a -> 'a t + val singleton : 'a -> 'a t + val get : 'a t -> int -> 'a + val set : 'a t -> int -> 'a -> unit + val swap : _ t -> int -> int -> unit + val unsafe_get : 'a t -> int -> 'a + val unsafe_set : 'a t -> int -> 'a -> unit + val unsafe_set_omit_phys_equal_check : 'a t -> int -> 'a -> unit + val unsafe_set_int : 'a t -> int -> int -> unit + val unsafe_set_int_assuming_currently_int : 'a t -> int -> int -> unit + val unsafe_set_assuming_currently_int : 'a t -> int -> 'a -> unit + val length : 'a t -> int + val unsafe_blit : ('a t, 'a t) Blit.blit + val copy : 'a t -> 'a t + val unsafe_truncate : 'a t -> len:int -> unit + val unsafe_clear_if_pointer : _ t -> int -> unit +end = struct + + type 'a t = Obj_array.t + + let empty = Obj_array.empty + + let unsafe_create_uninitialized ~len = Obj_array.create_zero ~len + let create_obj_array ~len = Obj_array.create_zero ~len + let create ~len x = Obj_array.create ~len (Caml.Obj.repr x) + let singleton x = Obj_array.singleton (Caml.Obj.repr x) + + let swap t i j = Obj_array.swap t i j + + let get arr i = Caml.Obj.obj (Obj_array.get arr i) + let set arr i x = Obj_array.set arr i (Caml.Obj.repr x) + let unsafe_get arr i = Caml.Obj.obj (Obj_array.unsafe_get arr i) + let unsafe_set arr i x = + Obj_array.unsafe_set arr i (Caml.Obj.repr x) + let unsafe_set_int arr i x = + Obj_array.unsafe_set_int arr i x + let unsafe_set_int_assuming_currently_int arr i x = + Obj_array.unsafe_set_int_assuming_currently_int arr i x + let unsafe_set_assuming_currently_int arr i x = + Obj_array.unsafe_set_assuming_currently_int arr i (Caml.Obj.repr x) + + let length = Obj_array.length + + let unsafe_blit = Obj_array.unsafe_blit + + let copy = Obj_array.copy + + let unsafe_truncate = Obj_array.truncate + + let unsafe_set_omit_phys_equal_check t i x = + Obj_array.unsafe_set_omit_phys_equal_check t i (Caml.Obj.repr x) + + let unsafe_clear_if_pointer = Obj_array.unsafe_clear_if_pointer +end + +include Trusted + +let init l ~f = + if l < 0 then invalid_arg "Uniform_array.init" + else + let res = unsafe_create_uninitialized ~len:l in + for i = 0 to l - 1 do + unsafe_set res i (f i) + done; + res + +let of_array arr = init ~f:(Array.unsafe_get arr) (Array.length arr) + +let map a ~f = init ~f:(fun i -> f (unsafe_get a i)) (length a) + +let iter a ~f = + for i = 0 to length a - 1 do + f (unsafe_get a i) + done + +let to_list t = List.init ~f:(get t) (length t) + +let of_list l = + let len = List.length l in + let res = unsafe_create_uninitialized ~len in + List.iteri l ~f:(fun i x -> set res i x); + res + +(* It is not safe for [to_array] to be the identity function because we have code that + relies on [float array]s being unboxed, for example in [bin_write_array]. *) +let to_array t = Array.init (length t) ~f:(fun i -> unsafe_get t i) + +include Sexpable.Of_sexpable1(Array)(struct + type nonrec 'a t = 'a t + let to_sexpable = to_array + let of_sexpable = of_array + end) + +include + Blit.Make1 + (struct + type nonrec 'a t = 'a t + let length = length + let create_like ~len t = + if len = 0 + then empty + else (assert (length t > 0); create ~len (get t 0)) + ;; + let unsafe_blit = unsafe_blit + end) + + diff --git a/src/uniform_array.mli b/src/uniform_array.mli new file mode 100644 index 0000000..5219cbf --- /dev/null +++ b/src/uniform_array.mli @@ -0,0 +1,88 @@ +(** Same semantics as ['a Array.t], except it's guaranteed that the representation array + is not tagged with [Double_array_tag], the tag for float arrays. + + This means it's safer to use in the presence of [Obj.magic], but it's slower than + normal [Array] if you use it with floats. + + It can often be faster than [Array] if you use it with non-floats. +*) + +open! Import + +(** See [Base.Array] for comments. *) +type 'a t [@@deriving_inline sexp] +include +sig + [@@@ocaml.warning "-32"] + include Ppx_sexp_conv_lib.Sexpable.S1 with type 'a t := 'a t +end[@@ocaml.doc "@inline"] +[@@@end] + + +val empty : _ t + +val create : len:int -> 'a -> 'a t + +val singleton : 'a -> 'a t + +val init : int -> f:(int -> 'a) -> 'a t + +val length : 'a t -> int + +val get : 'a t -> int -> 'a +val unsafe_get : 'a t -> int -> 'a + +val set : 'a t -> int -> 'a -> unit +val unsafe_set : 'a t -> int -> 'a -> unit + +val swap : _ t -> int -> int -> unit + +(** [unsafe_set_omit_phys_equal_check] is like [unsafe_set], except it doesn't do a + [phys_equal] check to try to skip [caml_modify]. It is safe to call this even if the + values are [phys_equal]. *) +val unsafe_set_omit_phys_equal_check : 'a t -> int -> 'a -> unit + +val map : 'a t -> f:('a -> 'b) -> 'b t +val iter : 'a t -> f:('a -> unit) -> unit + +(** [of_array] and [to_array] return fresh arrays with the same contents rather than + returning a reference to the underlying array. *) +val of_array : 'a array -> 'a t +val to_array : 'a t -> 'a array + +val of_list : 'a list -> 'a t +val to_list : 'a t -> 'a list + +include Blit.S1 with type 'a t := 'a t + +val copy : 'a t -> 'a t + +(** [truncate t ~len] shortens [t]'s length to [len]. It is an error if [len <= 0] or + [len > length t]. It's unsafe to truncate in the middle of iteration. *) +val unsafe_truncate : _ t -> len:int -> unit + +(** {2 Extra lowlevel and unsafe functions} *) + +(** The behavior is undefined if you access an element before setting it. *) +val unsafe_create_uninitialized : len:int -> _ t + +(** New obj array filled with [Obj.repr 0] *) +val create_obj_array : len:int -> Caml.Obj.t t + +(** [unsafe_set_assuming_currently_int t i obj] sets index [i] of [t] to [obj], but only + works correctly if the value there is an immediate, i.e. [Caml.Obj.is_int (get t i)]. + This precondition saves a dynamic check. + + [unsafe_set_int_assuming_currently_int] is similar, except the value being set is an + int. + + [unsafe_set_int] is similar but does not assume anything about the target. *) +val unsafe_set_assuming_currently_int : Caml.Obj.t t -> int -> Caml.Obj.t -> unit +val unsafe_set_int_assuming_currently_int : Caml.Obj.t t -> int -> int -> unit +val unsafe_set_int : Caml.Obj.t t -> int -> int -> unit + +(** [unsafe_clear_if_pointer t i] prevents [t.(i)] from pointing to anything to prevent + space leaks. It does this by setting [t.(i)] to [Caml.Obj.repr 0]. As a performance + hack, it only does this when [not (Caml.Obj.is_int t.(i))]. It is an error to access + the cleared index before setting it again. *) +val unsafe_clear_if_pointer : Caml.Obj.t t -> int -> unit diff --git a/src/unit.ml b/src/unit.ml new file mode 100644 index 0000000..afae207 --- /dev/null +++ b/src/unit.ml @@ -0,0 +1,29 @@ +open! Import + +module T = struct + type t = unit [@@deriving_inline enumerate, hash, sexp] + let all : t list = [()] + let (hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state) = + hash_fold_unit + and (hash : t -> Ppx_hash_lib.Std.Hash.hash_value) = + let func = hash_unit in fun x -> func x + let t_of_sexp : Ppx_sexp_conv_lib.Sexp.t -> t = unit_of_sexp + let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = sexp_of_unit + [@@@end] + + let compare _ _ = 0 + + let of_string = function + | "()" -> () + | _ -> failwith "Base.Unit.of_string: () expected" + + let to_string () = "()" + + let module_name = "Base.Unit" +end + +include T +include Identifiable.Make (T) + +let invariant () = () diff --git a/src/unit.mli b/src/unit.mli new file mode 100644 index 0000000..6a1ecc4 --- /dev/null +++ b/src/unit.mli @@ -0,0 +1,19 @@ +(** Module for the type [unit]. *) + +open! Import + +type t = unit [@@deriving_inline compare, enumerate, hash, sexp] +include +sig + [@@@ocaml.warning "-32"] + val compare : t -> t -> int + val all : t list + val hash_fold_t : + Ppx_hash_lib.Std.Hash.state -> t -> Ppx_hash_lib.Std.Hash.state + val hash : t -> Ppx_hash_lib.Std.Hash.hash_value + include Ppx_sexp_conv_lib.Sexpable.S with type t := t +end[@@ocaml.doc "@inline"] +[@@@end] + +include Identifiable.S with type t := t +include Invariant.S with type t := t diff --git a/src/validate.ml b/src/validate.ml new file mode 100644 index 0000000..81fb0c6 --- /dev/null +++ b/src/validate.ml @@ -0,0 +1,184 @@ +open! Import + +module Int = Int0 +module String = String0 + +(** Each single_error is a path indicating the location within the datastructure in + question that is being validated, along with an error message. *) +type single_error = + { path : string list; + error : Error.t; + } + +type t = single_error list + +type 'a check = 'a -> t + +let pass : t = [] + +let fails message a sexp_of_a = + [ { path = []; + error = Error.create message a sexp_of_a; + } ] +;; + +let fail message = [ { path = []; error = Error.of_string message } ] + +let failf format = Printf.ksprintf fail format + +let fail_s sexp = [ { path = []; error = Error.create_s sexp } ] + +let combine t1 t2 = t1 @ t2 + +let of_list = List.concat + +let name name t = + match t with + | [] -> [] (* when successful, avoid the allocation of a closure for [~f], below *) + | _ -> List.map t ~f:(fun { path; error } -> { path = name :: path; error }) +;; + +let name_list n l = name n (of_list l) + +let fail_fn message _ = fail message + +let pass_bool (_:bool) = pass +let pass_unit (_:unit) = pass + +let protect f v = + try + f v + with exn -> + fail_s + (Sexp.message "Exception raised during validation" [ "", sexp_of_exn exn ]) +;; + +let try_with f = + protect (fun () -> f (); pass) () + +let path_string path = String.concat ~sep:"." path + +let errors t = + List.map t ~f:(fun { path; error } -> + (Error.to_string_hum (Error.tag error ~tag:(path_string path)))) +;; + +let [@inline never] result_fail t = + Or_error.error + "validation errors" + (List.map t ~f:(fun { path; error } -> (path_string path, error))) + (sexp_of_list (sexp_of_pair sexp_of_string Error.sexp_of_t)) +;; + +(** [result] is carefully implemented so that it can be inlined -- calling [result_fail], + which is not inlineable, is key to this. *) +let result t = + if List.is_empty t + then Ok () + else result_fail t +;; + +let maybe_raise t = Or_error.ok_exn (result t) + +let valid_or_error x check = + Or_error.map (result (protect check x)) ~f:(fun () -> x) +;; + +let field record fld f = + let v = Field.get fld record in + let result = protect f v in + name (Field.name fld) result +;; + +let field_folder record check = (); fun acc fld -> field record fld check :: acc + +let field_direct_folder check = + Staged.stage (fun acc fld _record v -> + match protect check v with + | [] -> acc + | result -> name (Field.name fld) result :: acc) +;; + +let all checks v = + let rec loop checks v errs = + match checks with + | [] -> errs + | check :: checks -> + match protect check v with + | [] -> loop checks v errs + | err -> loop checks v (err :: errs) + in + of_list (List.rev (loop checks v [])) +;; + +let of_result f = + protect (fun v -> + match f v with + | Ok () -> pass + | Error error -> fail error) +;; + +let of_error f = + protect (fun v -> + match f v with + | Ok () -> pass + | Error error -> [ { path = []; error } ]) +;; + +let booltest f ~if_false = protect (fun v -> if f v then pass else fail if_false) + +let pair ~fst ~snd (fst_value,snd_value) = + of_list [ name "fst" (protect fst fst_value); + name "snd" (protect snd snd_value); + ] +;; + +let list_indexed check list = + List.mapi list ~f:(fun i el -> + name (Int.to_string (i+1)) (protect check el)) + |> of_list +;; + +let list ~name:extract_name check list = + List.map list ~f:(fun el -> + match protect check el with + | [] -> [] + | t -> + (* extra level of protection in case extract_name throws an exception *) + protect (fun t -> name (extract_name el) t) t) + |> of_list +;; + +let alist ~name f list' = + list (fun (_, x) -> f x) list' + ~name:(fun (key, _) -> name key) +;; + +let first_failure t1 t2 = if List.is_empty t1 then t2 else t1 + +let of_error_opt = function + | None -> pass + | Some error -> fail error +;; + +let bounded ~name ~lower ~upper ~compare x = + match Maybe_bound.compare_to_interval_exn ~lower ~upper ~compare x with + | In_range -> pass + | Below_lower_bound -> + begin + match lower with + | Unbounded -> assert false + | Incl incl -> fail (Printf.sprintf "value %s < bound %s" (name x) (name incl)) + | Excl excl -> fail (Printf.sprintf "value %s <= bound %s" (name x) (name excl)) + end + | Above_upper_bound -> + begin + match upper with + | Unbounded -> assert false + | Incl incl -> fail (Printf.sprintf "value %s > bound %s" (name x) (name incl)) + | Excl excl -> fail (Printf.sprintf "value %s >= bound %s" (name x) (name excl)) + end + +module Infix = struct + let (++) t1 t2 = combine t1 t2 +end diff --git a/src/validate.mli b/src/validate.mli new file mode 100644 index 0000000..aff6022 --- /dev/null +++ b/src/validate.mli @@ -0,0 +1,175 @@ +(** A module for organizing validations of data structures. + + Allows standardized ways of checking for conditions, and keeps track of the location + of errors by keeping a path to each error found. Thus, if you were validating the + following datastructure: + + {[ + { foo = 3; + bar = { snoo = 34.5; + blue = Snoot -6; } + } + ]} + + One might end up with an error with the error path: + + {v bar.blue.Snoot : value -6 <= bound 0 v} + + By convention, the validations for a type defined in module [M] appear in module [M], + and have their name prefixed by [validate_]. E.g., [Int.validate_positive]. + + Here's an example of how you would use [validate] with a record: + + {[ + type t = + { foo: int; + bar: float; + } + [@@deriving_inline fields][@@@end] + + let validate t = + let module V = Validate in + let w check = V.field_folder t check in + V.of_list + (Fields.fold ~init:[] + ~foo:(w Int.validate_positive) + ~bar:(w Float.validate_non_negative) + ) + ]} + + + And here's an example of how you would use it with a variant type: + + {[ + type t = + | Foo of int + | Bar of (float * int) + | Snoo of Floogle.t + + let validate = function + | Foo i -> V.name "Foo" (Int.validate_positive i) + | Bar p -> V.name "Bar" (V.pair + ~fst:Float.validate_positive + ~snd:Int.validate_non_negative p) + | Snoo floogle -> V.name "Snoo" (Floogle.validate floogle) + ]} *) + +open! Import + +(** The result of a validation. This effectively contains the list of errors, qualified + by their location path *) +type t +type 'a check = 'a -> t (** To make function signatures easier to read. *) + +(** A result containing no errors. *) +val pass : t + +(** A result containing a single error. *) +val fail : string -> t + +val fails + : string + -> 'a + -> ('a -> Sexp.t) + -> t +val fail_s : Sexp.t -> t (** This can be used with the [%sexp] extension. *) + +(** Like [sprintf] or [failwithf] but produces a [t] instead of a string or exception. *) +val failf : ('a, unit, string, t) format4 -> 'a + +val combine : t -> t -> t + +(** Combines multiple results, merging errors. *) +val of_list : t list -> t + +(** Extends location path by one name. *) +val name : string -> t -> t + +val name_list : string -> t list -> t + +(** [fail_fn err] returns a function that always returns fail, with [err] as the error + message. (Note that there is no [pass_fn] so as to discourage people from ignoring + the type of the value being passed unconditionally irrespective of type.) *) +val fail_fn : string -> _ check + +(** Checks for unconditionally passing a bool. *) +val pass_bool : bool check + +(** Checks for unconditionally passing a unit. *) +val pass_unit : unit check + +(** [protect f x] applies the validation [f] to [x], catching any exceptions and returning + them as errors. *) +val protect : 'a check -> 'a check + +(** [try_with f] runs [f] catching any exceptions and returning them as errors. *) +val try_with : (unit -> unit) -> t + +val result : t -> unit Or_error.t + +(** Returns a list of formatted error strings, which include both the error message and + the path to the error. *) +val errors : t -> string list + +(** If the result contains any errors, then raises an exception with a formatted error + message containing a message for every error. *) +val maybe_raise : t -> unit + +(** Returns an error if validation fails. *) +val valid_or_error : 'a -> 'a check -> 'a Or_error.t + +(** Used for validating an individual field. *) +val field : 'record -> ([> `Read], 'record, 'a) Field.t_with_perm -> 'a check -> t + +(** Creates a function for use in a [Fields.fold]. *) +val field_folder + : 'record + -> 'a check + -> (t list -> ([> `Read], 'record, 'a) Field.t_with_perm -> t list) + +(** Creates a function for use in a [Fields.Direct.fold]. *) +val field_direct_folder + : 'a check + -> (t list -> ([> `Read], 'record, 'a) Field.t_with_perm -> 'record -> 'a -> t list) + Staged.t + +(** Combines a list of validation functions into one that does all validations. *) +val all : 'a check list -> 'a check + +(** Creates a validation function from a function that produces a [Result.t]. *) +val of_result : ('a -> (unit, string) Result.t) -> 'a check + +val of_error : ('a -> unit Or_error.t) -> 'a check + +(** Creates a validation function from a function that produces a bool. *) +val booltest : ('a -> bool) -> if_false:string -> 'a check + +(** Validation functions for particular data types. *) +val pair : fst:'a check -> snd:'b check -> ('a * 'b) check + +(** Validates a list, naming each element by its position in the list (where the first + position is 1, not 0). *) +val list_indexed : 'a check -> 'a list check + +(** Validates a list, naming each element using a user-defined function for computing the + name. *) +val list : name:('a -> string) -> 'a check -> 'a list check + +val first_failure : t -> t -> t + +val of_error_opt : string option -> t + +(** Validates an association list, naming each element using a user-defined function for + computing the name. *) +val alist : name:('a -> string) -> 'b check -> ('a * 'b) list check + +val bounded + : name : ('a -> string) + -> lower : 'a Maybe_bound.t + -> upper : 'a Maybe_bound.t + -> compare : ('a -> 'a -> int) + -> 'a check + +module Infix : sig + val (++) : t -> t -> t (** Infix operator for [combine] above. *) +end diff --git a/src/variant.ml b/src/variant.ml new file mode 100644 index 0000000..3b7e041 --- /dev/null +++ b/src/variant.ml @@ -0,0 +1,7 @@ +type 'constructor t = { + name : string; + (* the position of the constructor in the type definition, starting from 0 *) + rank : int; + + constructor : 'constructor +} diff --git a/src/variant.mli b/src/variant.mli new file mode 100644 index 0000000..308a58e --- /dev/null +++ b/src/variant.mli @@ -0,0 +1,9 @@ +(** First-class representative of an individual variant in a variant type, used in + [[@@deriving_inline variants][@@@end]]. *) + +type 'constructor t = { + name : string; + (** The position of the constructor in the type definition, starting from 0 *) + rank : int; + constructor : 'constructor +} diff --git a/src/variantslib.ml b/src/variantslib.ml new file mode 100644 index 0000000..95994d5 --- /dev/null +++ b/src/variantslib.ml @@ -0,0 +1,3 @@ +(** This module is for use by ppx_variants_conv, and is thus not in the interface of + Base. *) +module Variant = Variant diff --git a/src/with_return.ml b/src/with_return.ml new file mode 100644 index 0000000..39a1531 --- /dev/null +++ b/src/with_return.ml @@ -0,0 +1,35 @@ +(* belongs in Common, but moved here to avoid circular dependencies *) + +open! Import + +type 'a return = { return : 'b. 'a -> 'b } [@@unboxed] + +let with_return (type a) f = + let module M = struct + (* Raised to indicate ~return was called. Local so that the exception is tied to a + particular call of [with_return]. *) + exception Return of a + end in + let is_alive = ref true in + let return a = + if not !is_alive + then failwith "use of [return] from a [with_return] that already returned"; + Exn.raise_without_backtrace (M.Return a); + in + try + let a = f { return } in + is_alive := false; + a + with exn -> + is_alive := false; + match exn with + | M.Return a -> a + | _ -> raise exn +;; + +let with_return_option f = + with_return (fun return -> + f { return = fun a -> return.return (Some a) }; None) +;; + +let prepend { return } ~f = { return = fun x -> return (f x) } diff --git a/src/with_return.mli b/src/with_return.mli new file mode 100644 index 0000000..6ed770a --- /dev/null +++ b/src/with_return.mli @@ -0,0 +1,54 @@ + +(** [with_return f] allows for something like the return statement in C within [f]. + + There are three ways [f] can terminate: + + + If [f] calls [r.return x], then [x] is returned by [with_return]. + + If [f] evaluates to a value [x], then [x] is returned by [with_return]. + + If [f] raises an exception, it escapes [with_return]. + + Here is a typical example: + + {[ + let find l ~f = + with_return (fun r -> + List.iter l ~f:(fun x -> if f x then r.return (Some x)); + None + ) + ]} + + It is only because of a deficiency of ML types that [with_return] doesn't have type: + + {[ val with_return : 'a. (('a -> ('b. 'b)) -> 'a) -> 'a ]} + + but we can slightly increase the scope of ['b] without changing the meaning of the + type, and then we get: + + {[ + type 'a return = { return : 'b . 'a -> 'b } + val with_return : ('a return -> 'a) -> 'a + ]} + + But the actual reason we chose to use a record type with polymorphic field is that + otherwise we would have to clobber the namespace of functions with [return] and that + is undesirable because [return] would get hidden as soon as we open any monad. We + considered names different than [return] but everything seemed worse than just having + [return] as a record field. We are clobbering the namespace of record fields but that + is much more acceptable. *) + +open! Import + +type -'a return = private { return : 'b. 'a -> 'b } [@@unboxed] + +val with_return : ('a return -> 'a ) -> 'a + +(** Note that [with_return_option] allocates ~5 words more than the equivalent + [with_return] call. *) +val with_return_option : ('a return -> unit) -> 'a option + +(** [prepend a ~f] returns a value [x] such that each call to [x.return] first applies [f] + before applying [a.return]. The call to [f] is "prepended" to the call to the + original [a.return]. A possible use case is to hand [x] over to another function + which returns ['b], a subtype of ['a], or to capture a common transformation [f] + applied to returned values at several call sites. *) +val prepend : 'a return -> f:('b -> 'a) -> 'b return diff --git a/src/word_size.ml b/src/word_size.ml new file mode 100644 index 0000000..3651fe6 --- /dev/null +++ b/src/word_size.ml @@ -0,0 +1,19 @@ +open! Import + +module Sys = Sys0 + +type t = W32 | W64 [@@deriving_inline sexp_of] +let sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t = + function + | W32 -> Ppx_sexp_conv_lib.Sexp.Atom "W32" + | W64 -> Ppx_sexp_conv_lib.Sexp.Atom "W64" +[@@@end] + +let num_bits = function W32 -> 32 | W64 -> 64 + +let word_size = + match Sys.word_size_in_bits with + | 32 -> W32 + | 64 -> W64 + | _ -> failwith "unknown word size" +;; diff --git a/src/word_size.mli b/src/word_size.mli new file mode 100644 index 0000000..11f7a53 --- /dev/null +++ b/src/word_size.mli @@ -0,0 +1,14 @@ +(** For determining the word size that the program is using. *) + +open! Import + +type t = W32 | W64 [@@deriving_inline sexp_of] +include +sig [@@@ocaml.warning "-32"] val sexp_of_t : t -> Ppx_sexp_conv_lib.Sexp.t +end[@@ocaml.doc "@inline"] +[@@@end] + +val num_bits : t -> int + +(** Returns the word size of this program, not necessarily of the OS. *) +val word_size : t diff --git a/test/avltree_unit_tests.ml b/test/avltree_unit_tests.ml new file mode 100644 index 0000000..443e343 --- /dev/null +++ b/test/avltree_unit_tests.ml @@ -0,0 +1,282 @@ +open! Import + +let%test_module _ = + (module (struct + + open Avltree + + type ('k, 'v) t = ('k, 'v) Avltree.t = private + | Empty + | Node of { mutable left : ('k, 'v) t + ; key : 'k + ; mutable value : 'v + ; mutable height : int + ; mutable right : ('k, 'v) t + } + | Leaf of { key : 'k + ; mutable value : 'v + } + + module For_quickcheck = struct + + module Key = struct + include Quickcheck.Int + let quickcheck_generator = Quickcheck.Generator.small_non_negative_int + end + module Data = struct + include Quickcheck.String + let quickcheck_generator = gen' Quickcheck.Char.gen_lowercase + end + + let compare = Key.compare + + open Quickcheck + open Generator + + module Constructor = struct + + type t = + | Add of Key.t * Data.t + | Replace of Key.t * Data.t + | Remove of Key.t + [@@deriving sexp_of] + + let add_gen = + Key.quickcheck_generator >>= fun key -> + Data.quickcheck_generator >>| fun data -> + Add (key, data) + + let replace_gen = + Key.quickcheck_generator >>= fun key -> + Data.quickcheck_generator >>| fun data -> + Replace (key, data) + + let remove_gen = + Key.quickcheck_generator >>| fun key -> + Remove key + + let quickcheck_generator = union [ add_gen ; replace_gen ; remove_gen ] + + let apply_to_tree t tree = + match t with + | Add (key, data) -> + add tree ~key ~data ~compare ~added:(ref false) ~replace:false + | Replace (key, data) -> + add tree ~key ~data ~compare ~added:(ref false) ~replace:true + | Remove key -> + remove tree key ~compare ~removed:(ref false) + + let apply_to_map t map = + match t with + | Add (key, data) -> + if Map.mem map key + then map + else Map.set map ~key ~data + | Replace (key, data) -> + Map.set map ~key ~data + | Remove key -> + Map.remove map key + + end + + let constructors_gen = List.quickcheck_generator Constructor.quickcheck_generator + + let reify constructors = + List.fold constructors + ~init:(empty, Key.Map.empty) + ~f:(fun (t, map) constructor -> + Constructor.apply_to_tree constructor t, + Constructor.apply_to_map constructor map) + + let merge map1 map2 = + Map.merge map1 map2 ~f:(fun ~key variant -> + match variant with + | `Left data | `Right data -> Some data + | `Both (data1, data2) -> + Error.raise_s ( + [%message + "duplicate data for key" (key : Key.t) (data1 : Data.t) (data2 : Data.t)])) + let rec to_map = function + | Empty -> Key.Map.empty + | Leaf { key; value = data } -> Key.Map.singleton key data + | Node { left; key; value = data; height = _; right } -> + merge (Key.Map.singleton key data) + (merge (to_map left) (to_map right)) + + end + + open For_quickcheck + + let empty = empty + + let%test_unit _ = + match empty with + | Empty -> () + | _ -> assert false + + let invariant = invariant + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + invariant t ~compare; + [%test_result: Data.t Key.Map.t] (to_map t) ~expect:map) + + let add = add + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple4 constructors_gen Key.quickcheck_generator Data.quickcheck_generator Quickcheck.Bool.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t * Data.t * bool] + ~f:(fun (constructors, key, data, replace) -> + let t, map = reify constructors in + (* test [added], other aspects of [add] are tested via [reify] in the + [invariant] test above *) + let added = ref false in + let _ = add t ~key ~data ~compare ~added ~replace in + [%test_result: bool] + !added + ~expect:(not (Map.mem map key))) + + let remove = remove + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + (* test [removed], other aspects of [remove] are tested via [reify] in the + [invariant] test above *) + let removed = ref false in + let _ = remove t key ~compare ~removed in + [%test_result: bool] + !removed + ~expect:(Map.mem map key)) + + let find = find + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: Data.t option] + (find t key ~compare) + ~expect:(Map.find map key)) + + let mem = mem + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: bool] + (mem t key ~compare) + ~expect:(Map.mem map key)) + + let first = first + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) option] + (first t) + ~expect:(Map.min_elt map)) + + let last = last + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) option] + (last t) + ~expect:(Map.max_elt map)) + + let find_and_call = find_and_call + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: [ `Found of Data.t | `Not_found of Key.t ]] + (find_and_call t key ~compare + ~if_found: (fun data -> `Found data) + ~if_not_found: (fun key -> `Not_found key)) + ~expect:(match Map.find map key with + | None -> `Not_found key + | Some data -> `Found data)) + + let findi_and_call = findi_and_call + + let%test_unit _ = + Quickcheck.test + (Quickcheck.Generator.tuple2 constructors_gen Key.quickcheck_generator) + ~sexp_of:[%sexp_of: Constructor.t list * Key.t] + ~f:(fun (constructors, key) -> + let t, map = reify constructors in + [%test_result: [ `Found of (Key.t * Data.t) | `Not_found of Key.t ]] + (findi_and_call t key ~compare + ~if_found: (fun ~key ~data -> `Found (key, data)) + ~if_not_found: (fun key -> `Not_found key)) + ~expect:(match Map.find map key with + | None -> `Not_found key + | Some data -> `Found (key, data))) + + let iter = iter + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) list] + (let q = Queue.create () in + iter t ~f:(fun ~key ~data -> + Queue.enqueue q (key, data)); + Queue.to_list q) + ~expect:(Map.to_alist map)) + + let mapi_inplace = mapi_inplace + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) list] + (mapi_inplace t ~f:(fun ~key:_ ~data -> data ^ data); + (fold t ~init:[] ~f:(fun ~key ~data acc -> (key, data) :: acc))) + ~expect:(Map.map map ~f:(fun data -> data ^ data) + |> Map.to_alist |> List.rev)) + + let fold = fold + + let%test_unit _ = + Quickcheck.test + constructors_gen + ~sexp_of:[%sexp_of: Constructor.t list] + ~f:(fun constructors -> + let t, map = reify constructors in + [%test_result: (Key.t * Data.t) list] + (fold t ~init:[] ~f:(fun ~key ~data acc -> + (key, data) :: acc)) + ~expect:(Map.to_alist map |> List.rev)) + + end : module type of Avltree)) diff --git a/test/avltree_unit_tests.mli b/test/avltree_unit_tests.mli new file mode 100644 index 0000000..8b9cdbb --- /dev/null +++ b/test/avltree_unit_tests.mli @@ -0,0 +1 @@ +(* intentionally blank *) diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..734364c --- /dev/null +++ b/test/dune @@ -0,0 +1,4 @@ +(library (name base_test) + (libraries base core_kernel.base_for_tests caml sexplib num + expect_test_helpers_kernel stdio) + (preprocess (pps ppx_jane -dont-apply=pipebang)))
\ No newline at end of file diff --git a/test/hashtbl_tests.ml b/test/hashtbl_tests.ml new file mode 100644 index 0000000..6fa1f04 --- /dev/null +++ b/test/hashtbl_tests.ml @@ -0,0 +1,360 @@ +open! Base + + +module type Hashtbl_for_testing = sig + include Hashtbl.Accessors with type 'key key = 'key + include Invariant.S2 with type ('key, 'data) t := ('key, 'data) t + + (* we don't define [module Poly : Hashtbl.S_poly] because we want to require only + the minimal number of constructors necessary to implement the tests, and also avoid + conflicting with any existing names. *) + val create_poly : ?size:int -> unit -> ('key, 'data) t + + val of_alist_poly_exn : ('key * 'data) list -> ('key, 'data) t + val of_alist_poly_or_error : ('key * 'data) list -> ('key, 'data) t Or_error.t +end + +module Make (Hashtbl : Hashtbl_for_testing) = struct + open Poly + + let test_data = [("a",1);("b",2);("c",3)] + + let test_hash = begin + let h = Hashtbl.create_poly () ~size:10 in + List.iter test_data ~f:(fun (k,v) -> + Hashtbl.set h ~key:k ~data:v + ); + h + end + + (* This is a very strong notion of equality on hash tables *) + let equal t t' equal_data = + let subtable t t' = + try + List.for_all (Hashtbl.keys t) ~f:(fun key -> + equal_data (Hashtbl.find_exn t key) (Hashtbl.find_exn t' key)) + with + | Invalid_argument _ -> false + in + subtable t t' && subtable t' t + + let%test "find" = + let found = Hashtbl.find test_hash "a" in + let not_found = Hashtbl.find test_hash "A" in + Hashtbl.invariant ignore ignore test_hash; + match found,not_found with + | Some _, None -> true + | _ -> false + ;; + + let%test "findi_and_call" = + let our_hash = Hashtbl.copy test_hash in + let test_string = "test string" in + Hashtbl.add_exn our_hash ~key:test_string ~data:10; + let test_string' = "test " ^ "string" in + assert (not (phys_equal test_string test_string')); + Hashtbl.findi_and_call our_hash test_string' + ~if_found:(fun ~key ~data -> phys_equal test_string key && data = 10) + ~if_not_found:(fun _ -> false) + ;; + + let%test_unit "add" = + let our_hash = Hashtbl.copy test_hash in + let duplicate = Hashtbl.add our_hash ~key:"a" ~data:4 in + let no_duplicate = Hashtbl.add our_hash ~key:"d" ~data:5 in + assert (Hashtbl.find our_hash "a" = Some 1); + assert (Hashtbl.find our_hash "d" = Some 5); + Hashtbl.invariant ignore ignore our_hash; + assert (match duplicate, no_duplicate with + | `Duplicate, `Ok -> true + | _ -> false) + ;; + + let%test "iter" = + let predicted = List.sort ~compare:Int.descending ( + List.map test_data ~f:(fun (_,v) -> v)) + in + let found = + let found = ref [] in + Hashtbl.iter test_hash ~f:(fun v -> found := v :: !found); + !found + |> List.sort ~compare:Int.descending + in + List.equal Int.equal predicted found + ;; + + let%test "iter_keys" = + let predicted = List.sort ~compare:String.descending ( + List.map test_data ~f:(fun (k,_) -> k)) + in + let found = + let found = ref [] in + Hashtbl.iter_keys test_hash ~f:(fun k -> found := k :: !found); + !found + |> List.sort ~compare:String.descending + in + List.equal String.equal predicted found + ;; + + let%test_module "of_alist" = + (module struct + + let%test "size" = + let predicted = List.length test_data in + let found = Hashtbl.length (Hashtbl.of_alist_poly_exn test_data) in + predicted = found + ;; + + let%test "right keys" = + let predicted = List.map test_data ~f:(fun (k,_) -> k) in + let found = Hashtbl.keys (Hashtbl.of_alist_poly_exn test_data) in + let sp = List.sort ~compare:Poly.ascending predicted in + let sf = List.sort ~compare:Poly.ascending found in + sp = sf + ;; + end) + + let%test_module "of_alist_or_error" = + (module struct + + let%test "unique" = + Result.is_ok (Hashtbl.of_alist_poly_or_error test_data) + + let%test "duplicate" = + Result.is_error (Hashtbl.of_alist_poly_or_error (test_data @ test_data)) + + end) + + let%test "size and right keys" = + let predicted = List.map test_data ~f:(fun (k,_) -> k) in + let found = Hashtbl.keys test_hash in + let sp = List.sort ~compare:Poly.ascending predicted in + let sf = List.sort ~compare:Poly.ascending found in + sp = sf + ;; + + let%test "size and right data" = + let predicted = List.map test_data ~f:(fun (_,v) -> v) in + let found = Hashtbl.data test_hash in + let sp = List.sort ~compare:Poly.ascending predicted in + let sf = List.sort ~compare:Poly.ascending found in + sp = sf + ;; + + let%test "map" = + let add1 x = x + 1 in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.map test_data ~f:(fun (k,v) -> (k,add1 v))) + in + let found_alist = + Hashtbl.map test_hash ~f:add1 + |> Hashtbl.to_alist + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test_unit "filter_map" = + let f x = Some x in + let result = Hashtbl.filter_map test_hash ~f in + assert (equal test_hash result Int.(=)); + let is_even x = x % 2 = 0 in + let add1_to_even x = if is_even x then Some (x + 1) else None in + let predicted_data = List.filter_map test_data ~f:(fun (k,v) -> + if is_even v then Some (k, v+1) else None) + in + let found = Hashtbl.filter_map test_hash ~f:add1_to_even in + let found_alist = + List.sort ~compare:Poly.ascending (Hashtbl.to_alist found) + in + assert (List.equal Poly.equal predicted_data found_alist ) + ;; + + let%test "filter_inplace" = + let f x = x <> 2 in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.filter test_data ~f:(fun (_,v) -> f v)) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.filter_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test "filter_keys_inplace" = + let f x = x = "c" in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.filter test_data ~f:(fun (k,_) -> f k)) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.filter_keys_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test "filter_map_inplace" = + let f x = if x = 1 then None else Some (x * 2) in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.filter_map test_data ~f:(fun (k,v) -> Option.map (f v) ~f:(fun x -> (k,x)))) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.filter_map_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test "map_inplace" = + let f x = x + 3 in + let predicted_data = + List.sort ~compare:Poly.ascending + (List.map test_data ~f:(fun (k,v) -> (k,f v))) + in + let test_hash = Hashtbl.copy test_hash in + Hashtbl.map_inplace test_hash ~f; + let found_alist = + Hashtbl.to_alist test_hash + |> List.sort ~compare:Poly.ascending + in + List.equal Poly.equal predicted_data found_alist + ;; + + let%test_unit "insert-find-remove" = + let t = Hashtbl.create_poly () ~size:1 in + let inserted = ref [] in + Random.init 123; + let verify_inserted t = + let missing = + List.fold !inserted ~init:[] ~f:(fun acc (key, data) -> + match Hashtbl.find t key with + | None -> `Missing key :: acc + | Some d -> + if data = d then acc + else `Wrong_data (key, data) :: acc) + in + match missing with + | [] -> () + | _ -> + raise_s + [%message + "some inserts are missing" + (missing : [`Missing of int | `Wrong_data of int * int ] list) ] + in + let equal = Int.equal in + let rec loop i t = + if i < 2000 then begin + let k = Random.int 10_000 in + inserted := List.Assoc.add (List.Assoc.remove !inserted ~equal k) ~equal k i; + Hashtbl.set t ~key:k ~data:i; + Hashtbl.invariant ignore ignore t; + verify_inserted t; + loop (i + 1) t + end + in + loop 0 t; + List.iter !inserted ~f:(fun (x, _) -> + Hashtbl.remove t x; + Hashtbl.invariant ignore ignore t; + begin match Hashtbl.find t x with + | None -> () + | Some _ -> failwith (Printf.sprintf "present after removal: %d" x) + end; + inserted := List.Assoc.remove !inserted ~equal x; + verify_inserted t) + ;; + + let%test_unit "clear" = + let t = Hashtbl.create_poly () ~size:1 in + let l = List.range 0 100 in + let verify_present l = List.for_all l ~f:(Hashtbl.mem t) in + let verify_not_present l = + List.for_all l ~f:(fun i -> not (Hashtbl.mem t i)) + in + List.iter l ~f:(fun i -> Hashtbl.set t ~key:i ~data:(i * i)); + List.iter l ~f:(fun i -> Hashtbl.set t ~key:i ~data:(i * i)); + assert (Hashtbl.length t = 100); + assert (verify_present l); + Hashtbl.clear t; + Hashtbl.invariant ignore ignore t; + assert (Hashtbl.length t = 0); + assert (verify_not_present l); + let l = List.take l 42 in + List.iter l ~f:(fun i -> Hashtbl.set t ~key:i ~data:(i * i)); + assert (Hashtbl.length t = 42); + assert (verify_present l); + Hashtbl.invariant ignore ignore t + ;; + + let%test_unit "mem" = + let t = Hashtbl.create_poly () ~size:1 in + Hashtbl.invariant ignore ignore t; + assert (not (Hashtbl.mem t "Fred")); + Hashtbl.invariant ignore ignore t; + Hashtbl.set t ~key:"Fred" ~data:"Wilma"; + Hashtbl.invariant ignore ignore t; + assert (Hashtbl.mem t "Fred"); + Hashtbl.invariant ignore ignore t; + Hashtbl.remove t "Fred"; + Hashtbl.invariant ignore ignore t; + assert (not (Hashtbl.mem t "Fred")); + Hashtbl.invariant ignore ignore t + ;; + + let%test_unit "exists" = + let t = Hashtbl.create_poly () in + assert (not (Hashtbl.exists t ~f:(fun _ -> failwith "can't be called"))); + assert (not (Hashtbl.existsi t ~f:(fun ~key:_ ~data:_ -> failwith "can't be called"))); + Hashtbl.set t ~key:7 ~data:3; + assert (not (Hashtbl.exists t ~f:(Int.equal 4))); + Hashtbl.set t ~key:8 ~data:4; + assert (Hashtbl.exists t ~f:(Int.equal 4)); + Hashtbl.set t ~key:9 ~data:5; + assert (Hashtbl.existsi t ~f:(fun ~key ~data -> key + data = 14)) + + let%test_unit "for_all" = + let t = Hashtbl.create_poly () in + assert (Hashtbl.for_all t ~f:(fun _ -> failwith "can't be called")); + assert (Hashtbl.for_alli t ~f:(fun ~key:_ ~data:_ -> failwith "can't be called")); + Hashtbl.set t ~key:7 ~data:3; + assert (Hashtbl.for_all t ~f:(fun x -> Int.equal x 3)); + Hashtbl.set t ~key:8 ~data:4; + assert (not (Hashtbl.for_all t ~f:(fun x -> Int.equal x 3))); + Hashtbl.set t ~key:9 ~data:5; + assert (Hashtbl.for_alli t ~f:(fun ~key ~data -> key - 4 = data)) + + let%test_unit "count" = + let t = Hashtbl.create_poly () in + assert (Hashtbl.count t ~f:(fun _ -> failwith "can't be called") = 0); + assert (Hashtbl.counti t ~f:(fun ~key:_ ~data:_ -> failwith "can't be called") = 0); + Hashtbl.set t ~key:7 ~data:3; + assert (Hashtbl.count t ~f:(fun x -> Int.equal x 3) = 1); + Hashtbl.set t ~key:8 ~data:4; + assert (Hashtbl.count t ~f:(fun x -> Int.equal x 3) = 1); + Hashtbl.set t ~key:9 ~data:5; + assert (Hashtbl.counti t ~f:(fun ~key ~data -> key - 4 = data) = 3) + + let%test_unit "merge" = + let make alist = Hashtbl.of_alist_poly_exn alist in + let t1 = make [ 1, 111 ; 2, 222 ; 3, 333 ] in + let t2 = make [ 1, 123 ; 2, 222 ; 4, 444 ] in + [%test_result: (int * [`Left of int|`Right of int|`Both of int*int]) List.t] + (Hashtbl.merge t1 t2 ~f:(fun ~key:_ -> function + | `Left x -> Some (`Left x) + | `Right y -> Some (`Right y) + | `Both (x, y) -> if x=y then None else Some (`Both (x, y))) + |> Hashtbl.to_alist + |> List.sort ~compare:(fun (x,_) (y,_) -> Int.compare x y)) + ~expect:[ 1, `Both (111,123) ; 3, `Left 333 ; 4, `Right 444 ] +end diff --git a/test/hashtbl_tests.mli b/test/hashtbl_tests.mli new file mode 100644 index 0000000..9109c90 --- /dev/null +++ b/test/hashtbl_tests.mli @@ -0,0 +1,13 @@ +open! Base + +module type Hashtbl_for_testing = sig + include Hashtbl.Accessors with type 'key key = 'key + include Invariant.S2 with type ('key, 'data) t := ('key, 'data) t + + val create_poly : ?size:int -> unit -> ('key, 'data) t + + val of_alist_poly_exn : ('key * 'data) list -> ('key, 'data) t + val of_alist_poly_or_error : ('key * 'data) list -> ('key, 'data) t Or_error.t +end + +module Make (Hashtbl : Hashtbl_for_testing) : sig end diff --git a/test/import.ml b/test/import.ml new file mode 100644 index 0000000..3858c5e --- /dev/null +++ b/test/import.ml @@ -0,0 +1,56 @@ +include Base +include Stdio +include Base_for_tests +include Expect_test_helpers_kernel + +module Quickcheck = struct + include Core_kernel.Quickcheck + + module Bool = Core_kernel.Bool + module Char = Core_kernel.Char + module Int = Core_kernel.Int + module Int32 = Core_kernel.Int32 + module Int64 = Core_kernel.Int64 + module List = Core_kernel.List + module Nativeint = Core_kernel.Nativeint + module String = Core_kernel.String +end + +module Core_kernel = struct +end [@@deprecated "[since 1970-01] Don't use Core_kernel in Base tests. Use Base."] + +let () = Base.Not_exposed_properly.Int_conversions.sexp_of_int_style := `Underscores + +let is_none = Option.is_none +let is_some = Option.is_some +let ok_exn = Or_error.ok_exn +let stage = Staged.stage +let unstage = Staged.unstage + +module type Hash = sig + type t [@@deriving hash, sexp_of] +end + +let check_hash_coherence (type t) here (module T : Hash with type t = t) ts = + List.iter ts ~f:(fun t -> + let hash1 = T.hash t in + let hash2 = [%hash: T.t] t in + require here (hash1 = hash2) ~cr:CR_soon + ~if_false_then_print_s: + (lazy [%message "" ~value:(t : T.t) (hash1 : int) (hash2 : int)])); +;; + +module type Int_hash = sig + include Hash + val of_int_exn : int -> t + val min_value : t + val max_value : t +end + +let check_int_hash_coherence (type t) here (module I : Int_hash with type t = t) = + check_hash_coherence here (module I) + [ I.min_value + ; I.of_int_exn 0 + ; I.of_int_exn 37 + ; I.max_value ]; +;; diff --git a/test/int_math_tests.ml b/test/int_math_tests.ml new file mode 100644 index 0000000..6d84011 --- /dev/null +++ b/test/int_math_tests.ml @@ -0,0 +1,39 @@ +let%test_module "overflow_bounds" = + (module struct + module Pow_overflow_bounds = Base.Not_exposed_properly.Pow_overflow_bounds + let%test _ = Pow_overflow_bounds.overflow_bound_max_int_value = Caml.max_int + let%test _ = Pow_overflow_bounds.overflow_bound_max_int64_value = Int64.max_int + + module Big_int = struct + include Big_int + let (>) = gt_big_int + let (=) = eq_big_int + let (^) = power_big_int_positive_int + let (+) = add_big_int + let one = unit_big_int + let to_string = string_of_big_int + end + + let test_overflow_table tbl conv max_val = + assert (Array.length tbl = 64); + let max_val = conv max_val in + StdLabels.Array.iteri tbl ~f:(fun i max_base -> + let max_base = conv max_base in + let overflows b = Big_int.((b ^ i) > max_val) in + let is_ok = + if i = 0 then Big_int.(max_base = max_val) + else + not (overflows max_base) && overflows Big_int.(max_base + one) + in + if not is_ok then + Base.Printf.failwithf + "overflow table check failed for %s (index %d)" + (Big_int.to_string max_base) i ()) + ;; + + let%test_unit _ = test_overflow_table Pow_overflow_bounds.int_positive_overflow_bounds + Big_int.big_int_of_int Caml.max_int + + let%test_unit _ = test_overflow_table Pow_overflow_bounds.int64_positive_overflow_bounds + Big_int.big_int_of_int64 Int64.max_int + end) diff --git a/test/interfaces_tests.ml b/test/interfaces_tests.ml new file mode 100644 index 0000000..0b07c4d --- /dev/null +++ b/test/interfaces_tests.ml @@ -0,0 +1,48 @@ +open Base + +let () = + let module M : sig + open Set + + type ('a, 'b) t + + include Accessors2 + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Set.Using_comparator.Tree.t + with type ('a, 'b) named := ('a, 'b) Set.Named.t + + include Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) With_first_class_module.t + with type ('a, 'b) set := ('a, 'b) t + with type ('a, 'b) t := ('a, 'b) t + with type ('a, 'b) tree := ('a, 'b) Set.Using_comparator.Tree.t + end = struct + type 'a elt = 'a + type _ cmp + include Set + let of_tree _ = assert false + let to_tree _ = assert false + end in () + +let () = + let module M : sig + open Map + + type ('a, 'b, 'c) t + + include Accessors3 + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Map.Using_comparator.Tree.t + + include Creators_generic + with type ('a, 'b, 'c) options := ('a, 'b, 'c) With_first_class_module.t + with type ('a, 'b, 'c) t := ('a, 'b, 'c) t + with type ('a, 'b, 'c) tree := ('a, 'b, 'c) Map.Using_comparator.Tree.t + end = struct + type 'a key = 'a + include Map + let of_tree _ = assert false + let to_tree _ = assert false + end + in + () diff --git a/test/test_am_testing.ml b/test/test_am_testing.ml new file mode 100644 index 0000000..70fb5d1 --- /dev/null +++ b/test/test_am_testing.ml @@ -0,0 +1,8 @@ +open! Base +open! Import + +let%expect_test _ = + print_s [%sexp (Exported_for_specific_uses.am_testing : bool)]; + [%expect {| + true |}]; +;; diff --git a/test/test_am_testing.mli b/test/test_am_testing.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_am_testing.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_am_testing.mlt b/test/test_am_testing.mlt new file mode 100644 index 0000000..3eaebd2 --- /dev/null +++ b/test/test_am_testing.mlt @@ -0,0 +1,7 @@ +open! Base +open! Expect_test_helpers_base + +let () = print_s [%sexp (Exported_for_specific_uses.am_testing : bool)]; +[%%expect {| +true +|}];; diff --git a/test/test_applicative.ml b/test/test_applicative.ml new file mode 100644 index 0000000..06aaa6d --- /dev/null +++ b/test/test_applicative.ml @@ -0,0 +1,383 @@ +open! Import + +let%test_module "Make" = + (module struct + module A = + Applicative.Make (struct + type 'a t = 'a Or_error.t + let return = Or_error.return + let apply = Or_error.apply + let map = `Define_using_apply + end) + + let error = Or_error.error_string + + module Tests : module type of A = struct + let return = A.return + let%expect_test _ = + print_s [%sexp (return "okay" : string Or_error.t)]; + [%expect {| (Ok okay) |}]; + ;; + + let apply = A.apply + let%expect_test _ = + let test x y = print_s [%sexp (apply x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( <*> ) = A.( <*> ) + let%expect_test _ = + let test x y = print_s [%sexp (( <*> ) x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( *> ) = A.( *> ) + let%expect_test _ = + let test x y = print_s [%sexp (( *> ) x y : string Or_error.t)] in + test (Ok ()) (Ok "kay"); + [%expect {| (Ok kay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok ()) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let ( <* ) = A.( <* ) + let%expect_test _ = + let test x y = print_s [%sexp (( <* ) x y : string Or_error.t)] in + test (Ok "okay") (Ok ()); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok ()); + [%expect {| (Error "not okay") |}]; + test (Ok "okay") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let both = A.both + let%expect_test _ = + let test x y = print_s [%sexp (both x y : (string * string) Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok (o kay)) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map = A.map + let%expect_test _ = + let test x = print_s [%sexp (map x ~f:String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let ( >>| ) = A.( >>| ) + let%expect_test _ = + let test x = print_s [%sexp (x >>| String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let map2 = A.map2 + let%expect_test _ = + let test x y = print_s [%sexp (map2 x y ~f:(^) : string Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map3 = A.map3 + let%expect_test _ = + let test x y z = + print_s [%sexp (map3 x y z ~f:(fun a b c -> a ^ b ^ c) : string Or_error.t)] + in + test (Ok "o") (Ok "k") (Ok "ay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "k") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (Ok "k") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no 1st") (error "no 2nd") (error "no 3rd"); + [%expect {| (Error ("no 1st" "no 2nd" "no 3rd")) |}]; + ;; + + let all = A.all + let%expect_test _ = + let test list = print_s [%sexp (all list : string list Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok "okay"]; + [%expect {| (Ok (okay)) |}]; + test [Ok "o"; Ok "kay"]; + [%expect {| (Ok (o kay)) |}]; + test [Ok "o"; Ok "k"; Ok "ay"]; + [%expect {| (Ok (o k ay)) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "okay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "okay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "o"; Ok "kay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; error "oh no!"; Ok "aay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; Ok "kay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + let all_unit = A.all_unit + let all_ignore = all_unit + let%expect_test _ = + let test list = print_s [%sexp (all_unit list : unit Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok (); Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + module Applicative_infix = A.Applicative_infix + end + end) + +let%test_module "Make_using_map2" = + (module struct + module A = + Applicative.Make_using_map2 (struct + type 'a t = 'a Or_error.t + let return = Or_error.return + let map2 = Or_error.map2 + let map = `Define_using_map2 + end) + + let error = Or_error.error_string + + module Tests : module type of A = struct + let return = A.return + let%expect_test _ = + print_s [%sexp (return "okay" : string Or_error.t)]; + [%expect {| (Ok okay) |}]; + ;; + + let apply = A.apply + let%expect_test _ = + let test x y = print_s [%sexp (apply x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( <*> ) = A.( <*> ) + let%expect_test _ = + let test x y = print_s [%sexp (( <*> ) x y : string Or_error.t)] in + test (Ok String.capitalize) (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay") (Ok "okay"); + [%expect {| (Error "not okay") |}]; + test (Ok String.capitalize) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fun") (error "no arg"); + [%expect {| (Error ("no fun" "no arg")) |}]; + ;; + + let ( *> ) = A.( *> ) + let%expect_test _ = + let test x y = print_s [%sexp (( *> ) x y : string Or_error.t)] in + test (Ok ()) (Ok "kay"); + [%expect {| (Ok kay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok ()) (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let ( <* ) = A.( <* ) + let%expect_test _ = + let test x y = print_s [%sexp (( <* ) x y : string Or_error.t)] in + test (Ok "okay") (Ok ()); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok ()); + [%expect {| (Error "not okay") |}]; + test (Ok "okay") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let both = A.both + let%expect_test _ = + let test x y = print_s [%sexp (both x y : (string * string) Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok (o kay)) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map = A.map + let%expect_test _ = + let test x = print_s [%sexp (map x ~f:String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let ( >>| ) = A.( >>| ) + let%expect_test _ = + let test x = print_s [%sexp (x >>| String.capitalize : string Or_error.t)] in + test (Ok "okay"); + [%expect {| (Ok Okay) |}]; + test (error "not okay"); + [%expect {| (Error "not okay") |}]; + ;; + + let map2 = A.map2 + let%expect_test _ = + let test x y = print_s [%sexp (map2 x y ~f:(^) : string Or_error.t)] in + test (Ok "o") (Ok "kay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "kay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no fst") (error "no snd"); + [%expect {| (Error ("no fst" "no snd")) |}]; + ;; + + let map3 = A.map3 + let%expect_test _ = + let test x y z = + print_s [%sexp (map3 x y z ~f:(fun a b c -> a ^ b ^ c) : string Or_error.t)] + in + test (Ok "o") (Ok "k") (Ok "ay"); + [%expect {| (Ok okay) |}]; + test (error "not okay") (Ok "k") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (error "not okay") (Ok "ay"); + [%expect {| (Error "not okay") |}]; + test (Ok "o") (Ok "k") (error "not okay"); + [%expect {| (Error "not okay") |}]; + test (error "no 1st") (error "no 2nd") (error "no 3rd"); + [%expect {| (Error ("no 1st" "no 2nd" "no 3rd")) |}]; + ;; + + let all = A.all + let%expect_test _ = + let test list = print_s [%sexp (all list : string list Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok "okay"]; + [%expect {| (Ok (okay)) |}]; + test [Ok "o"; Ok "kay"]; + [%expect {| (Ok (o kay)) |}]; + test [Ok "o"; Ok "k"; Ok "ay"]; + [%expect {| (Ok (o k ay)) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "okay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "okay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok "o"; Ok "kay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; error "oh no!"; Ok "aay"]; + [%expect {| (Error "oh no!") |}]; + test [Ok "o"; Ok "kay"; error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + let all_unit = A.all_unit + let all_ignore = all_unit + let%expect_test _ = + let test list = print_s [%sexp (all_unit list : unit Or_error.t)] in + test []; + [%expect {| (Ok ()) |}]; + test [Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [Ok (); Ok (); Ok ()]; + [%expect {| (Ok ()) |}]; + test [error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh no!"; Ok (); Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); error "oh no!"; Ok ()]; + [%expect {| (Error "oh no!") |}]; + test [Ok (); Ok (); error "oh no!"]; + [%expect {| (Error "oh no!") |}]; + test [error "oh"; error "no"; error "!"]; + [%expect {| (Error (oh no !)) |}]; + ;; + + module Applicative_infix = A.Applicative_infix + end + end) diff --git a/test/test_applicative.mli b/test/test_applicative.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_applicative.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_array.ml b/test/test_array.ml new file mode 100644 index 0000000..a6937b4 --- /dev/null +++ b/test/test_array.ml @@ -0,0 +1,288 @@ +open! Import +open! Array + +let%test_module "Binary_searchable" = + (module Test_binary_searchable.Test1 + (struct + include Array + module For_test = struct + let of_array = Fn.id + end + end)) + +let%test_module "Blit" = + (module Test_blit.Test1 + (struct + type 'a z = 'a + include Array + let create_bool ~len = create ~len false + end) + (Array)) + +let%test_module "Sort" = + (module struct + open Private.Sort + + let%test_module "Intro_sort.five_element_sort" = + (module struct + (* run [five_element_sort] on all permutations of an array of five elements *) + + let rec sprinkle x xs = + (x :: xs) :: begin + match xs with + | [] -> [] + | x' :: xs' -> + List.map (sprinkle x xs') ~f:(fun sprinkled -> x' :: sprinkled) + end + + let rec permutations = function + | [] -> [[]] + | x :: xs -> + List.concat_map (permutations xs) ~f:(fun perms -> sprinkle x perms) + + let all_perms = permutations [1;2;3;4;5] + let%test _ = List.length all_perms = 120 + let%test _ = not (List.contains_dup ~compare:[%compare: int list] all_perms) + + let%test _ = + List.for_all all_perms ~f:(fun l -> + let arr = Array.of_list l in + Intro_sort.five_element_sort arr ~compare:[%compare: int] 0 1 2 3 4; + [%compare.equal: int t] arr [|1;2;3;4;5|]) + end) + + module Test (M : Private.Sort.Sort) = struct + let random_data ~length ~range = + let arr = Array.create ~len:length 0 in + for i = 0 to length - 1 do + arr.(i) <- Random.int range; + done; + arr + ;; + + let assert_sorted arr = + M.sort arr ~left:0 ~right:(Array.length arr - 1) ~compare:[%compare: int]; + let len = Array.length arr in + let rec loop i prev = + if i = len then true + else if arr.(i) < prev then false + else loop (i + 1) arr.(i) + in + loop 0 (-1) + ;; + + let%test _ = assert_sorted (random_data ~length:0 ~range:100) + let%test _ = assert_sorted (random_data ~length:1 ~range:100) + let%test _ = assert_sorted (random_data ~length:100 ~range:1_000) + let%test _ = assert_sorted (random_data ~length:1_000 ~range:1) + let%test _ = assert_sorted (random_data ~length:1_000 ~range:10) + let%test _ = assert_sorted (random_data ~length:1_000 ~range:1_000_000) + end + + let%test_module _ = (module Test (Insertion_sort)) + let%test_module _ = (module Test (Heap_sort)) + let%test_module _ = (module Test (Intro_sort)) + + let%expect_test "Array.sort [||] only allocates when computing bounds" = + require_allocation_does_not_exceed (Minor_words 3) [%here] + (fun () -> Array.sort ~compare:Int.compare [||]); + [%expect {||}] + ;; + + let%expect_test "Array.sort [| 5; 2; 3; 4; 1 |] only allocates when computing bounds" = + let arr = [| 5; 2; 3; 4; 1 |] in + require_allocation_does_not_exceed (Minor_words 3) [%here] + (fun () -> Array.sort ~compare:Int.compare arr); + [%expect {||}] + ;; + end) + +let%test _ = is_sorted [||] ~compare:[%compare: int] +let%test _ = is_sorted [|0|] ~compare:[%compare: int] +let%test _ = is_sorted [|0;1;2;2;4|] ~compare:[%compare: int] +let%test _ = not (is_sorted [|0;1;2;3;2|] ~compare:[%compare: int]) + +let%test_unit _ = + List.iter + ~f:(fun (t, expect) -> + assert (Bool.equal expect (is_sorted_strictly (of_list t) ~compare:[%compare: int]))) + [ [] , true; + [ 1 ] , true; + [ 1; 2 ] , true; + [ 1; 1 ] , false; + [ 2; 1 ] , false; + [ 1; 2; 3 ], true; + [ 1; 1; 3 ], false; + [ 1; 2; 2 ], false; + ] +;; + +let%test _ = foldi [||] ~init:13 ~f:(fun _ _ _ -> failwith "bad") = 13 +let%test _ = foldi [| 13 |] ~init:17 ~f:(fun i ac x -> ac + i + x) = 30 +let%test _ = foldi [| 13; 17 |] ~init:19 ~f:(fun i ac x -> ac + i + x) = 50 + +let%test _ = counti [|0;1;2;3;4|] ~f:(fun idx x -> idx = x) = 5 +let%test _ = counti [|0;1;2;3;4|] ~f:(fun idx x -> idx = 4-x) = 1 + +let%test_unit _ = + for i = 0 to 5 do + let l1 = List.init i ~f:Fn.id in + let l2 = List.rev (to_list (of_list_rev l1)) in + assert ([%compare.equal: int list] l1 l2); + done +;; + +let%test_unit _ = + List.iter + ~f:(fun (t, len) -> + assert (Exn.does_raise (fun () -> unsafe_truncate t ~len))) + [ [| |] , -1 + ; [| |] , 0 + ; [| |] , 1 + ; [| 1 |], -1 + ; [| 1 |], 0 + ; [| 1 |], 2 + ] +;; + +let%test_unit _ = + for orig_len = 1 to 5 do + for new_len = 1 to orig_len do + let t = init orig_len ~f:Fn.id in + unsafe_truncate t ~len:new_len; + assert (length t = new_len); + for i = 0 to new_len - 1 do + assert (t.(i) = i); + done; + done; + done +;; + +let%test_unit _ = [%test_result: int array] (filter_opt [|Some 1; None; Some 2; None; Some 3|]) ~expect:[|1; 2; 3|] +let%test_unit _ = [%test_result: int array] (filter_opt [|Some 1; None; Some 2|]) ~expect:[|1; 2|] +let%test_unit _ = [%test_result: int array] (filter_opt [|Some 1|]) ~expect:[|1|] +let%test_unit _ = [%test_result: int array] (filter_opt [|None|]) ~expect:[||] +let%test_unit _ = [%test_result: int array] (filter_opt [||]) ~expect:[||] + +let%test_unit _ = + [%test_result: int] + (fold2_exn [||] [||] ~init:13 ~f:(fun _ -> failwith "fail")) + ~expect:13 +let%test_unit _ = + [%test_result: (int * string) list] + (fold2_exn [| 1 |] [| "1" |] ~init:[] ~f:(fun ac a b -> (a, b) :: ac)) + ~expect:[ 1, "1" ] + +let%test_unit _ = [%test_result: int array] (filter [| 0; 1 |] ~f:(fun n -> n < 2)) ~expect:[| 0; 1 |] +let%test_unit _ = [%test_result: int array] (filter [| 0; 1 |] ~f:(fun n -> n < 1)) ~expect:[| 0 |] +let%test_unit _ = [%test_result: int array] (filter [| 0; 1 |] ~f:(fun n -> n < 0)) ~expect:[||] + +let%test_unit _ = [%test_result: bool] (exists [||] ~f:(fun _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists [|0;1;2;3|] ~f:(fun x -> 4 = x)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists [|0;1;2;3|] ~f:(fun x -> 2 = x)) ~expect:true + +let%test_unit _ = [%test_result: bool] (existsi [||] ~f:(fun _ _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [|0;1;2;3|] ~f:(fun i x -> i <> x)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [|0;1;3;3|] ~f:(fun i x -> i <> x)) ~expect:true + +let%test_unit _ = [%test_result: bool] (for_all [||] ~f:(fun _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all [|1;2;3|] ~f:Int.is_positive ) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all [|0;1;3;3|] ~f:Int.is_positive ) ~expect:false + +let%test_unit _ = [%test_result: bool] (for_alli [||] ~f:(fun _ _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [|0;1;2;3|] ~f:(fun i x -> i = x)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [|0;1;3;3|] ~f:(fun i x -> i = x)) ~expect:false + +let%test_unit _ = [%test_result: bool] (exists2_exn [||] [||] ~f:(fun _ _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists2_exn [|0;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x <> y)) ~expect:false +let%test_unit _ = [%test_result: bool] (exists2_exn [|0;2;4;8|] [|0;2;4;6|] ~f:(fun x y -> x <> y)) ~expect:true +let%test_unit _ = [%test_result: bool] (exists2_exn [|2;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x <> y)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all2_exn [||] [||] ~f:(fun _ _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all2_exn [|0;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x = y)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_all2_exn [|0;2;4;8|] [|0;2;4;6|] ~f:(fun x y -> x = y)) ~expect:false +let%test_unit _ = [%test_result: bool] (for_all2_exn [|2;2;4;6|] [|0;2;4;6|] ~f:(fun x y -> x = y)) ~expect:false + +let%test_unit _ = [%test_result: bool] (equal (=) [||] [||]) ~expect:true +let%test_unit _ = [%test_result: bool] (equal (=) [| 1 |] [| 1 |]) ~expect:true +let%test_unit _ = [%test_result: bool] (equal (=) [| 1; 2 |] [| 1; 2 |]) ~expect:true +let%test_unit _ = [%test_result: bool] (equal (=) [||] [| 1 |]) ~expect:false +let%test_unit _ = [%test_result: bool] (equal (=) [| 1 |] [||]) ~expect:false +let%test_unit _ = [%test_result: bool] (equal (=) [| 1 |] [| 1; 2 |]) ~expect:false +let%test_unit _ = [%test_result: bool] (equal (=) [| 1; 2 |] [| 1; 3 |]) ~expect:false + +let%test_unit _ = + [%test_result: (int * int) option] + (findi [|1;2;3;4|] ~f:(fun i x -> i = 2*x)) + ~expect:None +let%test_unit _ = + [%test_result: (int * int) option] + (findi [|1;2;1;4|] ~f:(fun i x -> i = 2*x)) + ~expect:(Some (2, 1)) + +let%test_unit _ = [%test_result: int option] (find_mapi [|0;5;2;1;4|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (find_mapi [|3;5;2;1;4|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 4) +let%test_unit _ = [%test_result: int option] (find_mapi [|3;5;1;1;4|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 8) +let%test_unit _ = [%test_result: int option] (find_mapi [|3;5;1;1;2|] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:None + +let%test_unit _ = + List.iter + ~f:(fun (l, expect) -> + let t = of_list l in + assert (Poly.equal expect (find_consecutive_duplicate t ~equal:Poly.equal))) + [ [] , None + ; [ 1 ] , None + ; [ 1; 1 ] , Some (1, 1) + ; [ 1; 2 ] , None + ; [ 1; 2; 1 ] , None + ; [ 1; 2; 2 ] , Some (2, 2) + ; [ 1; 1; 2; 2 ], Some (1, 1) + ] +;; + +let%test_unit _ = [%test_result: int option] (random_element [| |]) ~expect:None +let%test_unit _ = [%test_result: int option] (random_element [| 0 |]) ~expect:(Some 0) + +let%test_unit _ = + List.iter + [ [||] + ; [| 1 |] + ; [| 1; 2; 3; 4; 5 |] + ] + ~f:(fun t -> + [%test_result: int array] + (Sequence.to_array (to_sequence t)) + ~expect:t) +;; + +let test_fold_map array ~init ~f ~expect = + [%test_result: int array] (folding_map array ~init ~f) ~expect:(snd expect); + [%test_result: int * int array] (fold_map array ~init ~f) ~expect + +let test_fold_mapi array ~init ~f ~expect = + [%test_result: int array] (folding_mapi array ~init ~f) ~expect:(snd expect); + [%test_result: int * int array] (fold_mapi array ~init ~f) ~expect + +let%test_unit _ = test_fold_map [|1;2;3;4|] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(10, [|1;3;6;10|]) +let%test_unit _ = test_fold_map [||] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(0, [||]) +let%test_unit _ = test_fold_mapi [|1;2;3;4|] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(20, [|0;2;8;20|]) +let%test_unit _ = test_fold_mapi [||] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(0, [||]) + +let%test "equal does not allocate" = + let arr1 = [|1;2;3;4|] in + let arr2 = [|1;2;4;3|] in + require_no_allocation [%here] (fun () -> + not (equal Int.equal arr1 arr2)) + +let%test "foldi does not allocate" = + let arr = [|1;2;3;4|] in + let f = fun i x y -> i + x + y in + require_no_allocation [%here] (fun () -> + 16 = (foldi ~init:0 ~f arr)) diff --git a/test/test_array.mli b/test/test_array.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_array.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_backtrace.ml b/test/test_backtrace.ml new file mode 100644 index 0000000..9f7860d --- /dev/null +++ b/test/test_backtrace.ml @@ -0,0 +1,14 @@ +open! Import +open! Backtrace + +let%test_unit _ [@tags "no-js"] = + let t = get () in + assert (String.length (to_string t) > 0) +;; + +let%expect_test _ = + Stdio.Out_channel.(output_string stdout) + (Sexp.to_string (sexp_of_t (Exn.with_recording false ~f:Exn.most_recent))); + [%expect {| + ("<backtrace elided in test>") |}]; +;; diff --git a/test/test_backtrace.mli b/test/test_backtrace.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_backtrace.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_base.ml b/test/test_base.ml new file mode 100644 index 0000000..eecf934 --- /dev/null +++ b/test/test_base.ml @@ -0,0 +1,17 @@ +open! Import + +let%expect_test _ = + let f x = x * 2 in + let g x = x + 3 in + print_s [%sexp (f @@ 5 : int)]; + [%expect {| 10 |}]; + print_s [%sexp (g @@ f @@ 5 : int)]; + [%expect {| 13 |}]; + print_s [%sexp (f @@ g @@ 5 : int)]; + [%expect {| 16 |}]; +;; + +let%expect_test "exp is present at the toplevel" = + print_s [%sexp (2 ** 8 : int)]; + [%expect {| 256 |}] +;; diff --git a/test/test_base.mli b/test/test_base.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_base.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_blit.ml b/test/test_blit.ml new file mode 100644 index 0000000..b62c7fb --- /dev/null +++ b/test/test_blit.ml @@ -0,0 +1,81 @@ +open! Import +open! Blit + +(* This unit test checks that when [blit] calls [unsafe_blit], the slices are valid. + It also checks that [blit] doesn't call [unsafe_blit] when there is a range error. *) +let%test_module _ = + (module struct + + let blit_was_called = ref false + + let slices_are_valid = ref (Ok ()) + + module B = + Make + (struct + type t = bool array + let create ~len = Array.create false ~len + let length = Array.length + let unsafe_blit ~src ~src_pos ~dst ~dst_pos ~len = + blit_was_called := true; + slices_are_valid := + Or_error.try_with (fun () -> + assert (len >= 0); + assert (src_pos >= 0); + assert (src_pos + len <= Array.length src); + assert (dst_pos >= 0); + assert (dst_pos + len <= Array.length dst)); + Array.blit ~src ~src_pos ~dst ~dst_pos ~len; + ;; + end) + ;; + + let%test_module "Bool" = + (module Test_blit.Test + (struct + type t = bool + let equal = Bool.equal + let of_bool = Fn.id + end) + (struct + type t = bool array [@@deriving sexp_of] + let create ~len = Array.create false ~len + let length = Array.length + let get = Array.get + let set = Array.set + end) + (B)) + ;; + + let%test_unit _ = + let opts = [ None; Some (-1); Some 0; Some 1; Some 2 ] in + List.iter [ 0; 1; 2 ] ~f:(fun src -> + List.iter [ 0; 1; 2 ] ~f:(fun dst -> + List.iter opts ~f:(fun src_pos -> + List.iter opts ~f:(fun src_len -> + List.iter opts ~f:(fun dst_pos -> + try begin + let check f = + blit_was_called := false; + slices_are_valid := Ok (); + match Or_error.try_with f with + | Error _ -> assert (not !blit_was_called); + | Ok () -> ok_exn !slices_are_valid + in + check (fun () -> + B.blito + ~src:(Array.create ~len:src false) ?src_pos ?src_len + ~dst:(Array.create ~len:dst false) ?dst_pos + ()); + check (fun () -> + ignore (B.subo (Array.create ~len:src false) ?pos:src_pos ?len:src_len + : bool array)); + end + with exn -> + raise_s [%message + "failure" + (exn : exn) + (src : int) (src_pos : int option) (src_len : int option) + (dst : int) (dst_pos : int option)]))))) + ;; + end) diff --git a/test/test_blit.mli b/test/test_blit.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_blit.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_bool.ml b/test/test_bool.ml new file mode 100644 index 0000000..3c67b70 --- /dev/null +++ b/test/test_bool.ml @@ -0,0 +1,32 @@ +open! Import + +let%expect_test "hash coherence" = + check_hash_coherence [%here] (module Bool) [ false; true ]; + [%expect {| |}] +;; + +let%expect_test "Bool.Non_short_circuiting.(||)" = + let (||) = Bool.Non_short_circuiting.(||) in + assert (true || true); + assert (true || false); + assert (false || true); + assert (not (false || false)); + + assert (true || (print_endline "rhs"; true)); + [%expect {|rhs|}]; + assert (false || (print_endline "rhs"; true)); + [%expect {|rhs|}]; +;; + +let%expect_test "Bool.Non_short_circuiting.(&&)" = + let (&&) = Bool.Non_short_circuiting.(&&) in + assert (true && true); + assert (not (true && false)); + assert (not (false && true)); + assert (not (false && false)); + + assert (true && (print_endline "rhs"; true)); + [%expect {|rhs|}]; + assert (not (false && (print_endline "rhs"; true))); + [%expect {|rhs|}]; +;; diff --git a/test/test_bool.mli b/test/test_bool.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_bool.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_bytes.ml b/test/test_bytes.ml new file mode 100644 index 0000000..a2fd53b --- /dev/null +++ b/test/test_bytes.ml @@ -0,0 +1,14 @@ +open! Import +open! Bytes + +let%test_module "Blit" = + (module Test_blit.Test + (struct + include Char + let of_bool b = if b then 'a' else 'b' + end) + (struct + include Bytes + let create ~len = create len + end) + (Bytes)) diff --git a/test/test_char.ml b/test/test_char.ml new file mode 100644 index 0000000..a80b90c --- /dev/null +++ b/test/test_char.ml @@ -0,0 +1,534 @@ +open! Import +open! Char + +let%test _ = not (is_whitespace '\008') (* backspace *) +let%test _ = is_whitespace '\009' (* '\t': horizontal tab *) +let%test _ = is_whitespace '\010' (* '\n': line feed *) +let%test _ = is_whitespace '\011' (* '\v': vertical tab *) +let%test _ = is_whitespace '\012' (* '\f': form feed *) +let%test _ = is_whitespace '\013' (* '\r': carriage return *) +let%test _ = not (is_whitespace '\014') (* shift out *) +let%test _ = is_whitespace '\032' (* space *) + +let%expect_test "hash coherence" = + check_hash_coherence [%here] (module Char) [ min_value; 'a'; max_value ]; + [%expect {| |}]; +;; + +let%test_module "int to char conversion" = + (module struct + + let%test_unit "of_int bounds" = + let bounds_check i = + [%test_result: t option] + (of_int i) + ~expect:None + ~message:(Int.to_string i) + in + for i = 1 to 100 do + bounds_check (-i); + bounds_check (255 + i); + done + + let%test_unit "of_int_exn vs of_int" = + for i = -100 to 300 do + [%test_eq: t option] + (of_int i) + (Option.try_with (fun () -> of_int_exn i)) + ~message:(Int.to_string i) + done + + let%test_unit "unsafe_of_int vs of_int_exn" = + for i = 0 to 255 do + [%test_eq: t] + (unsafe_of_int i) + (of_int_exn i) + ~message:(Int.to_string i) + done + + end) + +let%expect_test "all" = + print_s [%sexp (all : t list)]; + [%expect {| + ("\000" + "\001" + "\002" + "\003" + "\004" + "\005" + "\006" + "\007" + "\b" + "\t" + "\n" + "\011" + "\012" + "\r" + "\014" + "\015" + "\016" + "\017" + "\018" + "\019" + "\020" + "\021" + "\022" + "\023" + "\024" + "\025" + "\026" + "\027" + "\028" + "\029" + "\030" + "\031" + " " + ! + "\"" + # + $ + % + & + ' + "(" + ")" + * + + + , + - + . + / + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + : + ";" + < + = + > + ? + @ + A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + [ + "\\" + ] + ^ + _ + ` + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + { + | + } + ~ + "\127" + "\128" + "\129" + "\130" + "\131" + "\132" + "\133" + "\134" + "\135" + "\136" + "\137" + "\138" + "\139" + "\140" + "\141" + "\142" + "\143" + "\144" + "\145" + "\146" + "\147" + "\148" + "\149" + "\150" + "\151" + "\152" + "\153" + "\154" + "\155" + "\156" + "\157" + "\158" + "\159" + "\160" + "\161" + "\162" + "\163" + "\164" + "\165" + "\166" + "\167" + "\168" + "\169" + "\170" + "\171" + "\172" + "\173" + "\174" + "\175" + "\176" + "\177" + "\178" + "\179" + "\180" + "\181" + "\182" + "\183" + "\184" + "\185" + "\186" + "\187" + "\188" + "\189" + "\190" + "\191" + "\192" + "\193" + "\194" + "\195" + "\196" + "\197" + "\198" + "\199" + "\200" + "\201" + "\202" + "\203" + "\204" + "\205" + "\206" + "\207" + "\208" + "\209" + "\210" + "\211" + "\212" + "\213" + "\214" + "\215" + "\216" + "\217" + "\218" + "\219" + "\220" + "\221" + "\222" + "\223" + "\224" + "\225" + "\226" + "\227" + "\228" + "\229" + "\230" + "\231" + "\232" + "\233" + "\234" + "\235" + "\236" + "\237" + "\238" + "\239" + "\240" + "\241" + "\242" + "\243" + "\244" + "\245" + "\246" + "\247" + "\248" + "\249" + "\250" + "\251" + "\252" + "\253" + "\254" + "\255") |}] + +let%expect_test "predicates" = + print_s [%sexp (List.filter all ~f:is_digit : t list)]; + [%expect {| (0 1 2 3 4 5 6 7 8 9) |}]; + print_s [%sexp (List.filter all ~f:is_lowercase : t list)]; + [%expect {| (a b c d e f g h i j k l m n o p q r s t u v w x y z) |}]; + print_s [%sexp (List.filter all ~f:is_uppercase : t list)]; + [%expect {| (A B C D E F G H I J K L M N O P Q R S T U V W X Y Z) |}]; + print_s [%sexp (List.filter all ~f:is_alpha : t list)]; + [%expect {| + (A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z) |}]; + print_s [%sexp (List.filter all ~f:is_alphanum : t list)]; + [%expect {| + (0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z) |}]; + print_s [%sexp (List.filter all ~f:is_print : t list)]; + [%expect {| + (" " + ! + "\"" + # + $ + % + & + ' + "(" + ")" + * + + + , + - + . + / + 0 + 1 + 2 + 3 + 4 + 5 + 6 + 7 + 8 + 9 + : + ";" + < + = + > + ? + @ + A + B + C + D + E + F + G + H + I + J + K + L + M + N + O + P + Q + R + S + T + U + V + W + X + Y + Z + [ + "\\" + ] + ^ + _ + ` + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + { + | + } + ~) |}]; + print_s [%sexp (List.filter all ~f:is_whitespace : t list)]; + [%expect {| ("\t" "\n" "\011" "\012" "\r" " ") |}]; diff --git a/test/test_char.mli b/test/test_char.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_char.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_compare.ml b/test/test_compare.ml new file mode 100644 index 0000000..f52f612 --- /dev/null +++ b/test/test_compare.ml @@ -0,0 +1,97 @@ +open! Base +open Expect_test_helpers_kernel + +module type S = sig + type t [@@deriving sexp_of] + include Comparable.Polymorphic_compare with type t := t +end + +(* Test the consistency of derived comparison operators with [compare] because many of + them are hand-optimized in [Base]. *) +let test (type a) here (module T : S with type t = a) list = + let op (type b) (module Result : S with type t = b) operator ~actual ~expect = + With_return.with_return (fun failed -> + List.iter list ~f:(fun arg1 -> + List.iter list ~f:(fun arg2 -> + let actual = actual arg1 arg2 in + let expect = expect arg1 arg2 in + if not (Result.compare actual expect = 0) then begin + print_cr here [%message + "comparison failed" + (operator : string) + (arg1 : T.t) + (arg2 : T.t) + (actual : Result.t) + (expect : Result.t)]; + failed.return () + end))) + in + let module C = Comparable.Make (T) in + op (module Bool) "equal" ~actual:T.equal ~expect:C.equal; + op (module T) "min" ~actual:T.min ~expect:C.min; + op (module T) "max" ~actual:T.max ~expect:C.max; + op (module Bool) "(=)" ~actual:T.(=) ~expect:C.(=); + op (module Bool) "(<)" ~actual:T.(<) ~expect:C.(<); + op (module Bool) "(>)" ~actual:T.(>) ~expect:C.(>); + op (module Bool) "(<>)" ~actual:T.(<>) ~expect:C.(<>); + op (module Bool) "(<=)" ~actual:T.(<=) ~expect:C.(<=); + op (module Bool) "(>=)" ~actual:T.(>=) ~expect:C.(>=); +;; + +let%expect_test "Base" = + test [%here] + (module struct include Base type t = int [@@deriving sexp_of] end) + Int.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Unit" = + test [%here] (module Unit) Unit.all; + [%expect {||}]; +;; + +let%expect_test "Bool" = + test [%here] (module Bool) Bool.all; + [%expect {||}]; +;; + +let%expect_test "Char" = + test [%here] (module Char) Char.all; + [%expect {||}]; +;; + +let%expect_test "Float" = + test [%here] (module Float) + Float.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int" = + test [%here] (module Int) + Int.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int32" = + test [%here] (module Int32) + Int32.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int64" = + test [%here] (module Int64) + Int64.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Nativeint" = + test [%here] (module Nativeint) + Nativeint.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; + +let%expect_test "Int63" = + test [%here] (module Int63) + Int63.([min_value; minus_one; zero; one; max_value]); + [%expect {||}]; +;; diff --git a/test/test_compare.mli b/test/test_compare.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_compare.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_container.ml b/test/test_container.ml new file mode 100644 index 0000000..2e791f5 --- /dev/null +++ b/test/test_container.ml @@ -0,0 +1,174 @@ +open! Import +open! Container + +module Test_generic + (Elt : sig + type 'a t + val of_int : int -> int t + val to_int : int t -> int + end) + (Container : sig + type 'a t [@@deriving sexp] + include Generic + with type 'a t := 'a t + with type 'a elt := 'a Elt.t + val mem : 'a t -> 'a Elt.t -> bool + val of_list : 'a Elt.t list -> [`Ok of 'a t | `Skip_test ] + end) + (* This signature constraint reminds us to add unit tests when functions are added to + [Generic]. *) + : sig + type 'a t [@@deriving sexp] + include Generic with type 'a t := 'a t + val mem : 'a t -> 'a Elt.t -> bool + end + with type 'a t := 'a Container.t + with type 'a elt := 'a Elt.t += struct + + open Container + + let find = find + let find_map = find_map + let fold = fold + let is_empty = is_empty + let iter = iter + let length = length + let mem = mem + let sexp_of_t = sexp_of_t + let t_of_sexp = t_of_sexp + let to_array = to_array + let to_list = to_list + let fold_result = fold_result + let fold_until = fold_until + + let%test_unit _ = + let ( = ) = Poly.equal in + let compare = Poly.compare in + List.iter [ 0; 1; 2; 3; 4; 8; 1024 ] ~f:(fun n -> + let list = List.init n ~f:Elt.of_int in + match Container.of_list list with + | `Skip_test -> () + | `Ok c -> + let sort l = List.sort l ~compare in + let sorts_are_equal l1 l2 = sort l1 = sort l2 in + assert (n = Container.length c); + assert ((n = 0) = Container.is_empty c); + assert (sorts_are_equal list + (Container.fold c ~init:[] ~f:(fun ac e -> e :: ac))); + assert (sorts_are_equal list (Container.to_list c)); + assert (sorts_are_equal list (Array.to_list (Container.to_array c))); + assert (n > 0 = is_some (Container.find c ~f:(fun e -> Elt.to_int e = 0))); + assert (n > 0 = is_some (Container.find c ~f:(fun e -> Elt.to_int e = n - 1))); + assert (is_none (Container.find c ~f:(fun e -> Elt.to_int e = n))); + assert (n > 0 = Container.mem c (Elt.of_int 0)); + assert (n > 0 = Container.mem c (Elt.of_int (n - 1))); + assert (not (Container.mem c (Elt.of_int n))); + assert (n > 0 = is_some (Container.find_map c ~f:(fun e -> + if Elt.to_int e = 0 then Some () else None))); + assert (n > 0 = is_some (Container.find_map c ~f:(fun e -> + if Elt.to_int e = n - 1 then Some () else None))); + assert (is_none (Container.find_map c ~f:(fun e -> + if Elt.to_int e = n then Some () else None))); + let r = ref 0 in + Container.iter c ~f:(fun e -> r := !r + Elt.to_int e); + assert (!r = List.fold list ~init:0 ~f:(fun n e -> n + Elt.to_int e)); + assert (!r = sum (module Int) c ~f:Elt.to_int); + let c2 = [%of_sexp: int Container.t] ([%sexp_of: int Container.t] c) in + assert (sorts_are_equal list (Container.to_list c2)); + let compare_elt a b = Int.compare (Elt.to_int a) (Elt.to_int b) in + if n = 0 then begin + assert (!r = 0); + assert (min_elt ~compare:compare_elt c = None); + assert (max_elt ~compare:compare_elt c = None); + end else begin + assert (!r = (n * (n-1) / 2)); + assert (Option.map ~f:Elt.to_int (min_elt ~compare:compare_elt c) = Some 0); + assert (Option.map ~f:Elt.to_int (max_elt ~compare:compare_elt c) = Some (Int.pred n)); + end; + let mid = Container.length c / 2 in + match + Container.fold_result c + ~init:0 + ~f:(fun count _elt -> if count = mid then Error count else Ok (count + 1)) + with + | Ok 0 -> assert (Container.length c = 0) + | Ok _ -> failwith "Expected fold to stop early" + | Error x -> assert (mid = x) + ) + ;; + + let min_elt = min_elt + let max_elt = max_elt + + let count = count + let sum = sum + let exists = exists + let for_all = for_all + + let%test_unit _ = + List.iter [ []; + [true]; + [false]; + [false; false]; + [true; false]; + [false; true]; + [true; true]; + ] + ~f:(fun bools -> + let count_should_be = + List.fold bools ~init:0 ~f:(fun n b -> if b then n + 1 else n) + in + let forall_should_be = List.fold bools ~init:true ~f:(fun ac b -> b && ac) in + let exists_should_be = List.fold bools ~init:false ~f:(fun ac b -> b || ac) in + match + Container.of_list + (List.map bools ~f:(fun b -> Elt.of_int (if b then 1 else 0))) + with + | `Skip_test -> () + | `Ok container -> + let is_one e = Elt.to_int e = 1 in + let ( = ) = Poly.equal in + assert (forall_should_be = Container.for_all container ~f:is_one); + assert (exists_should_be = Container.exists container ~f:is_one); + assert (count_should_be = Container.count container ~f:is_one); + ) + ;; + +end + +module Test_S1_allow_skipping_tests + (Container : sig + type 'a t [@@deriving sexp] + include Container.S1 with type 'a t := 'a t + val of_list : 'a list -> [`Ok of 'a t | `Skip_test] + end) = struct + + include + Test_generic + (struct + type 'a t = 'a + let of_int = Fn.id + let to_int = Fn.id + end) + (struct + include Container + let mem t a = mem t a ~equal:Poly.equal + end) + + let mem = Container.mem +end + +module Test_S1 + (Container : sig + type 'a t [@@deriving sexp] + include Container.S1 with type 'a t := 'a t + val of_list : 'a list -> 'a t + end) = Test_S1_allow_skipping_tests (struct + include Container + let of_list l = `Ok (of_list l) + end) + +include (Test_S1 (Array) : sig end) +include (Test_S1 (List) : sig end) +include (Test_S1 (Queue) : sig end) diff --git a/test/test_error.ml b/test/test_error.ml new file mode 100644 index 0000000..a2056c0 --- /dev/null +++ b/test/test_error.ml @@ -0,0 +1,23 @@ +open! Base +open! Import + +let errors = + [ Error.of_string "ABC" + ; Error.tag ~tag:"DEF" (Error.of_thunk (fun () -> "GHI")) + ; Error.create_s ([%message "foo" ~bar:(31:int)]) + ] + +let%expect_test _ = + List.iter errors ~f:(fun error -> show_raise (fun () -> Error.raise error)); + [%expect {| + (raised ABC) + (raised (DEF GHI)) + (raised (foo (bar 31))) |}] + +let%expect_test _ = + List.iter errors ~f:(fun error -> + show_raise (fun () -> Error.raise_s [%sexp (error : Error.t)])); + [%expect {| + (raised ABC) + (raised (DEF GHI)) + (raised (foo (bar 31))) |}] diff --git a/test/test_error.mli b/test/test_error.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_error.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_exn.ml b/test/test_exn.ml new file mode 100644 index 0000000..286861e --- /dev/null +++ b/test/test_exn.ml @@ -0,0 +1,18 @@ +open! Import +open! Exn + +let%expect_test "[create_s]" = + print_s [%sexp (create_s [%message "foo"] : t)]; + [%expect {| + foo |}]; + print_s [%sexp (create_s [%message "foo" "bar"] : t)]; + [%expect {| + (foo bar) |}]; + let sexp = [%message "foo"] in + print_s [%sexp (phys_equal sexp (sexp_of_t (create_s sexp)) : bool)]; + [%expect {| + true |}]; +;; + +let%test _ = not (does_raise Fn.ignore) +let%test _ = does_raise (fun () -> failwith "foo") diff --git a/test/test_exn.mli b/test/test_exn.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_exn.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_exported_int_conversions.ml b/test/test_exported_int_conversions.ml new file mode 100644 index 0000000..65fa77a --- /dev/null +++ b/test/test_exported_int_conversions.ml @@ -0,0 +1,235 @@ +open! Import + +module type S = sig + type t [@@deriving compare, sexp_of] + + val num_bits : int + + val min_value : t + val minus_one : t + val zero : t + val one : t + val max_value : t + + val to_int64 : t -> int64 + + val shift_right : t -> int -> t + + val random : Random.State.t -> t -> t -> t +end + +module I : S with type t = int = struct + include Int + let random = Random.State.int_incl +end + +module Native : S with type t = nativeint = struct + include Nativeint + let random = Random.State.nativeint_incl +end + +module I32 : S with type t = int32 = struct + include Int32 + let random = Random.State.int32_incl +end + +module I64 : S with type t = int64 = struct + include Int64 + let random = Random.State.int64_incl +end + +module I63 : S with type t = Int63.t = struct + include Int63 + let random state lo hi = Int63.random_incl ~state lo hi +end + +let iter (type a) (module M : S with type t = a) ~f = + let state = Random.State.make [| 0; 1; 2; 3; 4; 5 |] in + List.iter ~f [M.min_value; M.minus_one; M.zero; M.one; M.max_value]; + for _ = 1 to 10_000 do + (* skew toward low numbers of bits so that, e.g., choosing a random int64 does + frequently find a value that can be converted to int32. *) + let strip_bits = Random.State.int_incl state 0 (M.num_bits - 1) in + let lo = M.shift_right M.min_value strip_bits in + let hi = M.shift_right M.max_value strip_bits in + f (M.random state lo hi) + done + +let try_with f x = Option.try_with (fun () -> f x) + +(* Checks that a conversion from [A.t] to [B.t] is total using [of] and [to]. *) +let test_total (type a) (type b) + (module A : S with type t = a) + (module B : S with type t = b) + ~of_:b_of_a ~to_:a_to_b + = + iter (module A) ~f:(fun a -> + require_compare_equal [%here] (module B) (b_of_a a) (a_to_b a); + require_compare_equal [%here] (module Int64) (A.to_int64 a) (B.to_int64 (b_of_a a))) + +let truncate int64 ~num_bits = + Int64.shift_right (Int64.shift_left int64 (64 - num_bits)) (64 - num_bits) + +(* Checks that a conversion from [A.t] to [B.t] is partial using [of] and [to], and the + [_exn] equivalents. In the case where the conversion fails, ensure that the value, + converted to an [Int64.t] is outside the representable range of [B.t] converted to an + [Int64.t] as well. *) +let test_partial (type a) (type b) + (module A : S with type t = a) + (module B : S with type t = b) + ~of_:b_of_a ~of_exn:b_of_a_exn ~of_trunc:b_of_a_trunc + ~to_:a_to_b ~to_exn:a_to_b_exn ~to_trunc:a_to_b_trunc + = + let module B_option = struct type t = B.t option [@@deriving compare, sexp_of] end in + let convertible_count = ref 0 in + iter (module A) ~f:(fun a -> + require_compare_equal [%here] (module B_option) (b_of_a a) (a_to_b a); + require_compare_equal [%here] (module B_option) (b_of_a a) (try_with b_of_a_exn a); + require_compare_equal [%here] (module B_option) (a_to_b a) (try_with a_to_b_exn a); + match b_of_a a with + | Some b -> + Int.incr convertible_count; + require_compare_equal [%here] (module B) b (b_of_a_trunc a); + require_compare_equal [%here] (module B) b (a_to_b_trunc a); + require_compare_equal [%here] (module Int64) (A.to_int64 a) (B.to_int64 b) + | None -> + let trunc = truncate (A.to_int64 a) ~num_bits:B.num_bits in + require_compare_equal [%here] (module Int64) trunc (B.to_int64 (b_of_a_trunc a)); + require_compare_equal [%here] (module Int64) trunc (B.to_int64 (a_to_b_trunc a)); + require [%here] (Int64.( > ) (A.to_int64 a) (B.to_int64 B.max_value) || + Int64.( < ) (A.to_int64 a) (B.to_int64 B.min_value)) + ~if_false_then_print_s:(lazy [%message "failed to convert" ~_:(a : A.t)])); + (* Make sure we stress the conversion a nontrivial number of times. This makes sure the + random generation is useful and we aren't just testing the hard-coded examples. *) + require [%here] (!convertible_count > 100) + ~if_false_then_print_s: + (lazy [%message + "did not test successful conversion often enough" + (convertible_count : int ref)]) + +let%expect_test "int <-> nativeint" = + test_total (module I) (module Native) + ~of_:Nativeint.of_int + ~to_:Int.to_nativeint; + [%expect {| |}]; + test_partial (module Native) (module I) + ~of_:Int.of_nativeint ~of_exn:Int.of_nativeint_exn ~of_trunc:Int.of_nativeint_trunc + ~to_:Nativeint.to_int ~to_exn:Nativeint.to_int_exn ~to_trunc:Nativeint.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "int <-> int32" = + test_partial (module I) (module I32) + ~of_:Int32.of_int ~of_exn:Int32.of_int_exn ~of_trunc:Int32.of_int_trunc + ~to_:Int.to_int32 ~to_exn:Int.to_int32_exn ~to_trunc:Int.to_int32_trunc; + [%expect {| |}]; + test_partial (module I32) (module I) + ~of_:Int.of_int32 ~of_exn:Int.of_int32_exn ~of_trunc:Int.of_int32_trunc + ~to_:Int32.to_int ~to_exn:Int32.to_int_exn ~to_trunc:Int32.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "nativeint <-> int32" = + test_partial (module Native) (module I32) + ~of_: Int32.of_nativeint + ~of_exn: Int32.of_nativeint_exn + ~of_trunc: Int32.of_nativeint_trunc + ~to_: Nativeint.to_int32 + ~to_exn: Nativeint.to_int32_exn + ~to_trunc: Nativeint.to_int32_trunc; + [%expect {| |}]; + test_total (module I32) (module Native) + ~of_:Nativeint.of_int32 + ~to_:Int32.to_nativeint; + [%expect {| |}]; +;; + +let%expect_test "int <-> int64" = + test_total (module I) (module I64) + ~of_:Int64.of_int + ~to_:Int.to_int64; + [%expect {| |}]; + test_partial (module I64) (module I) + ~of_:Int.of_int64 ~of_exn:Int.of_int64_exn ~of_trunc:Int.of_int64_trunc + ~to_:Int64.to_int ~to_exn:Int64.to_int_exn ~to_trunc:Int64.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "nativeint <-> int64" = + test_total (module Native) (module I64) + ~of_:Int64.of_nativeint + ~to_:Nativeint.to_int64; + [%expect {| |}]; + test_partial (module I64) (module Native) + ~of_: Nativeint.of_int64 + ~of_exn: Nativeint.of_int64_exn + ~of_trunc: Nativeint.of_int64_trunc + ~to_: Int64.to_nativeint + ~to_exn: Int64.to_nativeint_exn + ~to_trunc: Int64.to_nativeint_trunc; + [%expect {| |}]; +;; + +let%expect_test "int32 <-> int64" = + test_total (module I32) (module I64) + ~of_:Int64.of_int32 + ~to_:Int32.to_int64; + [%expect {| |}]; + test_partial (module I64) (module I32) + ~of_:Int32.of_int64 ~of_exn:Int32.of_int64_exn ~of_trunc:Int32.of_int64_trunc + ~to_:Int64.to_int32 ~to_exn:Int64.to_int32_exn ~to_trunc:Int64.to_int32_trunc; + [%expect {| |}]; +;; + + +let%expect_test "int <-> int63" = + test_total (module I) (module I63) + ~of_:Int63.of_int + ~to_:Int63.of_int; + [%expect {| |}]; + test_partial (module I63) (module I) + ~of_:Int63.to_int ~of_exn:Int63.to_int_exn ~of_trunc:Int63.to_int_trunc + ~to_:Int63.to_int ~to_exn:Int63.to_int_exn ~to_trunc:Int63.to_int_trunc; + [%expect {| |}]; +;; + +let%expect_test "nativeint <-> int63" = + test_partial (module Native) (module I63) + ~of_: Int63.of_nativeint + ~of_exn: Int63.of_nativeint_exn + ~of_trunc: Int63.of_nativeint_trunc + ~to_: Int63.of_nativeint + ~to_exn: Int63.of_nativeint_exn + ~to_trunc: Int63.of_nativeint_trunc; + [%expect {| |}]; + test_partial (module I63) (module Native) + ~of_: Int63.to_nativeint + ~of_exn: Int63.to_nativeint_exn + ~of_trunc: Int63.to_nativeint_trunc + ~to_: Int63.to_nativeint + ~to_exn: Int63.to_nativeint_exn + ~to_trunc: Int63.to_nativeint_trunc; + [%expect {| |}]; +;; + +let%expect_test "int32 <-> int63" = + test_total (module I32) (module I63) + ~of_:Int63.of_int32 + ~to_:Int63.of_int32; + [%expect {| |}]; + test_partial (module I63) (module I32) + ~of_:Int63.to_int32 ~of_exn:Int63.to_int32_exn ~of_trunc:Int63.to_int32_trunc + ~to_:Int63.to_int32 ~to_exn:Int63.to_int32_exn ~to_trunc:Int63.to_int32_trunc; + [%expect {| |}]; +;; + +let%expect_test "int64 <-> int63" = + test_partial (module I64) (module I63) + ~of_:Int63.of_int64 ~of_exn:Int63.of_int64_exn ~of_trunc:Int63.of_int64_trunc + ~to_:Int63.of_int64 ~to_exn:Int63.of_int64_exn ~to_trunc:Int63.of_int64_trunc; + [%expect {| |}]; + test_total (module I63) (module I64) + ~of_:Int63.to_int64 + ~to_:Int63.to_int64; + [%expect {| |}]; +;; diff --git a/test/test_exported_int_conversions.mli b/test/test_exported_int_conversions.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_exported_int_conversions.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_float.ml b/test/test_float.ml new file mode 100644 index 0000000..74baca9 --- /dev/null +++ b/test/test_float.ml @@ -0,0 +1,1025 @@ +open! Import +open! Float +open! Float.Private + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module Float) [ min_value; 0.; 37.; max_value ]; + [%expect {| |}]; +;; + +let exponent_bits = 11 +let mantissa_bits = 52 + +let exponent_mask64 = Int64.((shift_left one exponent_bits) - one) +let exponent_mask = Int64.to_int_exn exponent_mask64 +let mantissa_mask = Int63.((shift_left one mantissa_bits) - one) +let _mantissa_mask64 = Int63.to_int64 mantissa_mask + +let%test_unit "upper/lower_bound_for_int" = + assert ( + [%compare.equal: (int * t * t) list] + ([8; 16; 31; 32; 52; 53; 54; 62; 63; 64] + |> List.map ~f:(fun x -> (x, lower_bound_for_int x, upper_bound_for_int x))) + [( 8 + , -128.99999999999997 + , 127.99999999999999) + ; (16 + , -32768.999999999993 + , 32767.999999999996) + ; (31 + , -1073741824.9999998 + , 1073741823.9999999) + ; (32 + , -2147483648.9999995 + , 2147483647.9999998) + ; (52 + , -2251799813685248.5 + , 2251799813685247.8) + ; (53 + , -4503599627370496. + , 4503599627370495.5) + ; (54 + , -9007199254740992. + , 9007199254740991.) + ; (62 + , -2305843009213693952. + , 2305843009213693696.) + ; (63 + , -4611686018427387904. + , 4611686018427387392.) + ; (64 + , -9223372036854775808. + , 9223372036854774784.) + ]) +;; + +let%test_unit _ = + (* on 64-bit platform ppx_hash hashes floats exactly the same as polymorphic hash *) + match Word_size.word_size with + | W32 -> () + | W64 -> + List.iter ~f:(fun float -> + let hash1 = Caml.Hashtbl.hash float in + let hash2 = [%hash: float] float in + let hash3 = specialized_hash float in + if (not Int.(hash1 = hash2 && hash1 = hash3)) + then raise_s [%message "bad" + (hash1 : Int.Hex.t) + (hash2 : Int.Hex.t) + (hash3 : Int.Hex.t)] + ) + [ 0.926038888360971146 + ; 34.1638588598232076 + ] +;; + +let test_both_ways (a : t) (b : int64) = + Int64.(=) (to_int64_preserve_order_exn a) b && Float.(=) (of_int64_preserve_order b) a +;; + +let%test _ = test_both_ways 0. 0L +let%test _ = test_both_ways (-0.) 0L +let%test _ = test_both_ways 1. Int64.(shift_left 1023L 52) +let%test _ = test_both_ways (-2.) Int64.(neg (shift_left 1024L 52)) +let%test _ = test_both_ways infinity Int64.(shift_left 2047L 52) +let%test _ = test_both_ways neg_infinity Int64.(neg (shift_left 2047L 52)) + +let%test _ = one_ulp `Down infinity = max_finite_value +let%test _ = is_nan (one_ulp `Up infinity) +let%test _ = is_nan (one_ulp `Down neg_infinity) +let%test _ = one_ulp `Up neg_infinity = ~-. max_finite_value + +(* Some tests to make sure that the compiler is generating code for handling subnormal + numbers at runtime accurately. *) +let x () = min_positive_subnormal_value +let y () = min_positive_normal_value + +let%test _ = test_both_ways (x ()) 1L +let%test _ = test_both_ways (y ()) Int64.(shift_left 1L 52) + +let%test _ = x () > 0. +let%test_unit _ = [%test_result: float] (x () /. 2.) ~expect:0. + +let%test _ = one_ulp `Up 0. = x () +let%test _ = one_ulp `Down 0. = ~-. (x ()) + +let are_one_ulp_apart a b = one_ulp `Up a = b + +let%test _ = are_one_ulp_apart (x ()) (2. *. x ()) +let%test _ = are_one_ulp_apart (2. *. x ()) (3. *. x ()) + +let one_ulp_below_y () = y () -. x () +let%test _ = one_ulp_below_y () < y () +let%test _ = y () -. one_ulp_below_y () = x () +let%test _ = are_one_ulp_apart (one_ulp_below_y ()) (y ()) + +let one_ulp_above_y () = y () +. x () +let%test _ = y () < one_ulp_above_y () +let%test _ = one_ulp_above_y () -. y () = x () +let%test _ = are_one_ulp_apart (y ()) (one_ulp_above_y ()) + +let%test _ = not (are_one_ulp_apart (one_ulp_below_y ()) (one_ulp_above_y ())) + +(* [2 * min_positive_normal_value] is where the ulp increases for the first time. *) +let z () = 2. *. y () +let one_ulp_below_z () = z () -. x () +let%test _ = one_ulp_below_z () < z () +let%test _ = z () -. one_ulp_below_z () = x () +let%test _ = are_one_ulp_apart (one_ulp_below_z ()) (z ()) + +let one_ulp_above_z () = z () +. 2. *. x () +let%test _ = z () < one_ulp_above_z () +let%test _ = one_ulp_above_z () -. z () = 2. *. x () +let%test _ = are_one_ulp_apart (z ()) (one_ulp_above_z ()) + +let%test_module "clamp" = + (module struct + let%test _ = clamp_exn 1.0 ~min:2. ~max:3. = 2. + let%test _ = clamp_exn 2.5 ~min:2. ~max:3. = 2.5 + let%test _ = clamp_exn 3.5 ~min:2. ~max:3. = 3. + + let%test_unit "clamp" = + [%test_result: float Or_error.t] (clamp 3.5 ~min:2. ~max:3.) ~expect:(Ok 3.) + + let%test_unit "clamp nan" = + [%test_result: float Or_error.t] (clamp nan ~min:2. ~max:3.) ~expect:(Ok nan) + + let%test "clamp bad" = Or_error.is_error (clamp 2.5 ~min:3. ~max:2.) + end) + +let%test_unit _ = + [%test_result: Int64.t] + (Int64.bits_of_float 1.1235582092889474E+307) ~expect:0x7fb0000000000000L + +let%test_module "IEEE" = + (module struct + (* Note: IEEE 754 defines NaN values to be those where the exponent is all 1s and the + mantissa is nonzero. test_result<t> sees nan values as equal because it is based + on [compare] rather than [=]. (If [x] and [x'] are nan, [compare x x'] returns 0, + whereas [x = x'] returns [false]. This is the case regardless of whether or not + [x] and [x'] are bit-identical values of nan.) *) + let f (t : t) (negative : bool) (exponent : int) (mantissa : Int63.t) : unit = + let str = to_string t in + let is_nan = is_nan t in + (* the sign doesn't matter when nan *) + if not is_nan then + [%test_result: bool] ~message:("ieee_negative " ^ str) + (ieee_negative t) ~expect:negative; + [%test_result: int] ~message:("ieee_exponent " ^ str) + (ieee_exponent t) ~expect:exponent; + if is_nan + then assert (Int63.(zero <> ieee_mantissa t)) + else [%test_result: Int63.t] ~message:("ieee_mantissa " ^ str) + (ieee_mantissa t) ~expect:mantissa; + [%test_result: t] + ~message:(Printf.sprintf !"create_ieee ~negative:%B ~exponent:%d ~mantissa:%{Int63}" + negative exponent mantissa) + (create_ieee_exn ~negative ~exponent ~mantissa) + ~expect:t + + let%test_unit _ = + let (!!) x = Int63.of_int x in + f zero false 0 (!! 0); + f min_positive_subnormal_value false 0 (!! 1); + f min_positive_normal_value false 1 (!! 0); + f epsilon_float false Int.(1023 - mantissa_bits) (!! 0); + f one false 1023 (!! 0); + f minus_one true 1023 (!! 0); + f max_finite_value false Int.(exponent_mask - 1) mantissa_mask; + f infinity false exponent_mask (!! 0); + f neg_infinity true exponent_mask (!! 0); + f nan false exponent_mask (!! 1) + + (* test the normalized case, that is, 1 <= exponent <= 2046 *) + let%test_unit _ = + let g ~negative ~exponent ~mantissa = + assert (create_ieee_exn ~negative ~exponent + ~mantissa:(Int63.of_int64_exn mantissa) + = + (if negative then -1. else 1.) + * 2. **. (Float.of_int exponent - 1023.) + * (1. + (2. **. -52.) * Int64.to_float mantissa)) + in + g ~negative:false ~exponent:1 ~mantissa:147L; + g ~negative:true ~exponent:137 ~mantissa:13L; + g ~negative:false ~exponent:1015 ~mantissa:1370001L; + g ~negative:true ~exponent:2046 ~mantissa:137000100945L + end) + +let%test_module _ = + (module struct + let test f expect = + let actual = to_padded_compact_string f in + if String.(actual <> expect) + then raise_s + [%message "failure" + (f : t ) + (expect : string) + (actual : string)] + + let both f expect = + assert (f > 0.); + test f expect; + test (~-.f) ("-"^expect); + ;; + + let decr = one_ulp `Down + let incr = one_ulp `Up + + let boundary f ~closer_to_zero ~at = + assert (f > 0.); + (* If [f] looks like an odd multiple of 0.05, it might be slightly under-represented + as a float, e.g. + + 1. -. 0.95 = 0.0500000000000000444 + + In such case, sadly, the right way for [to_padded_compact_string], just as for + [sprintf "%.1f"], is to round it down. However, the next representable number + should be rounded up: + + # let x = 0.95 in sprintf "%.0f / %.1f / %.2f / %.3f / %.20f" x x x x x;; + - : string = "1 / 0.9 / 0.95 / 0.950 / 0.94999999999999995559" + + # let x = incr 0.95 in sprintf "%.0f / %.1f / %.2f / %.3f / %.20f" x x x x x ;; + - : string = "1 / 1.0 / 0.95 / 0.950 / 0.95000000000000006661" + + *) + let f = + if f >= 1000. then + f + else + let x = Printf.sprintf "%.20f" f in + let spot = String.index_exn x '.' in + (* the following condition is only meant to work for small multiples of 0.05 *) + let (+) = Int.(+) in + let (=) = Char.(=) in + if x.[spot + 2] = '4' && x.[spot + 3] = '9' && x.[spot + 4] = '9' then + (* something like 0.94999999999999995559 *) + incr f + else + f + in + both (decr f) closer_to_zero; + both f at; + ;; + + let%test_unit _ = test nan "nan " + let%test_unit _ = test 0.0 "0 " + let%test_unit _ = both min_positive_subnormal_value "0 " + let%test_unit _ = both infinity "inf " + + let%test_unit _ = boundary 0.05 ~closer_to_zero: "0 " ~at: "0.1" + let%test_unit _ = boundary 0.15 ~closer_to_zero: "0.1" ~at: "0.2" + (* glibc printf resolves ties to even, cf. + http://www.exploringbinary.com/inconsistent-rounding-of-printed-floating-point-numbers/ + Ties are resolved differently in JavaScript - mark some tests as no running with JavaScript. + *) + let%test_unit _ [@tags "no-js"] = + boundary (* tie *) 0.25 ~closer_to_zero: "0.2" ~at: "0.2" + let%test_unit _ [@tags "no-js"] = + boundary (incr 0.25)~closer_to_zero: "0.2" ~at: "0.3" + let%test_unit _ = boundary 0.35 ~closer_to_zero: "0.3" ~at: "0.4" + let%test_unit _ = boundary 0.45 ~closer_to_zero: "0.4" ~at: "0.5" + let%test_unit _ = both 0.50 "0.5" + let%test_unit _ = boundary 0.55 ~closer_to_zero: "0.5" ~at: "0.6" + let%test_unit _ = boundary 0.65 ~closer_to_zero: "0.6" ~at: "0.7" + (* this time tie-to-even means round away from 0 *) + let%test_unit _ = boundary (* tie *) 0.75 ~closer_to_zero: "0.7" ~at: "0.8" + let%test_unit _ = boundary 0.85 ~closer_to_zero: "0.8" ~at: "0.9" + let%test_unit _ = boundary 0.95 ~closer_to_zero: "0.9" ~at: "1 " + let%test_unit _ = boundary 1.05 ~closer_to_zero: "1 " ~at: "1.1" + let%test_unit _ [@tags "no-js"] = + boundary 3.25 ~closer_to_zero: "3.2" ~at: "3.2" + let%test_unit _ [@tags "no-js"] = + boundary (incr 3.25)~closer_to_zero: "3.2" ~at: "3.3" + let%test_unit _ = boundary 3.75 ~closer_to_zero: "3.7" ~at: "3.8" + let%test_unit _ = boundary 9.95 ~closer_to_zero: "9.9" ~at: "10 " + let%test_unit _ = boundary 10.05 ~closer_to_zero: "10 " ~at: "10.1" + let%test_unit _ = boundary 100.05 ~closer_to_zero:"100 " ~at: "100.1" + let%test_unit _ [@tags "no-js"] = + boundary (* tie *) 999.25 ~closer_to_zero:"999.2" ~at: "999.2" + let%test_unit _ [@tags "no-js"] = + boundary (incr 999.25)~closer_to_zero:"999.2" ~at: "999.3" + let%test_unit _ = boundary 999.75 ~closer_to_zero:"999.7" ~at: "999.8" + let%test_unit _ = boundary 999.95 ~closer_to_zero:"999.9" ~at: "1k " + let%test_unit _ = both 1000. "1k " + + (* some ties which we resolve manually in [iround_ratio_exn] *) + let%test_unit _ = boundary 1050. ~closer_to_zero: "1k " ~at: "1k " + let%test_unit _ = boundary (incr 1050.) ~closer_to_zero: "1k " ~at: "1k1" + let%test_unit _ = boundary 1950. ~closer_to_zero: "1k9" ~at: "2k " + let%test_unit _ = boundary 3250. ~closer_to_zero: "3k2" ~at: "3k2" + let%test_unit _ = boundary (incr 3250.) ~closer_to_zero: "3k2" ~at: "3k3" + let%test_unit _ = boundary 9950. ~closer_to_zero: "9k9" ~at: "10k " + let%test_unit _ = boundary 33_250. ~closer_to_zero: "33k2" ~at: "33k2" + let%test_unit _ = boundary (incr 33_250.) ~closer_to_zero: "33k2" ~at: "33k3" + let%test_unit _ = boundary 33_350. ~closer_to_zero: "33k3" ~at: "33k4" + let%test_unit _ = boundary 33_750. ~closer_to_zero: "33k7" ~at: "33k8" + let%test_unit _ = boundary 333_250. ~closer_to_zero:"333k2" ~at: "333k2" + let%test_unit _ = boundary (incr 333_250.) ~closer_to_zero:"333k2" ~at: "333k3" + let%test_unit _ = boundary 333_750. ~closer_to_zero:"333k7" ~at: "333k8" + let%test_unit _ = boundary 999_850. ~closer_to_zero:"999k8" ~at: "999k8" + let%test_unit _ = boundary (incr 999_850.) ~closer_to_zero:"999k8" ~at: "999k9" + let%test_unit _ = boundary 999_950. ~closer_to_zero:"999k9" ~at: "1m " + let%test_unit _ = boundary 1_050_000. ~closer_to_zero: "1m " ~at: "1m " + let%test_unit _ = boundary (incr 1_050_000.) ~closer_to_zero: "1m " ~at: "1m1" + + let%test_unit _ = boundary 999_950_000. ~closer_to_zero:"999m9" ~at: "1g " + let%test_unit _ = boundary 999_950_000_000. ~closer_to_zero:"999g9" ~at: "1t " + let%test_unit _ = boundary 999_950_000_000_000. ~closer_to_zero:"999t9" ~at: "1p " + let%test_unit _ = boundary 999_950_000_000_000_000. ~closer_to_zero:"999p9" ~at:"1.0e+18" + + (* Test the boundary between the subnormals and the normals. *) + let%test_unit _ = boundary min_positive_normal_value ~closer_to_zero:"0 " ~at:"0 " + end) + +let%test "int_pow" = + let tol = 1e-15 in + let test (x, n) = + let reference_value = x **. of_int n in + let relative_error = (int_pow x n -. reference_value) /. reference_value in + abs relative_error < tol + in + List.for_all ~f:test + [(1.5, 17); (1.5, 42); (0.99, 64); (2., -5); (2., -1) + ; (-1.3, 2); (-1.3, -1); (-1.3, -2); (5., 0) + ; (nan, 0); (0., 0); (infinity, 0) + ] + +let%test "int_pow misc" = + int_pow 0. (-1) = infinity + && int_pow (-0.) (-1) = neg_infinity + && int_pow (-0.) (-2) = infinity + && int_pow 1.5 5000 = infinity + && int_pow 1.5 (-5000) = 0. + && int_pow (-1.) Int.max_value = -1. + && int_pow (-1.) Int.min_value = 1. + +(* some ugly corner cases with extremely large exponents and some serious precision loss *) +let%test "int_pow bad cases" [@tags "64-bits-only"] = + let a = one_ulp `Down 1. in + let b = one_ulp `Up 1. in + let large = 1 lsl 61 in + let small = Int.neg large in + (* this huge discrepancy comes from the fact that [1 / a = b] but this is a very poor + approximation, and in particular [1 / b = one_ulp `Down a = a * a]. *) + a **. of_int small = 1.5114276650041252e+111 + && int_pow a small = 2.2844048619719663e+222 + && int_pow b large = 2.2844048619719663e+222 + && b **. of_int large = 2.2844135865396268e+222 + +let%test_unit "sign_exn" = + List.iter ~f:(fun (input,expect) -> assert (Sign.equal (sign_exn input) expect)) + [ (1e-30, Sign.Pos) + ; (-0., Zero) + ; (0., Zero) + ; (neg_infinity, Neg) + ] + +let%test _ = + match sign_exn nan with + | Neg | Zero | Pos -> false + | exception _ -> true + +let%test_unit "sign_or_nan" = + List.iter ~f:(fun (input,expect) -> assert (Sign_or_nan.equal (sign_or_nan input) expect)) + [ (1e-30, Sign_or_nan.Pos) + ; (-0., Zero) + ; (0., Zero) + ; (neg_infinity, Neg) + ; (nan, Nan) + ] + +let%test_module _ = + (module struct + let check v expect = + match Validate.result v, expect with + | Ok (), `Ok | Error _, `Error -> () + | r, expect -> + raise_s + [%message "mismatch" (r : unit Or_error.t) (expect : [ `Ok | `Error ])] + ;; + + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) nan) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) infinity) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) neg_infinity) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) (-1.)) `Error + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) 0.) `Ok + let%test_unit _ = check (validate_lbound ~min:(Incl 0.) 1.) `Ok + + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) nan) `Error + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) infinity) `Error + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) neg_infinity) `Error + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) (-1.)) `Ok + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) 0.) `Ok + let%test_unit _ = check (validate_ubound ~max:(Incl 0.) 1.) `Error + + (* Some of the following tests used to live in lib_test/core_float_test.ml. *) + + let () = Random.init 137 + + (* round: + ... <-)[-><-)[-><-)[-><-)[-><-)[-><-)[-> ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so round x -. x should be in (-0.5,0.5] + *) + let round_test x = + let y = round x in + -0.5 < y -. x && y -. x <= 0.5 + + let iround_up_vs_down_test x = + let expected_difference = + if Parts.fractional (modf x) = 0. then + 0 + else + 1 + in + match iround_up x, iround_down x with + | Some x, Some y -> Int.(x - y = expected_difference) + | _, _ -> true + + let test_all_six x + ~specialized_iround ~specialized_iround_exn ~float_rounding + ~dir ~validate = + let result1 = iround x ~dir in + let result2 = Option.try_with (fun () -> iround_exn x ~dir) in + let result3 = specialized_iround x in + let result4 = Option.try_with (fun () -> specialized_iround_exn x) in + let result5 = Option.try_with (fun () -> Int.of_float (float_rounding x)) in + let result6 = Option.try_with (fun () -> Int.of_float (round ~dir x)) in + let (=) = Caml.(=) in + if result1 = result2 && result2 = result3 && result3 = result4 + && result4 = result5 && result5 = result6 then + validate result1 + else + false + + (* iround ~dir:`Nearest built so this should always be true *) + let iround_nearest_test x = + test_all_six x + ~specialized_iround:iround_nearest + ~specialized_iround_exn:iround_nearest_exn + ~float_rounding:round_nearest + ~dir:`Nearest + ~validate:(function + | None -> true + | Some y -> + let y = of_int y in + -0.5 < y -. x && y -. x <= 0.5) + + (* iround_down: + ... )[<---)[<---)[<---)[<---)[<---)[<---)[ ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so x -. iround_down x should be in [0,1) + *) + let iround_down_test x = + test_all_six x + ~specialized_iround:iround_down + ~specialized_iround_exn:iround_down_exn + ~float_rounding:round_down + ~dir:`Down + ~validate:(function + | None -> true + | Some y -> + let y = of_int y in + 0. <= x -. y && x -. y < 1.) + + (* iround_up: + ... ](--->](--->](--->](--->](--->](--->]( ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so iround_up x -. x should be in [0,1) + *) + let iround_up_test x = + test_all_six x + ~specialized_iround:iround_up + ~specialized_iround_exn:iround_up_exn + ~float_rounding:round_up + ~dir:`Up + ~validate:(function + | None -> true + | Some y -> + let y = of_int y in + 0. <= y -. x && y -. x < 1.) + + (* iround_towards_zero: + ... ](--->](--->](---><--->)[<---)[<---)[ ... + ... -+-----+-----+-----+-----+-----+-----+- ... + ... -3 -2 -1 0 1 2 3 ... + so abs x -. abs (iround_towards_zero x) should be in [0,1) + *) + let iround_towards_zero_test x = + test_all_six x + ~specialized_iround:iround_towards_zero + ~specialized_iround_exn:iround_towards_zero_exn + ~float_rounding:round_towards_zero + ~dir:`Zero + ~validate:(function + | None -> true + | Some y -> + let x = abs x in + let y = abs (of_int y) in + 0. <= x -. y && x -. y < 1. && (Sign.(sign_exn x = sign_exn y) || y = 0.0)) + + (* Easy cases that used to live inline with the code above. *) + let%test_unit _ = [%test_result: int option] (iround_up (-3.4)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_up 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_up 3.4) ~expect:(Some 4) + + let%test_unit _ = [%test_result: int] (iround_up_exn (-3.4)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_up_exn 0.0) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_up_exn 3.4) ~expect:4 + + let%test_unit _ = [%test_result: int option] (iround_down (-3.4)) ~expect:(Some (-4)) + let%test_unit _ = [%test_result: int option] (iround_down 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_down 3.4) ~expect:(Some 3) + + let%test_unit _ = [%test_result: int] (iround_down_exn (-3.4)) ~expect:(-4) + let%test_unit _ = [%test_result: int] (iround_down_exn 0.0) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_down_exn 3.4) ~expect:3 + + let%test_unit _ = [%test_result: int option] (iround_towards_zero (-3.4)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_towards_zero 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_towards_zero 3.4) ~expect:(Some 3) + + let%test_unit _ = [%test_result: int] (iround_towards_zero_exn (-3.4)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_towards_zero_exn 0.0) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_towards_zero_exn 3.4) ~expect:3 + + let%test_unit _ = [%test_result: int option] (iround_nearest (-3.6)) ~expect:(Some (-4)) + let%test_unit _ = [%test_result: int option] (iround_nearest (-3.5)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_nearest (-3.4)) ~expect:(Some (-3)) + let%test_unit _ = [%test_result: int option] (iround_nearest 0.0) ~expect:(Some 0) + let%test_unit _ = [%test_result: int option] (iround_nearest 3.4) ~expect:(Some 3) + let%test_unit _ = [%test_result: int option] (iround_nearest 3.5) ~expect:(Some 4) + let%test_unit _ = [%test_result: int option] (iround_nearest 3.6) ~expect:(Some 4) + + let%test_unit _ = [%test_result: int] (iround_nearest_exn (-3.6)) ~expect:(-4) + let%test_unit _ = [%test_result: int] (iround_nearest_exn (-3.5)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_nearest_exn (-3.4)) ~expect:(-3) + let%test_unit _ = [%test_result: int] (iround_nearest_exn 0.0 ) ~expect:0 + let%test_unit _ = [%test_result: int] (iround_nearest_exn 3.4 ) ~expect:3 + let%test_unit _ = [%test_result: int] (iround_nearest_exn 3.5 ) ~expect:4 + let%test_unit _ = [%test_result: int] (iround_nearest_exn 3.6 ) ~expect:4 + + let special_values_test () = + [%test_result: float] (round (-.1.50001)) ~expect:(-.2.); + [%test_result: float] (round (-.1.5)) ~expect:(-.1.); + [%test_result: float] (round (-.0.50001)) ~expect:(-.1.); + [%test_result: float] (round (-.0.5)) ~expect:0.; + [%test_result: float] (round 0.49999) ~expect:0.; + [%test_result: float] (round 0.5) ~expect:1.; + [%test_result: float] (round 1.49999) ~expect:1.; + [%test_result: float] (round 1.5) ~expect:2.; + [%test_result: int] (iround_exn ~dir:`Up (-.2.)) ~expect:(-2); + [%test_result: int] (iround_exn ~dir:`Up (-.1.9999)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Up (-.1.)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Up (-.0.9999)) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Up 0.) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Up 0.00001) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Up 1.) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Up 1.00001) ~expect:2; + [%test_result: int] (iround_up_exn (-.2.)) ~expect:(-2); + [%test_result: int] (iround_up_exn (-.1.9999)) ~expect:(-1); + [%test_result: int] (iround_up_exn (-.1.)) ~expect:(-1); + [%test_result: int] (iround_up_exn (-.0.9999)) ~expect:0; + [%test_result: int] (iround_up_exn 0.) ~expect:0; + [%test_result: int] (iround_up_exn 0.00001) ~expect:1; + [%test_result: int] (iround_up_exn 1.) ~expect:1; + [%test_result: int] (iround_up_exn 1.00001) ~expect:2; + [%test_result: int] (iround_exn ~dir:`Down (-.1.00001)) ~expect:(-2); + [%test_result: int] (iround_exn ~dir:`Down (-.1.)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Down (-.0.00001)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Down 0.) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Down 0.99999) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Down 1.) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Down 1.99999) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Down 2.) ~expect:2; + [%test_result: int] (iround_down_exn (-.1.00001)) ~expect:(-2); + [%test_result: int] (iround_down_exn (-.1.)) ~expect:(-1); + [%test_result: int] (iround_down_exn (-.0.00001)) ~expect:(-1); + [%test_result: int] (iround_down_exn 0.) ~expect:0; + [%test_result: int] (iround_down_exn 0.99999) ~expect:0; + [%test_result: int] (iround_down_exn 1.) ~expect:1; + [%test_result: int] (iround_down_exn 1.99999) ~expect:1; + [%test_result: int] (iround_down_exn 2.) ~expect:2; + [%test_result: int] (iround_exn ~dir:`Zero (-.2.)) ~expect:(-2); + [%test_result: int] (iround_exn ~dir:`Zero (-.1.99999)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Zero (-.1.)) ~expect:(-1); + [%test_result: int] (iround_exn ~dir:`Zero (-.0.99999)) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Zero 0.99999) ~expect:0; + [%test_result: int] (iround_exn ~dir:`Zero 1.) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Zero 1.99999) ~expect:1; + [%test_result: int] (iround_exn ~dir:`Zero 2.) ~expect:2 + + let is_64_bit_platform = of_int Int.max_value >= 2. **. 60. + + (* Tests for values close to [iround_lbound] and [iround_ubound]. *) + let extremities_test ~round = + let (+) = Int.(+) in + let (-) = Int.(-) in + if is_64_bit_platform then ( + (* 64 bits *) + [%test_result: int option] (round (2.0 **. 62. -. 512.)) ~expect:(Some (Int.max_value - 511)); + [%test_result: int option] (round (2.0 **. 62. -. 1024.)) ~expect:(Some (Int.max_value - 1023)); + [%test_result: int option] (round (-. (2.0 **. 62.))) ~expect:(Some Int.min_value); + [%test_result: int option] (round (-. (2.0 **. 62. -. 512.))) ~expect:(Some (Int.min_value + 512)); + [%test_result: int option] (round (2.0 **. 62.)) ~expect:None; + [%test_result: int option] (round (-. (2.0 **. 62. +. 1024.))) ~expect:None) + else ( + let int_size_minus_one = of_int (Int.num_bits - 1) in + (* 32 bits *) + [%test_result: int option] (round (2.0 **. int_size_minus_one -. 1.)) ~expect:(Some Int.max_value); + [%test_result: int option] (round (2.0 **. int_size_minus_one -. 2.)) ~expect:(Some (Int.max_value - 1)); + [%test_result: int option] (round (-. (2.0 **. int_size_minus_one))) ~expect:(Some Int.min_value); + [%test_result: int option] (round (-. (2.0 **. int_size_minus_one -. 1.))) ~expect:(Some (Int.min_value + 1)); + [%test_result: int option] (round (2.0 **. int_size_minus_one)) ~expect:None; + [%test_result: int option] (round (-. (2.0 **. int_size_minus_one +. 1.))) ~expect:None) + + let%test_unit _ = extremities_test ~round:iround_down + let%test_unit _ = extremities_test ~round:iround_up + let%test_unit _ = extremities_test ~round:iround_nearest + let%test_unit _ = extremities_test ~round:iround_towards_zero + + (* test values beyond the integers range *) + let large_value_test x = + [%test_result: int option] (iround_down x) ~expect:None; + [%test_result: int option] (iround ~dir:`Down x) ~expect:None; + [%test_result: int option] (iround_up x) ~expect:None; + [%test_result: int option] (iround ~dir:`Up x) ~expect:None; + [%test_result: int option] (iround_towards_zero x) ~expect:None; + [%test_result: int option] (iround ~dir:`Zero x) ~expect:None; + [%test_result: int option] (iround_nearest x) ~expect:None; + [%test_result: int option] (iround ~dir:`Nearest x) ~expect:None; + + assert (Exn.does_raise (fun () -> iround_down_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Down x)); + assert (Exn.does_raise (fun () -> iround_up_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Up x)); + assert (Exn.does_raise (fun () -> iround_towards_zero_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Zero x)); + assert (Exn.does_raise (fun () -> iround_nearest_exn x)); + assert (Exn.does_raise (fun () -> iround_exn ~dir:`Nearest x)); + + [%test_result: float] (round_down x) ~expect:x; + [%test_result: float] (round ~dir:`Down x) ~expect:x; + [%test_result: float] (round_up x) ~expect:x; + [%test_result: float] (round ~dir:`Up x) ~expect:x; + [%test_result: float] (round_towards_zero x) ~expect:x; + [%test_result: float] (round ~dir:`Zero x) ~expect:x; + [%test_result: float] (round_nearest x) ~expect:x; + [%test_result: float] (round ~dir:`Nearest x) ~expect:x + + let large_numbers = + let (+) = Int.(+) in + let (-) = Int.(-) in + List.concat ( + List.init (1024 - 64) ~f:(fun x -> + let x = of_int (x + 64) in + let y = + [2. **. x; + 2. **. x -. 2. **. (x -. 53.); (* one ulp down *) + 2. **. x +. 2. **. (x -. 52.)] (* one ulp up *) + in + y @ (List.map y ~f:neg))) + @ + [infinity; + neg_infinity] + + let%test_unit _ = List.iter large_numbers ~f:large_value_test + + let numbers_near_powers_of_two = + List.concat ( + List.init 64 ~f:(fun i -> + let pow2 = 2. **. of_int i in + let x = + [ pow2; + one_ulp `Down (pow2 +. 0.5); + pow2 +. 0.5; + one_ulp `Down (pow2 +. 1.0); + pow2 +. 1.0; + one_ulp `Down (pow2 +. 1.5); + pow2 +. 1.5; + one_ulp `Down (pow2 +. 2.0); + pow2 +. 2.0; + one_ulp `Down (pow2 *. 2.0 -. 1.0); + one_ulp `Down pow2; + one_ulp `Up pow2 + ] + in + x @ (List.map x ~f:neg) + )) + + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_up_vs_down_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_nearest_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_down_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_up_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:iround_towards_zero_test + let%test _ = List.for_all numbers_near_powers_of_two ~f:round_test + + (* code for generating random floats on which to test functions *) + let rec absirand () = + let open Int.O in + let rec aux acc cnt = + if cnt = 0 then + acc + else + let bit = if Random.bool () then 1 else 0 in + aux (2 * acc + bit) (cnt - 1) + in + let result = aux 0 (if is_64_bit_platform then 62 else 30) in + if result >= Int.max_value - 255 then + (* On a 64-bit box, [float x > Int.max_value] when [x >= Int.max_value - 255], so + [iround (float x)] would be out of bounds. So we try again. This branch of code + runs with probability 6e-17 :-) As such, we have some fixed tests in + [extremities_test] above, to ensure that we do always check some examples in + that range. *) + absirand () + else + result + + (* -Int.max_value <= frand () <= Int.max_value *) + let frand () = + let x = (of_int (absirand ())) +. Random.float 1.0 in + if Random.bool () then + -1.0 *. x + else + x + + let randoms = List.init ~f:(fun _ -> frand ()) 10_000 + + let%test _ = List.for_all randoms ~f:iround_up_vs_down_test + let%test _ = List.for_all randoms ~f:iround_nearest_test + let%test _ = List.for_all randoms ~f:iround_down_test + let%test _ = List.for_all randoms ~f:iround_up_test + let%test _ = List.for_all randoms ~f:iround_towards_zero_test + let%test _ = List.for_all randoms ~f:round_test + let%test_unit _ = special_values_test () + let%test _ = iround_nearest_test (of_int Int.max_value) + let%test _ = iround_nearest_test (of_int Int.min_value) + end) + +module Test_bounds ( + I : sig + type t + val num_bits : int + val of_float : float -> t + val to_int64 : t -> Int64.t + val max_value : t + val min_value : t + end + ) = struct + open I + + let float_lower_bound = lower_bound_for_int num_bits + let float_upper_bound = upper_bound_for_int num_bits + + let%test_unit "lower bound is valid" = ignore (of_float float_lower_bound : t) + let%test_unit "upper bound is valid" = ignore (of_float float_upper_bound : t) + + let%test "smaller than lower bound is not valid" = + Exn.does_raise (fun () -> of_float (one_ulp `Down float_lower_bound)) + let%test "bigger than upper bound is not valid" = + Exn.does_raise (fun () -> of_float (one_ulp `Up float_upper_bound)) + + (* We use [Caml.Int64.of_float] in the next two tests because [Int64.of_float] rejects + out-of-range inputs, whereas [Caml.Int.of_float] simply overflows (returns + [Int64.min_int]). *) + + let%test "smaller than lower bound overflows" = + let lower_bound = Int64.of_float float_lower_bound in + let lower_bound_minus_epsilon = Caml.Int64.of_float (one_ulp `Down float_lower_bound) in + let min_value = to_int64 min_value in + if Int.(=) num_bits 64 + (* We cannot detect overflow because on Intel overflow results in min_value. *) + then true + else begin + assert (Int64.(<=) lower_bound_minus_epsilon lower_bound); + (* a value smaller than min_value would overflow if converted to [t] *) + Int64.(<) lower_bound_minus_epsilon min_value + end + + let%test "bigger than upper bound overflows" = + let upper_bound = Int64.of_float float_upper_bound in + let upper_bound_plus_epsilon = Caml.Int64.of_float (one_ulp `Up float_upper_bound) in + let max_value = to_int64 max_value in + if Int.(=) num_bits 64 + (* upper_bound_plus_epsilon is not representable as a Int64.t, it has overflowed *) + then Int64.(<) upper_bound_plus_epsilon upper_bound + else begin + assert (Int64.(>=) upper_bound_plus_epsilon upper_bound); + (* a value greater than max_value would overflow if converted to [t] *) + Int64.(>) upper_bound_plus_epsilon max_value + end +end + +let%test_module "Int" = (module Test_bounds(Int)) +let%test_module "Int32" = (module Test_bounds(Int32)) +let%test_module "Int63" = (module Test_bounds(Int63)) +let%test_module "Int63_emul" = (module Test_bounds(Base.Not_exposed_properly.Int63_emul)) +let%test_module "Int64" = (module Test_bounds(Int64)) +let%test_module "Nativeint" = (module Test_bounds(Nativeint)) + +let%test_unit _ = [%test_result: string] (to_string 3.14) ~expect:"3.14" +let%test_unit _ = [%test_result: string] (to_string 3.1400000000000001) ~expect:"3.14" +let%test_unit _ = [%test_result: string] (to_string 3.1400000000000004) ~expect:"3.1400000000000006" +let%test_unit _ = [%test_result: string] (to_string 8.000000000000002) ~expect:"8.0000000000000018" +let%test_unit _ = [%test_result: string] (to_string 9.992) ~expect:"9.992" +let%test_unit _ = [%test_result: string] (to_string (2.**.63. *. (1. +. 2.**. (-52.)))) ~expect:"9.2233720368547779e+18" +let%test_unit _ = [%test_result: string] (to_string (-3.)) ~expect:"-3." +let%test_unit _ = [%test_result: string] (to_string nan) ~expect:"nan" +let%test_unit _ = [%test_result: string] (to_string infinity) ~expect:"inf" +let%test_unit _ = [%test_result: string] (to_string neg_infinity) ~expect:"-inf" +let%test_unit _ = [%test_result: string] (to_string 3e100) ~expect:"3e+100" +let%test_unit _ = [%test_result: string] (to_string max_finite_value) ~expect:"1.7976931348623157e+308" +let%test_unit _ = [%test_result: string] (to_string min_positive_subnormal_value) ~expect:"4.94065645841247e-324" + +let%test _ = epsilon_float = (one_ulp `Up 1.) -. 1. + +let%test _ = one_ulp_less_than_half = 0.49999999999999994 + +let%test _ = + round_down 3.6 = 3. + && round_down (-3.6) = -4. + +let%test _ = + round_up 3.6 = 4. + && round_up (-3.6) = -3. + +let%test _ = + round_towards_zero 3.6 = 3. + && round_towards_zero (-3.6) = -3. + +let%test _ = round_nearest_half_to_even 0. = 0. +let%test _ = round_nearest_half_to_even 0.5 = 0. +let%test _ = round_nearest_half_to_even (-0.5) = 0. +let%test _ = round_nearest_half_to_even (one_ulp `Up 0.5) = 1. +let%test _ = round_nearest_half_to_even (one_ulp `Down 0.5) = 0. +let%test _ = round_nearest_half_to_even (one_ulp `Up (-0.5)) = 0. +let%test _ = round_nearest_half_to_even (one_ulp `Down (-0.5)) = -1. +let%test _ = round_nearest_half_to_even 3.5 = 4. +let%test _ = round_nearest_half_to_even 4.5 = 4. +let%test _ = round_nearest_half_to_even (one_ulp `Up (-5.5)) = -5. +let%test _ = round_nearest_half_to_even 5.5 = 6. +let%test _ = round_nearest_half_to_even 6.5 = 6. +let%test _ = round_nearest_half_to_even (one_ulp `Up (-. (2. **. 52.))) = -. (2. **. 52.) +let%test _ = round_nearest (one_ulp `Up (-. (2. **. 52.))) = 1. -. (2. **. 52.) + +let%test_module _ = + (module struct + (* check we raise on invalid input *) + let must_fail f x = Exn.does_raise (fun () -> f x) + let must_succeed f x = ignore (f x); true + let%test _ = must_fail int63_round_nearest_portable_alloc_exn nan + let%test _ = must_fail int63_round_nearest_portable_alloc_exn max_value + let%test _ = must_fail int63_round_nearest_portable_alloc_exn min_value + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (2. **. 63.) + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (~-. (2. **. 63.)) + let%test _ = must_succeed int63_round_nearest_portable_alloc_exn (2. **. 62. -. 512.) + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (2. **. 62.) + let%test _ = must_fail int63_round_nearest_portable_alloc_exn (~-. (2. **. 62.) -. 1024.) + let%test _ = must_succeed int63_round_nearest_portable_alloc_exn (~-. (2. **. 62.)) + end) + +let%test _ = + round_nearest 3.6 = 4. + && round_nearest (-3.6) = -4. + + +(* The redefinition of [sexp_of_t] in float.ml assumes sexp conversion uses E rather than + e. *) +let%test_unit "e vs E" = [%test_result: Sexp.t] [%sexp (1.4e100 : t)] ~expect:(Atom "1.4E+100") + +let%test_module _ = + (module struct + let test ?delimiter ~decimals f s s_strip_zero = + let s' = to_string_hum ?delimiter ~decimals ~strip_zero:false f in + if String.(s' <> s) then + raise_s + [%message + "to_string_hum ~strip_zero:false" + ~input:(f : float) + (decimals : int) + ~got:(s' : string) + ~expected:(s : string) + ]; + let s_strip_zero' = to_string_hum ?delimiter ~decimals ~strip_zero:true f in + if String.(s_strip_zero' <> s_strip_zero) then + raise_s + [%message + "to_string_hum ~strip_zero:true" + ~input:(f : float) + (decimals : int) + ~got:(s_strip_zero : string) + ~expected:(s_strip_zero' : string) + ]; + ;; + + let%test_unit _ = test ~decimals:3 0.99999 "1.000" "1" + let%test_unit _ = test ~decimals:3 0.00001 "0.000" "0" + let%test_unit _ = test ~decimals:3 ~-.12345.1 "-12_345.100" "-12_345.1" + let%test_unit _ = test ~delimiter:',' ~decimals:3 ~-.12345.1 "-12,345.100" "-12,345.1" + let%test_unit _ = test ~decimals:0 0.99999 "1" "1" + let%test_unit _ = test ~decimals:0 0.00001 "0" "0" + let%test_unit _ = test ~decimals:0 ~-.12345.1 "-12_345" "-12_345" + let%test_unit _ = test ~decimals:0 (5.0 /. 0.0) "inf" "inf" + let%test_unit _ = test ~decimals:0 (-5.0 /. 0.0) "-inf" "-inf" + let%test_unit _ = test ~decimals:0 (0.0 /. 0.0) "nan" "nan" + let%test_unit _ = test ~decimals:2 (5.0 /. 0.0) "inf" "inf" + let%test_unit _ = test ~decimals:2 (-5.0 /. 0.0) "-inf" "-inf" + let%test_unit _ = test ~decimals:2 (0.0 /. 0.0) "nan" "nan" + let%test_unit _ = test ~decimals:5 (10_000.0 /. 3.0) "3_333.33333" "3_333.33333" + let%test_unit _ = test ~decimals:2 ~-.0.00001 "-0.00" "-0" + + let rand_test n = + let go () = + let f = Random.float 1_000_000.0 -. 500_000.0 in + let repeatable to_str = + let s = to_str f in + if String.(<>) (String.split s ~on:',' |> String.concat |> of_string |> to_str) s + then raise_s [%message "failed" (f : t)] + in + repeatable (to_string_hum ~decimals:3 ~strip_zero:false); + in + try + for _ = 0 to Int.(-) n 1 do go () done; + true + with e -> + eprintf "%s\n%!" (Exn.to_string e); + false + ;; + + let%test _ = rand_test 10_000 + ;; + end) +;; + +let%test_module "Hexadecimal syntax" = + (module struct + + let should_fail str = Exn.does_raise (fun () -> Caml.float_of_string str) + + let test_equal str g = Caml.float_of_string str = g + let%test _ = should_fail "0x" + let%test _ = should_fail "0x.p0" + let%test _ = test_equal "0x0" 0. + let%test _ = test_equal "0x1.b7p-1" 0.857421875 + let%test _ = test_equal "0x1.999999999999ap-4" 0.1 + + end) +;; + +let%expect_test "square" = + printf "%f\n" (square 1.5); + printf "%f\n" (square (-2.5)); + [%expect {| + 2.250000 + 6.250000 |}] +;; + +let%expect_test "mathematical constants" = + (* Compare to the from-string conversion of numbers from Wolfram Alpha *) + let eq x s = assert (x = of_string s) in + eq pi "3.141592653589793238462643383279502884197169399375105820974"; + eq sqrt_pi "1.772453850905516027298167483341145182797549456122387128213"; + eq sqrt_2pi "2.506628274631000502415765284811045253006986740609938316629"; + eq euler "0.577215664901532860606512090082402431042159335939923598805"; + (* Check size of diff from ordinary computation. *) + printf "sqrt pi diff : %.20f\n" (sqrt_pi - sqrt pi); + printf "sqrt 2pi diff : %.20f\n" (sqrt_2pi - sqrt (2. * pi)); + [%expect {| + sqrt pi diff : 0.00000000000000022204 + sqrt 2pi diff : 0.00000000000000044409 |}] + +let%test _ = not (is_negative Float.nan) +let%test _ = not (is_non_positive Float.nan) +let%test _ = is_non_negative (-0.) + + +let%test_unit "int to float conversion consistency" = + let test_int63 x = + [%test_result: float] (Float.of_int63 x) ~expect:(Float.of_int64 (Int63.to_int64 x)) + in + let test_int x = + [%test_result: float] (Float.of_int x) ~expect:(Float.of_int63 (Int63.of_int x)); + test_int63 (Int63.of_int x) + in + test_int 0; + test_int 35; + test_int (-1); + test_int Int.max_value; + test_int Int.min_value; + + test_int63 Int63.zero; + test_int63 Int63.min_value; + test_int63 Int63.max_value; + + let rand = Random.State.make [| Hashtbl.hash "int to float conversion consistency" |] in + for _i = 0 to 100 do + let x = Random.State.int rand Int.max_value in + test_int x; + done; + () +;; diff --git a/test/test_float.mli b/test/test_float.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_float.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_fn.ml b/test/test_fn.ml new file mode 100644 index 0000000..7e866ed --- /dev/null +++ b/test/test_fn.ml @@ -0,0 +1,12 @@ +open! Import +open! Fn + +(* enforce that we're testing [Fn.(|>)] and not ppx_pipebang. *) +let (_ : 'a -> ('a -> 'b) -> 'b) = (|>) + +let%test _ = 1 |> fun x -> x = 1 +let%test _ = 1 |> fun x -> x + 1 |> fun y -> y = 2 + +let%test _ = 0 = apply_n_times ~n:0 (fun _ -> assert false) 0 +let%test _ = 0 = apply_n_times ~n:(-3) (fun _ -> assert false) 0 +let%test _ = 10 = apply_n_times ~n:10 ((+) 1) 0 diff --git a/test/test_fn.mli b/test/test_fn.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_fn.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_hash_set.ml b/test/test_hash_set.ml new file mode 100644 index 0000000..d6c2731 --- /dev/null +++ b/test/test_hash_set.ml @@ -0,0 +1,64 @@ +open! Import +open! Hash_set + +let%test_module "Set Intersection" = + (module struct + let run_test first_contents second_contents ~expect = + let of_list lst = + let s = create (module String) in + List.iter lst ~f:(add s); + s + in + let s1 = of_list first_contents in + let s2 = of_list second_contents in + let expect = of_list expect in + let result = inter s1 s2 in + iter result ~f:(fun x -> assert (mem expect x)); + iter expect ~f:(fun x -> assert (mem result x)); + let equal x y = 0 = String.compare x y in + assert (List.equal equal (to_list result) (to_list expect)); + assert ((length result) = (length expect)); + (* Make sure the sets are unmodified by the inter *) + assert ((List.length first_contents) = length s1); + assert ((List.length second_contents) = length s2) + ;; + + let%test_unit "First smaller" = + run_test ["0"; "3"; "99"] + ["0";"1";"2";"3"] + ~expect:["0"; "3"] + + let%test_unit "Second smaller" = + run_test ["a";"b";"c";"d"] + [ "b"; "d"] + ~expect:[ "b"; "d"] + + let%test_unit "No intersection" = + run_test ~expect:[] ["a";"b";"c";"d"] ["1";"2";"3";"4"] + end) + +let%expect_test "sexp" = + let ints = List.init 20 ~f:(fun x -> x * x) in + let int_hash_set = Hash_set.of_list (module Int) ints in + print_s [%sexp (int_hash_set : int Hash_set.t)]; + [%expect {| (0 1 4 9 16 25 36 49 64 81 100 121 144 169 196 225 256 289 324 361) |}]; + let strs = List.init 20 ~f:(fun x -> Int.to_string x) in + let str_hash_set = Hash_set.of_list (module String) strs in + print_s [%sexp (str_hash_set : string Hash_set.t)]; + [%expect {| (0 1 10 11 12 13 14 15 16 17 18 19 2 3 4 5 6 7 8 9) |}]; +;; + +let%expect_test "to_array" = + let empty_array = to_array (Hash_set.of_list (module Int) []) in + print_s [%sexp (empty_array : int Array.t)]; + [%expect {| () |}]; + let array_from_to_array = to_array (Hash_set.of_list (module Int) [1; 2; 3; 4; 5;]) in + print_s [%sexp (array_from_to_array : int Array.t)]; + [%expect {| (1 3 2 4 5) |}]; + let array_via_to_list = + to_list (Hash_set.of_list (module Int) [1; 2; 3; 4; 5;]) |> Array.of_list + in + print_s [%sexp (array_via_to_list : int Array.t)]; + [%expect {| (1 3 2 4 5) |}]; +;; + diff --git a/test/test_hash_set.mli b/test/test_hash_set.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_hash_set.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_hashtbl.ml b/test/test_hashtbl.ml new file mode 100644 index 0000000..3e1eda7 --- /dev/null +++ b/test/test_hashtbl.ml @@ -0,0 +1,24 @@ +open! Base + +type int_hashtbl = int Hashtbl.M(Int).t [@@deriving sexp] + +let%test "Hashtbl.merge succeeds with first-class-module interface" = + let t1 = Hashtbl.create (module Int) in + let t2 = Hashtbl.create (module Int) in + let result = + Hashtbl.merge t1 t2 ~f:(fun ~key:_ -> function + | `Left x -> x + | `Right x -> x + | `Both _ -> assert false) + |> Hashtbl.to_alist + in + List.equal Poly.equal result [] + +let%test_module _ = (module Hashtbl_tests.Make(struct + include Hashtbl + + let create_poly ?size () = Poly.create ?size () + + let of_alist_poly_exn l = Poly.of_alist_exn l + let of_alist_poly_or_error l = Poly.of_alist_or_error l + end)) diff --git a/test/test_hashtbl.mli b/test/test_hashtbl.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_hashtbl.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_identifiable.ml b/test/test_identifiable.ml new file mode 100644 index 0000000..2045baa --- /dev/null +++ b/test/test_identifiable.ml @@ -0,0 +1,17 @@ +open! Import +open! Identifiable + +module T = struct + type t = string + include + Make (struct + let module_name = "test" + include String + end) +end + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module T) + ([ ""; "a"; "foo" ] |> List.map ~f:T.of_string); + [%expect {| |}]; +;; diff --git a/test/test_identifiable.mli b/test/test_identifiable.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_identifiable.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_indexed_container.ml b/test/test_indexed_container.ml new file mode 100644 index 0000000..a695c04 --- /dev/null +++ b/test/test_indexed_container.ml @@ -0,0 +1,181 @@ +open! Import + +module type S = Indexed_container.S1 with type 'a t = 'a list + +module This_list : S = struct + include List + include (Indexed_container.Make (struct + type 'a t = 'a list + let fold = List.fold + let iter = `Custom List.iter + let length = `Custom List.length + let foldi = `Define_using_fold + let iteri = `Define_using_fold + end)) +end + +module That_list : S = List + +let examples = + [ [] + ; [1] + ; [2; 3] + ; [4; 5; 1] + ; List.init 8 ~f:(fun i -> i*i) + ] +;; + +module type Output = sig + type t [@@deriving compare, sexp_of] +end + +module Int_list = struct + type t = int list [@@deriving compare, sexp_of] +end + +module Int_pair_option = struct + type t = (int * int) option [@@deriving compare, sexp_of] +end + +module Int_option = struct + type t = int option [@@deriving compare, sexp_of] +end + +let check (type a) + here + examples + ~actual + ~expect + (module Output : Output with type t = a) = + List.iter examples ~f:(fun example -> + let actual = actual example in + let expect = expect example in + require here (Output.compare actual expect = 0) + ~if_false_then_print_s:(lazy [%message (expect : Output.t)]); + print_s [%sexp (actual : Output.t)]); +;; + +let%expect_test "foldi" = + let f i acc elt = + if i % 2 = 0 then elt :: acc else acc + in + check [%here] examples (module Int_list) + ~actual:(fun list -> This_list.foldi list ~init:[] ~f) + ~expect:(fun list -> That_list.foldi list ~init:[] ~f); + [%expect {| + () + (1) + (2) + (1 4) + (36 16 4 0) |}] +;; + +let%expect_test "findi" = + let check f = + check [%here] examples (module Int_pair_option) + ~actual:(fun list -> This_list.findi list ~f) + ~expect:(fun list -> That_list.findi list ~f); + in + check (fun i _elt -> i = 0); + [%expect {| + () + ((0 1)) + ((0 2)) + ((0 4)) + ((0 0)) |}]; + check (fun _i elt -> elt = 1); + [%expect {| + () + ((0 1)) + () + ((2 1)) + ((1 1)) |}]; +;; + +let%expect_test "find_mapi" = + let f i elt = + if elt = 1 then Some (i * 100 + elt) else None + in + check [%here] examples (module Int_option) + ~actual:(fun list -> This_list.find_mapi list ~f) + ~expect:(fun list -> That_list.find_mapi list ~f); + [%expect {| + () + (1) + () + (201) + (101) |}]; +;; + +let%expect_test "iteri" = + let go iteri = + let acc = ref [] in + iteri ~f:(fun i elt -> acc := i :: elt :: !acc); + !acc + in + check [%here] examples (module Int_list) + ~actual:(fun list -> go (This_list.iteri list)) + ~expect:(fun list -> go (That_list.iteri list)); + [%expect {| + () + (0 1) + (1 3 0 2) + (2 1 1 5 0 4) + (7 49 6 36 5 25 4 16 3 9 2 4 1 1 0 0) |}]; +;; + +let bool_examples = + [ []; + [true]; + [false]; + [false; false]; + [true; false]; + [false; true]; + [true; true]; + ] +;; + +let%expect_test "for_alli" = + let f _i elt = elt in + check [%here] bool_examples (module Bool) + ~actual:(fun list -> This_list.for_alli list ~f) + ~expect:(fun list -> That_list.for_alli list ~f); + [%expect {| + true + true + false + false + false + false + true |}]; +;; + +let%expect_test "existsi" = + let f _i elt = elt in + check [%here] bool_examples (module Bool) + ~actual:(fun list -> This_list.existsi list ~f) + ~expect:(fun list -> That_list.existsi list ~f); + [%expect {| + false + true + false + false + true + true + true |}]; +;; + +let%expect_test "counti" = + let f _i elt = elt in + check [%here] bool_examples (module Int) + ~actual:(fun list -> This_list.counti list ~f) + ~expect:(fun list -> That_list.counti list ~f); + [%expect {| + 0 + 1 + 0 + 0 + 1 + 1 + 2 |}]; +;; diff --git a/test/test_indexed_container.mli b/test/test_indexed_container.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_indexed_container.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_info.ml b/test/test_info.ml new file mode 100644 index 0000000..81c9501 --- /dev/null +++ b/test/test_info.ml @@ -0,0 +1,62 @@ +open! Import +open! Info + +let%test_unit _ = + [%test_result: string] (to_string_hum (of_exn (Failure "foo"))) + ~expect:"(Failure foo)" +;; + +let%test_unit _ = + [%test_result: string] (to_string_hum (tag (of_string "b") ~tag:"a")) + ~expect:"(a b)" +;; + +let%test_unit _ = + [%test_result: string] + (to_string_hum (of_list (List.map ~f:of_string [ "a"; "b"; "c" ]))) + ~expect:"(a b c)" +;; + +let of_strings strings = of_list (List.map ~f:of_string strings) + +let nested = + of_list + (List.map ~f:of_strings + [ [ "a"; "b"; "c" ] + ; [ "d"; "e"; "f" ] + ; [ "g"; "h"; "i" ] + ]) +;; + +let%test_unit _ = + [%test_result: string] (to_string_hum nested) ~expect:"(a b c d e f g h i)" +;; + +let%test_unit _ = + [%test_result: Sexp.t] (sexp_of_t nested) + ~expect:(sexp_of_t (of_strings [ "a"; "b"; "c" + ; "d"; "e"; "f" + ; "g"; "h"; "i" ])) +;; + +let%test_unit _ = + match to_exn (of_exn (Failure "foo")) with + | Failure "foo" -> () + | exn -> raise_s [%sexp { got : exn = exn; expected = Failure "foo" }] +;; + +let round t = + let sexp = sexp_of_t t in + Sexp.(=) sexp (sexp_of_t (t_of_sexp sexp)) +;; + +let%test _ = round (of_string "hello") +let%test _ = round (of_thunk (fun () -> "hello")) +let%test _ = round (create "tag" 13 [%sexp_of: int]) +let%test _ = round (tag (of_string "hello") ~tag:"tag") +let%test _ = round (tag_arg (of_string "hello") "tag" 13 + [%sexp_of: int]) +let%test _ = round (of_list [ of_string "hello"; of_string "goodbye" ]) +let%test _ = round (t_of_sexp (Sexplib.Sexp.of_string "((random sexp 1)(b 2)((c (1 2 3))))")) + +let%test _ = String.equal (to_string_hum (of_string "a\nb")) "a\nb" diff --git a/test/test_info.mli b/test/test_info.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_info.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int.ml b/test/test_int.ml new file mode 100644 index 0000000..e380bfe --- /dev/null +++ b/test/test_int.ml @@ -0,0 +1,117 @@ +open! Import +open! Int + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_int_hash_coherence [%here] (module Int); + [%expect {| |}]; +;; + +let%expect_test "[max_value_30_bits]" = + print_s [%sexp (max_value_30_bits : t)]; + [%expect {| + 1_073_741_823 |}]; +;; + +let%expect_test "hex" = + let test x = + let n = Or_error.try_with (fun () -> Int.Hex.of_string x) in + print_s [%message (n : int Or_error.t)] + in + test "0x1c5f"; + [%expect {| + (n (Ok 7_263)) + |}]; + test "0x1c5f NON-HEX-GARBAGE"; + [%expect {| + (n ( + Error ( + Failure + "Base.Int.Hex.of_string: invalid input \"0x1c5f NON-HEX-GARBAGE\""))) + |}] + +let%test_module "Hex" = + (module struct + + let f (i,s_hum) = + let s = String.filter s_hum ~f:(fun c -> not (Char.equal c '_')) in + let sexp_hum = Sexp.Atom s_hum in + let sexp = Sexp.Atom s in + [%test_result: Sexp.t] ~message:"sexp_of_t" ~expect:sexp (Hex.sexp_of_t i); + [%test_result: int] ~message:"t_of_sexp" ~expect:i (Hex.t_of_sexp sexp); + [%test_result: int] ~message:"t_of_sexp[human]" ~expect:i (Hex.t_of_sexp sexp_hum); + [%test_result: string] ~message:"to_string" ~expect:s (Hex.to_string i); + [%test_result: string] ~message:"to_string_hum" ~expect:s_hum (Hex.to_string_hum i); + [%test_result: int] ~message:"of_string" ~expect:i (Hex.of_string s); + [%test_result: int] ~message:"of_string[human]" ~expect:i (Hex.of_string s_hum); + ;; + + let%test_unit _ = + List.iter ~f + [ 0, "0x0" + ; 1, "0x1" + ; 2, "0x2" + ; 5, "0x5" + ; 10, "0xa" + ; 16, "0x10" + ; 254, "0xfe" + ; 65_535, "0xffff" + ; 65_536, "0x1_0000" + ; 1_000_000, "0xf_4240" + ; -1, "-0x1" + ; -2, "-0x2" + ; -1_000_000, "-0xf_4240" + ; max_value, + (match num_bits with + | 31 -> "0x3fff_ffff" + | 32 -> "0x7fff_ffff" + | 63 -> "0x3fff_ffff_ffff_ffff" + | _ -> assert false) + ; min_value, + (match num_bits with + | 31 -> "-0x4000_0000" + | 32 -> "-0x8000_0000" + | 63 -> "-0x4000_0000_0000_0000" + | _ -> assert false) + ] + + let%test_unit _ = + [%test_result: int] (Hex.of_string "0XA") ~expect:10 + + let%test_unit _ = + match Option.try_with (fun () -> Hex.of_string "0") with + | None -> () + | Some _ -> failwith "Hex must always have a 0x prefix." + + let%test_unit _ = + match Option.try_with (fun () -> Hex.of_string "0x_0") with + | None -> () + | Some _ -> failwith "Hex may not have '_' before the first digit." + + end) + +let%test _ = (neg 5 + 5 = 0) + +let%test _ = pow min_value 1 = min_value +let%test _ = pow max_value 1 = max_value + +let%test "comparisons" = + let original_compare (x : int) y = Caml.compare x y in + let valid_compare x y = + let result = compare x y in + let expect = original_compare x y in + assert (Bool.(=) (result < 0) (expect < 0)); + assert (Bool.(=) (result > 0) (expect > 0)); + assert (Bool.(=) (result = 0) (expect = 0)); + assert (result = expect); + in + (valid_compare min_value min_value); + (valid_compare min_value (-1)); + (valid_compare (-1) min_value); + (valid_compare min_value 0); + (valid_compare 0 min_value); + (valid_compare max_value (-1)); + (valid_compare (-1) max_value); + (valid_compare max_value min_value); + (valid_compare max_value max_value); + true +;; diff --git a/test/test_int.mli b/test/test_int.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int32.ml b/test/test_int32.ml new file mode 100644 index 0000000..2be6461 --- /dev/null +++ b/test/test_int32.ml @@ -0,0 +1,8 @@ + +open! Import +open! Int32 + +let%expect_test "hash coherence" = + check_int_hash_coherence [%here] (module Int32); + [%expect {| |}]; +;; diff --git a/test/test_int32.mli b/test/test_int32.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int32.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int32_pow2.ml b/test/test_int32_pow2.ml new file mode 100644 index 0000000..90ad146 --- /dev/null +++ b/test/test_int32_pow2.ml @@ -0,0 +1,114 @@ +open! Import +open! Int32 + +let of_ints = List.map ~f:(of_int_exn);; + +let examples = + of_ints + [ -1 + ; 0 + ; 1 + ; 2 + ; 3 + ; 4 + ; 5 + ; 7 + ; 8 + ; 9 + ; 63 + ; 64 + ; 65] +;; + +let examples_64_bit = + [ min_value + ; succ min_value + ; pred max_value + ; max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : int32) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Int32.floor_log2] got invalid input" -1))) + (0 (Error ("[Int32.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-2_147_483_648 (Error ("[Int32.floor_log2] got invalid input" -2147483648))) + (-2_147_483_647 (Error ("[Int32.floor_log2] got invalid input" -2147483647))) + (2_147_483_646 (Ok 30)) + (2_147_483_647 (Ok 30)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Int32.ceil_log2] got invalid input" -1))) + (0 (Error ("[Int32.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-2_147_483_648 (Error ("[Int32.ceil_log2] got invalid input" -2147483648))) + (-2_147_483_647 (Error ("[Int32.ceil_log2] got invalid input" -2147483647))) + (2_147_483_646 (Ok 31)) + (2_147_483_647 (Ok 31)) |}]; +;; + +let%test_module "int_math" = + (module struct + + let test_cases () = + of_ints [ 0b10101010; 0b1010101010101010; 0b101010101010101010101010; + 0b10000000; 0b1000000000001000; 0b100000000000000000001000; ] + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / of_int_exn 2)) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && ((of_int_exn 2 * p2) >= x && x >= p2) ) + ) + ;; + + end) diff --git a/test/test_int32_pow2.mli b/test/test_int32_pow2.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int32_pow2.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int63.ml b/test/test_int63.ml new file mode 100644 index 0000000..6c90ae7 --- /dev/null +++ b/test/test_int63.ml @@ -0,0 +1,76 @@ +open! Import +open! Int63 + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_int_hash_coherence [%here] (module Int63); + [%expect {| |}]; +;; + +let%test_unit _ = [%test_result: t] max_value ~expect:(of_int64_exn 4611686018427387903L) +let%test_unit _ = [%test_result: t] min_value ~expect:(of_int64_exn (-4611686018427387904L)) + +let%test_unit _ = + [%test_result: t] (of_int32_exn Int32.min_value) ~expect:(of_int32 Int32.min_value) +let%test_unit _ = + [%test_result: t] (of_int32_exn Int32.max_value) ~expect:(of_int32 Int32.max_value) + +let%test "typical random 0" = Exn.does_raise (fun () -> random zero) + +let%test_module "Overflow_exn" = + (module struct + open Overflow_exn + + let%test_module "( + )" = + (module struct + let test t = Exn.does_raise (fun () -> t + t) + let%test "max_value / 2 + 1" = test (succ (max_value / of_int 2)) + let%test "min_value / 2 - 1" = test (pred (min_value / of_int 2)) + let%test "min_value + min_value" = test min_value + let%test "max_value + max_value" = test max_value + end) + ;; + + let%test_module "( - )" = + (module struct + let%test "min_value - 1" = Exn.does_raise (fun () -> min_value - one) + let%test "max_value - -1" = Exn.does_raise (fun () -> max_value - neg one) + let%test "min_value / 2 - max_value / 2 - 2" = + Exn.does_raise (fun () -> min_value / of_int 2 - max_value / of_int 2 - of_int 2) + let%test "min_value - max_value" = Exn.does_raise (fun () -> min_value - max_value) + let%test "max_value - min_value" = Exn.does_raise (fun () -> max_value - min_value) + let%test "max_value - -max_value" = + Exn.does_raise (fun () -> max_value - neg max_value) + end) + ;; + end) + +let%expect_test "[floor_log2]" = + let floor_log2 t = print_s [%sexp (floor_log2 t : int)] in + show_raise (fun () -> floor_log2 zero); + [%expect {| + (raised ("[Int.floor_log2] got invalid input" 0)) |}]; + floor_log2 one; + [%expect {| + 0 |}]; + for i = 1 to 8 do + floor_log2 (i |> of_int); + done; + [%expect {| + 0 + 1 + 1 + 2 + 2 + 2 + 2 + 3 |}]; + floor_log2 (one lsl 61 - one); + [%expect {| + 60 |}]; + floor_log2 (one lsl 61); + [%expect {| + 61 |}]; + floor_log2 max_value; + [%expect {| + 61 |}]; +;; diff --git a/test/test_int63.mli b/test/test_int63.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int63.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int63_emul.ml b/test/test_int63_emul.ml new file mode 100644 index 0000000..16b7031 --- /dev/null +++ b/test/test_int63_emul.ml @@ -0,0 +1,14 @@ +open! Core_kernel +open! Expect_test_helpers_kernel + +module Int63_emul = Base.Not_exposed_properly.Int63_emul + +let%expect_test _ = + let s63 = Int63.( Hex.to_string min_value) in + let s63_emul = Int63_emul.(Hex.to_string min_value) in + print_s [%message (s63 : string) (s63_emul : string)]; + require [%here] (String.equal s63 s63_emul); + [%expect {| + ((s63 -0x4000000000000000) + (s63_emul -0x4000000000000000)) |}]; +;; diff --git a/test/test_int63_emul.mli b/test/test_int63_emul.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int63_emul.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int64.ml b/test/test_int64.ml new file mode 100644 index 0000000..38b6521 --- /dev/null +++ b/test/test_int64.ml @@ -0,0 +1,7 @@ +open! Import +open! Int64 + +let%expect_test "hash coherence" = + check_int_hash_coherence [%here] (module Int64); + [%expect {| |}]; +;; diff --git a/test/test_int64.mli b/test/test_int64.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int64.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int64_pow2.ml b/test/test_int64_pow2.ml new file mode 100644 index 0000000..aa37f72 --- /dev/null +++ b/test/test_int64_pow2.ml @@ -0,0 +1,123 @@ +open! Import +open! Int64 + +let examples = + [ -1L + ; 0L + ; 1L + ; 2L + ; 3L + ; 4L + ; 5L + ; 7L + ; 8L + ; 9L + ; 63L + ; 64L + ; 65L + ] +;; + +let examples_64_bit = + [ min_value + ; succ min_value + ; pred max_value + ; max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : int64) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Int64.floor_log2] got invalid input" -1))) + (0 (Error ("[Int64.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Int64.floor_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Int64.floor_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 62)) + (9_223_372_036_854_775_807 (Ok 62)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Int64.ceil_log2] got invalid input" -1))) + (0 (Error ("[Int64.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Int64.ceil_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Int64.ceil_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 63)) + (9_223_372_036_854_775_807 (Ok 63)) |}]; +;; + +let%test_module "int64_math" = + (module struct + + let test_cases () = + let cases = + [ 0b10101010L; 0b1010101010101010L; 0b101010101010101010101010L; + 0b10000000L; 0b1000000000001000L; 0b100000000000000000001000L; ] + in + let cases = + cases @ [ (0b1010101010101010L lsl 16) lor 0b1010101010101010L; + (0b1000000000000000L lsl 16) lor 0b0000000000001000L; ] + in + let added_cases = List.map cases ~f:(fun x -> x lsl 16) in + List.concat [ cases; added_cases ] + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / 2L)) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && ((2L * p2) >= x && x >= p2) ) + ) + ;; + end) diff --git a/test/test_int64_pow2.mli b/test/test_int64_pow2.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int64_pow2.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int_conversions.ml b/test/test_int_conversions.ml new file mode 100644 index 0000000..32b1af4 --- /dev/null +++ b/test/test_int_conversions.ml @@ -0,0 +1,170 @@ +open! Import +open! Base.Not_exposed_properly.Int_conversions + +let%test_module "pretty" = + (module struct + + let check input output = + List.for_all [""; "+"; "-"] ~f:(fun prefix -> + let input = prefix ^ input in + let output = prefix ^ output in + [%compare.equal: string] output (insert_underscores input)) + + let%test _ = check "1" "1" + let%test _ = check "12" "12" + let%test _ = check "123" "123" + let%test _ = check "1234" "1_234" + let%test _ = check "12345" "12_345" + let%test _ = check "123456" "123_456" + let%test _ = check "1234567" "1_234_567" + let%test _ = check "12345678" "12_345_678" + let%test _ = check "123456789" "123_456_789" + let%test _ = check "1234567890" "1_234_567_890" + + end) + +let%test_module "conversions" = + (module struct + module type S = sig + include Int.S + val module_name : string + end + + let test_conversion (type a) (type b) loc ma mb + a_to_b_or_error + a_to_b_trunc + b_to_a_trunc + = + let (module A : S with type t = a) = ma in + let (module B : S with type t = b) = mb in + let examples = + [ A.min_value + ; A.minus_one + ; A.zero + ; A.one + ; A.max_value + ; B.min_value |> b_to_a_trunc + ; B.max_value |> b_to_a_trunc + ] + |> List.concat_map ~f:(fun a -> [ A.pred a; a; A.succ a ]) + |> List.dedup_and_sort ~compare:A.compare + |> List.sort ~compare:A.compare + in + List.iter examples ~f:(fun a -> + let b' = a_to_b_trunc a in + let a' = b_to_a_trunc b' in + match a_to_b_or_error a with + | Ok b -> + require loc (B.equal b b') + ~if_false_then_print_s: + (lazy [%message + "conversion produced wrong value" + ~from: (A.module_name : string) + ~to_: (B.module_name : string) + ~input: (a : A.t) + ~output: (b : B.t) + ~expected: (b' : B.t)]); + require loc (A.equal a a') + ~if_false_then_print_s: + (lazy + [%message + "conversion does not round-trip" + ~from: (A.module_name : string) + ~to_: (B.module_name : string) + ~input: (a : A.t) + ~output: (b : B.t) + ~round_trip: (a' : A.t)]) + | Error error -> + require loc (not (A.equal a a')) + ~if_false_then_print_s: + (lazy + [%message + "conversion failed" + ~from: (A.module_name : string) + ~to_: (B.module_name : string) + ~input: (a : A.t) + ~expected_output: (b' : B.t) + ~error: (error : Error.t)])) + + let test loc ma mb + (a_to_b_trunc, a_to_b_or_error) + (b_to_a_trunc, b_to_a_or_error) + = + test_conversion loc ma mb a_to_b_or_error a_to_b_trunc b_to_a_trunc; + test_conversion loc mb ma b_to_a_or_error b_to_a_trunc a_to_b_trunc + + module Int = struct include Int let module_name = "Int" end + module Int32 = struct include Int32 let module_name = "Int32" end + module Int64 = struct include Int64 let module_name = "Int64" end + module Nativeint = struct include Nativeint let module_name = "Nativeint" end + + let with_exn f x = Or_error.try_with (fun () -> f x) + let optional f x = Or_error.try_with (fun () -> Option.value_exn (f x)) + let alwaysok f x = Ok (f x) + + let%expect_test "int <-> int32" = + test [%here] (module Int) (module Int32) + (Caml.Int32.of_int, with_exn int_to_int32_exn) + (Caml.Int32.to_int, with_exn int32_to_int_exn); + [%expect {| |}]; + test [%here] (module Int) (module Int32) + (Caml.Int32.of_int, optional int_to_int32) + (Caml.Int32.to_int, optional int32_to_int); + [%expect {| |}]; + ;; + + let%expect_test "int <-> int64" = + test [%here] (module Int) (module Int64) + (Caml.Int64.of_int, alwaysok int_to_int64) + (Caml.Int64.to_int, with_exn int64_to_int_exn); + [%expect {| |}]; + test [%here] (module Int) (module Int64) + (Caml.Int64.of_int, alwaysok int_to_int64) + (Caml.Int64.to_int, optional int64_to_int); + [%expect {| |}]; + ;; + + let%expect_test "int <-> nativeint" = + test [%here] (module Int) (module Nativeint) + (Caml.Nativeint.of_int, alwaysok int_to_nativeint) + (Caml.Nativeint.to_int, with_exn nativeint_to_int_exn); + [%expect {| |}]; + test [%here] (module Int) (module Nativeint) + (Caml.Nativeint.of_int, alwaysok int_to_nativeint) + (Caml.Nativeint.to_int, optional nativeint_to_int); + [%expect {| |}]; + ;; + + let%expect_test "int32 <-> int64" = + test [%here] (module Int32) (module Int64) + (Caml.Int64.of_int32, alwaysok int32_to_int64) + (Caml.Int64.to_int32, with_exn int64_to_int32_exn); + [%expect {| |}]; + test [%here] (module Int32) (module Int64) + (Caml.Int64.of_int32, alwaysok int32_to_int64) + (Caml.Int64.to_int32, optional int64_to_int32); + [%expect {| |}]; + ;; + + let%expect_test "int32 <-> nativeint" = + test [%here] (module Int32) (module Nativeint) + (Caml.Nativeint.of_int32, alwaysok int32_to_nativeint) + (Caml.Nativeint.to_int32, with_exn nativeint_to_int32_exn); + [%expect {| |}]; + test [%here] (module Int32) (module Nativeint) + (Caml.Nativeint.of_int32, alwaysok int32_to_nativeint) + (Caml.Nativeint.to_int32, optional nativeint_to_int32); + [%expect {| |}]; + ;; + + let%expect_test "int64 <-> nativeint" = + test [%here] (module Int64) (module Nativeint) + (Caml.Int64.to_nativeint, with_exn int64_to_nativeint_exn) + (Caml.Int64.of_nativeint, alwaysok nativeint_to_int64); + [%expect {| |}]; + test [%here] (module Int64) (module Nativeint) + (Caml.Int64.to_nativeint, optional int64_to_nativeint) + (Caml.Int64.of_nativeint, alwaysok nativeint_to_int64); + [%expect {| |}]; + ;; + end) diff --git a/test/test_int_conversions.mli b/test/test_int_conversions.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int_conversions.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int_hash.ml b/test/test_int_hash.ml new file mode 100644 index 0000000..6ad8fa0 --- /dev/null +++ b/test/test_int_hash.ml @@ -0,0 +1,8 @@ +open! Base +open! Import + +let%expect_test "int hash is not ident" [@tags "64-bits-only"] = + print_s [%message "hash of 10" + (Int.hash 10 : int) ]; + [%expect {| ("hash of 10" ("Int.hash 10" 1_579_120_067_278_557_813)) |}]; +;; diff --git a/test/test_int_hash.mli b/test/test_int_hash.mli new file mode 100644 index 0000000..1b56838 --- /dev/null +++ b/test/test_int_hash.mli @@ -0,0 +1 @@ +(* this interface is deliberately empty *) diff --git a/test/test_int_math.ml b/test/test_int_math.ml new file mode 100644 index 0000000..99b9b3e --- /dev/null +++ b/test/test_int_math.ml @@ -0,0 +1,192 @@ +open! Import +open! Base.Not_exposed_properly.Int_math + +let%test_unit _ = + let x = match Word_size.word_size with W32 -> 9 | W64 -> 10 in + for i = 0 to x do + for j = 0 to x do + assert (int_pow i j + = Caml.(int_of_float ((float_of_int i) ** (float_of_int j)))) + done + done + +module Make (X : T) : sig end = struct + open X + include Make(X) + + let%test_module "integer-rounding" = + (module struct + + let check dir ~range:(lower, upper) ~modulus expected = + let modulus = of_int_exn modulus in + let expected = of_int_exn expected in + for i = lower to upper do + let observed = round ~dir ~to_multiple_of:modulus (of_int_exn i) in + if observed <> expected then + raise_s [%message "invalid result" (i : int)] + done + ;; + + let%test_unit _ = check ~modulus:10 `Down ~range:( 10, 19) 10 + let%test_unit _ = check ~modulus:10 `Down ~range:( 0, 9) 0 + let%test_unit _ = check ~modulus:10 `Down ~range:(-10, -1) (-10) + let%test_unit _ = check ~modulus:10 `Down ~range:(-20, -11) (-20) + + let%test_unit _ = check ~modulus:10 `Up ~range:( 11, 20) 20 + let%test_unit _ = check ~modulus:10 `Up ~range:( 1, 10) 10 + let%test_unit _ = check ~modulus:10 `Up ~range:( -9, 0) 0 + let%test_unit _ = check ~modulus:10 `Up ~range:(-19, -10) (-10) + + let%test_unit _ = check ~modulus:10 `Zero ~range:( 10, 19) 10 + let%test_unit _ = check ~modulus:10 `Zero ~range:( -9, 9) 0 + let%test_unit _ = check ~modulus:10 `Zero ~range:(-19, -10) (-10) + + let%test_unit _ = check ~modulus:10 `Nearest ~range:( 15, 24) 20 + let%test_unit _ = check ~modulus:10 `Nearest ~range:( 5, 14) 10 + let%test_unit _ = check ~modulus:10 `Nearest ~range:( -5, 4) 0 + let%test_unit _ = check ~modulus:10 `Nearest ~range:(-15, -6) (-10) + let%test_unit _ = check ~modulus:10 `Nearest ~range:(-25, -16) (-20) + + let%test_unit _ = check ~modulus:5 `Nearest ~range:( 8, 12) 10 + let%test_unit _ = check ~modulus:5 `Nearest ~range:( 3, 7) 5 + let%test_unit _ = check ~modulus:5 `Nearest ~range:( -2, 2) 0 + let%test_unit _ = check ~modulus:5 `Nearest ~range:( -7, -3) (-5) + let%test_unit _ = check ~modulus:5 `Nearest ~range:(-12, -8) (-10) + end) + + let%test_module "remainder-and-modulus" = + (module struct + + let one = of_int_exn 1 + + let check_integers x y = + let sexp_of_t t = sexp_of_string (to_string t) in + let check_raises f what = + match f () with + | exception _ -> () + | z -> + raise_s + [%message "produced result instead of raising" + (what : string) + (x : t) + (y : t) + (z : t)] + in + let check_true cond what = + if not cond + then raise_s + [%message "failed" + (what : string) + (x : t) + (y : t)] + in + if y = zero + then + begin + check_raises (fun () -> x / y) "division by zero"; + check_raises (fun () -> rem x y) "rem _ zero"; + check_raises (fun () -> x % y) "_ % zero"; + check_raises (fun () -> x /% y) "_ /% zero"; + end + else + begin + if x < zero + then check_true (rem x y <= zero) "non-positive remainder" + else check_true (rem x y >= zero) "non-negative remainder"; + check_true (abs (rem x y) <= abs y - one) "range of remainder"; + if y < zero then begin + check_raises (fun () -> x % y) "_ % negative"; + check_raises (fun () -> x /% y) "_ /% negative" + end + else begin + check_true (x = (x /% y) * y + (x % y)) "(/%) and (%) identity"; + check_true (x = (x / y) * y + (rem x y)) "(/) and rem identity"; + check_true (x % y >= zero) "non-negative (%)"; + check_true (x % y <= y - one) "range of (%)"; + if x > zero && y > zero + then begin + check_true (x /% y = x / y) "(/%) and (/) identity"; + check_true (x % y = rem x y) "(%) and rem identity" + end; + end + end + ;; + + let check_natural_numbers x y = + List.iter [ x ; -x ; x+one ; -(x + one) ] ~f:(fun x -> + List.iter [ y ; -y ; y+one ; -(y + one) ] ~f:(fun y -> + check_integers x y)) + + let%test_unit "deterministic" = + let big1 = of_int_exn 118_310_344 in + let big2 = of_int_exn 828_172_408 in + (* Important to test the case where one value is a multiple of the other. Note that + the [x + one] and [y + one] cases in [check_natural_numbers] ensure that we also + test non-multiple cases. *) + assert (big2 = big1 * of_int_exn 7); + let values = [ zero ; one ; big1 ; big2 ] in + List.iter values ~f:(fun x -> + List.iter values ~f:(fun y -> + check_natural_numbers x y)) + + let%test_unit "random" = + let rand = Random.State.make [| 8; 67; -5_309 |] in + for _ = 0 to 1_000 do + let max_value = 1_000_000_000 in + let x = of_int_exn (Random.State.int rand max_value) in + let y = of_int_exn (Random.State.int rand max_value) in + check_natural_numbers x y + done + end) +end + +include Make (Int) +include Make (Int32) +include Make (Int63) +include Make (Int64) +include Make (Nativeint) + +let%test_module "pow" = + (module struct + let%test _ = int_pow 0 0 = 1 + let%test _ = int_pow 0 1 = 0 + let%test _ = int_pow 10 1 = 10 + let%test _ = int_pow 10 2 = 100 + let%test _ = int_pow 10 3 = 1_000 + let%test _ = int_pow 10 4 = 10_000 + let%test _ = int_pow 10 5 = 100_000 + let%test _ = int_pow 2 10 = 1024 + + let%test _ = int_pow 0 1_000_000 = 0 + let%test _ = int_pow 1 1_000_000 = 1 + let%test _ = int_pow (-1) 1_000_000 = 1 + let%test _ = int_pow (-1) 1_000_001 = -1 + + let (=) = Int64.(=) + let%test _ = int64_pow 0L 0L = 1L + let%test _ = int64_pow 0L 1_000_000L = 0L + let%test _ = int64_pow 1L 1_000_000L = 1L + let%test _ = int64_pow (-1L) 1_000_000L = 1L + let%test _ = int64_pow (-1L) 1_000_001L = -1L + + let%test _ = int64_pow 10L 1L = 10L + let%test _ = int64_pow 10L 2L = 100L + let%test _ = int64_pow 10L 3L = 1_000L + let%test _ = int64_pow 10L 4L = 10_000L + let%test _ = int64_pow 10L 5L = 100_000L + let%test _ = int64_pow 2L 10L = 1_024L + let%test _ = int64_pow 5L 27L = 7450580596923828125L + + let exception_thrown pow b e = Exn.does_raise (fun () -> pow b e) + + let%test _ = exception_thrown int_pow 10 60 + let%test _ = exception_thrown int64_pow 10L 60L + let%test _ = exception_thrown int_pow 10 (-1) + let%test _ = exception_thrown int64_pow 10L (-1L) + + let%test _ = exception_thrown int64_pow 2L 63L + let%test _ = not (exception_thrown int64_pow 2L 62L) + + let%test _ = exception_thrown int64_pow (-2L) 63L + let%test _ = not (exception_thrown int64_pow (-2L) 62L) + end) diff --git a/test/test_int_math.mli b/test/test_int_math.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int_math.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_int_pow2.ml b/test/test_int_pow2.ml new file mode 100644 index 0000000..9f7d76d --- /dev/null +++ b/test/test_int_pow2.ml @@ -0,0 +1,126 @@ +open! Import +open! Int + +let examples = + [ -1 + ; 0 + ; 1 + ; 2 + ; 3 + ; 4 + ; 5 + ; 7 + ; 8 + ; 9 + ; 63 + ; 64 + ; 65 ] +;; + +let examples_64_bit = + [ Int.min_value + ; Int.min_value + 1 + ; Int.max_value - 1 + ; Int.max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : int) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Int.floor_log2] got invalid input" -1))) + (0 (Error ("[Int.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-4_611_686_018_427_387_904 ( + Error ("[Int.floor_log2] got invalid input" -4611686018427387904))) + (-4_611_686_018_427_387_903 ( + Error ("[Int.floor_log2] got invalid input" -4611686018427387903))) + (4_611_686_018_427_387_902 (Ok 61)) + (4_611_686_018_427_387_903 (Ok 61)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Int.ceil_log2] got invalid input" -1))) + (0 (Error ("[Int.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-4_611_686_018_427_387_904 ( + Error ("[Int.ceil_log2] got invalid input" -4611686018427387904))) + (-4_611_686_018_427_387_903 ( + Error ("[Int.ceil_log2] got invalid input" -4611686018427387903))) + (4_611_686_018_427_387_902 (Ok 62)) + (4_611_686_018_427_387_903 (Ok 62)) |}]; +;; + +let%test_module "int_math" = + (module struct + + let test_cases () = + let cases = + [ 0b10101010; 0b1010101010101010; 0b101010101010101010101010; + 0b10000000; 0b1000000000001000; 0b100000000000000000001000; ] + in + match Word_size.word_size with + | W64 -> (* create some >32 bit values... *) + (* We can't use literals directly because the compiler complains on 32 bits. *) + let cases = + cases @ [ (0b1010101010101010 lsl 16) lor 0b1010101010101010; + (0b1000000000000000 lsl 16) lor 0b0000000000001000; ] + in + let added_cases = List.map cases ~f:(fun x -> x lsl 16) in + List.concat [ cases; added_cases ] + | W32 -> cases + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / 2)) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && ((2 * p2) >= x && x >= p2) ) + ) + ;; + end) diff --git a/test/test_int_pow2.mli b/test/test_int_pow2.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_int_pow2.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_lazy.ml b/test/test_lazy.ml new file mode 100644 index 0000000..55603a0 --- /dev/null +++ b/test/test_lazy.ml @@ -0,0 +1,55 @@ +open! Import +open! Lazy + +let%test_unit _ = + let r = ref 0 in + let t = return () >>= fun () -> Int.incr r; return () in + assert (!r = 0); + force t; + assert (!r = 1); + force t; + assert (!r = 1) +;; + +let%test_unit _ = + let r = ref 0 in + let t = return () >>= fun () -> lazy (Int.incr r) in + assert (!r = 0); + force t; + assert (!r = 1); + force t; + assert (!r = 1) +;; + +let%test_module _ = + (module struct + + module M1 = struct + type nonrec t = { x : int t } [@@deriving sexp_of] + end + + module M2 = struct + type t = { x : int T_unforcing.t } [@@deriving sexp_of] + end + + let%test_unit _ = + let v = lazy 42 in + let (_ : int) = + (* no needed, but the purpose of this test is not to test this compiler + optimization *) + force v + in + assert (is_val v); + let t1 = { M1. x = v } in + let t2 = { M2. x = v } in + assert (Sexp.equal (M1.sexp_of_t t1) (M2.sexp_of_t t2)) + ;; + + let%test_unit _ = + let t1 = { M1. x = lazy (40 + 2) } in + let t2 = { M2. x = lazy (40 + 2) } in + assert (not (Sexp.equal (M1.sexp_of_t t1) (M2.sexp_of_t t2))); + assert (is_val t1.x); + assert (not (is_val t2.x)) + ;; + end) diff --git a/test/test_lazy.mli b/test/test_lazy.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_lazy.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_list.ml b/test/test_list.ml new file mode 100644 index 0000000..b6d8181 --- /dev/null +++ b/test/test_list.ml @@ -0,0 +1,610 @@ +open! Import +open! List + +let%test_module "reduce_balanced" = + (module struct + let test expect list = + [%test_result: string option] ~expect + (reduce_balanced ~f:(fun a b -> "(" ^ a ^ "+" ^ b ^ ")") list) + + let%test_unit "length 0" = + test None [] + + let%test_unit "length 1" = + test (Some "a") ["a"] + + let%test_unit "length 2" = + test (Some "(a+b)") ["a"; "b"] + + let%test_unit "length 6" = + test (Some "(((a+b)+(c+d))+(e+f))") ["a";"b";"c";"d";"e";"f"] + + let%test_unit "longer" = + (* pairs (index, number of times f called on me) to check: + 1. f called on results in index order + 2. total number of calls on any element is low + called on 2^n + 1 to demonstrate lack of balance (most elements are distance 7 from + the tree root, but one is distance 1) *) + let data = map (range 0 65) ~f:(fun i -> [(i, 0)]) in + let f x y = map (x @ y) ~f:(fun (ix, cx) -> (ix, cx + 1)) in + match reduce_balanced data ~f with + | None -> failwith "None" + | Some l -> + [%test_result: int] ~expect:65 (List.length l); + iteri l ~f:(fun actual_index (computed_index, num_f) -> + let expected_num_f = if actual_index = 64 then 1 else 7 in + [%test_result: int * int] + ~expect:(actual_index, expected_num_f) (computed_index, num_f)) + end) + +let%test_module "range symmetries" = + (module struct + + let basic ~stride ~start ~stop ~start_n ~stop_n ~result = + [%compare.equal: int t] (range ~stride ~start ~stop start_n stop_n) result + + let test stride (start_n, start) (stop_n, stop) result = + basic ~stride ~start ~stop ~start_n ~stop_n ~result + && (* works for negative [start] and [stop] *) + basic ~stride:(-stride) + ~start_n:(-start_n) + ~stop_n:(-stop_n) + ~start + ~stop + ~result:(List.map result ~f:(fun x -> -x)) + + let%test _ = test 1 ( 3, `inclusive) ( 1, `exclusive) [] + let%test _ = test 1 ( 3, `inclusive) ( 3, `exclusive) [] + let%test _ = test 1 ( 3, `inclusive) ( 4, `exclusive) [3] + let%test _ = test 1 ( 3, `inclusive) ( 8, `exclusive) [3;4;5;6;7] + let%test _ = test 3 ( 4, `inclusive) (10, `exclusive) [4;7] + let%test _ = test 3 ( 4, `inclusive) (11, `exclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (12, `exclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (13, `exclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (14, `exclusive) [4;7;10;13] + + let%test _ = test (-1) ( 1, `inclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 3, `inclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 4, `inclusive) ( 3, `exclusive) [4] + let%test _ = test (-1) ( 8, `inclusive) ( 3, `exclusive) [8;7;6;5;4] + let%test _ = test (-3) (10, `inclusive) ( 4, `exclusive) [10;7] + let%test _ = test (-3) (10, `inclusive) ( 3, `exclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 2, `exclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 1, `exclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 0, `exclusive) [10;7;4;1] + + let%test _ = test 1 ( 3, `exclusive) ( 1, `exclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 3, `exclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 4, `exclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 8, `exclusive) [4;5;6;7] + let%test _ = test 3 ( 4, `exclusive) (10, `exclusive) [7] + let%test _ = test 3 ( 4, `exclusive) (11, `exclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (12, `exclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (13, `exclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (14, `exclusive) [7;10;13] + + let%test _ = test (-1) ( 1, `exclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 3, `exclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 4, `exclusive) ( 3, `exclusive) [] + let%test _ = test (-1) ( 8, `exclusive) ( 3, `exclusive) [7;6;5;4] + let%test _ = test (-3) (10, `exclusive) ( 4, `exclusive) [7] + let%test _ = test (-3) (10, `exclusive) ( 3, `exclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 2, `exclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 1, `exclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 0, `exclusive) [7;4;1] + + let%test _ = test 1 ( 3, `inclusive) ( 1, `inclusive) [] + let%test _ = test 1 ( 3, `inclusive) ( 3, `inclusive) [3] + let%test _ = test 1 ( 3, `inclusive) ( 4, `inclusive) [3;4] + let%test _ = test 1 ( 3, `inclusive) ( 8, `inclusive) [3;4;5;6;7;8] + let%test _ = test 3 ( 4, `inclusive) (10, `inclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (11, `inclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (12, `inclusive) [4;7;10] + let%test _ = test 3 ( 4, `inclusive) (13, `inclusive) [4;7;10;13] + let%test _ = test 3 ( 4, `inclusive) (14, `inclusive) [4;7;10;13] + + let%test _ = test (-1) ( 1, `inclusive) ( 3, `inclusive) [] + let%test _ = test (-1) ( 3, `inclusive) ( 3, `inclusive) [3] + let%test _ = test (-1) ( 4, `inclusive) ( 3, `inclusive) [4;3] + let%test _ = test (-1) ( 8, `inclusive) ( 3, `inclusive) [8;7;6;5;4;3] + let%test _ = test (-3) (10, `inclusive) ( 4, `inclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 3, `inclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 2, `inclusive) [10;7;4] + let%test _ = test (-3) (10, `inclusive) ( 1, `inclusive) [10;7;4;1] + let%test _ = test (-3) (10, `inclusive) ( 0, `inclusive) [10;7;4;1] + + let%test _ = test 1 ( 3, `exclusive) ( 1, `inclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 3, `inclusive) [] + let%test _ = test 1 ( 3, `exclusive) ( 4, `inclusive) [4] + let%test _ = test 1 ( 3, `exclusive) ( 8, `inclusive) [4;5;6;7;8] + let%test _ = test 3 ( 4, `exclusive) (10, `inclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (11, `inclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (12, `inclusive) [7;10] + let%test _ = test 3 ( 4, `exclusive) (13, `inclusive) [7;10;13] + let%test _ = test 3 ( 4, `exclusive) (14, `inclusive) [7;10;13] + + let%test _ = test (-1) ( 1, `exclusive) ( 3, `inclusive) [] + let%test _ = test (-1) ( 3, `exclusive) ( 3, `inclusive) [] + let%test _ = test (-1) ( 4, `exclusive) ( 3, `inclusive) [3] + let%test _ = test (-1) ( 8, `exclusive) ( 3, `inclusive) [7;6;5;4;3] + let%test _ = test (-3) (10, `exclusive) ( 4, `inclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 3, `inclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 2, `inclusive) [7;4] + let%test _ = test (-3) (10, `exclusive) ( 1, `inclusive) [7;4;1] + let%test _ = test (-3) (10, `exclusive) ( 0, `inclusive) [7;4;1] + + let test_start_inc_exc stride start (stop, stop_inc_exc) result = + test stride (start, `inclusive) (stop, stop_inc_exc) result + && begin + match result with + | [] -> true + | head :: tail -> + head = start && test stride (start, `exclusive) (stop, stop_inc_exc) tail + end + + let test_inc_exc stride start stop result = + test_start_inc_exc stride start (stop, `inclusive) result + && begin + match List.rev result with + | [] -> true + | last :: all_but_last -> + let all_but_last = List.rev all_but_last in + if last = stop then + test_start_inc_exc stride start (stop, `exclusive) all_but_last + else + true + end + + let%test _ = test_inc_exc 1 4 10 [4;5;6;7;8;9;10] + let%test _ = test_inc_exc 3 4 10 [4;7;10] + let%test _ = test_inc_exc 3 4 11 [4;7;10] + let%test _ = test_inc_exc 3 4 12 [4;7;10] + let%test _ = test_inc_exc 3 4 13 [4;7;10;13] + let%test _ = test_inc_exc 3 4 14 [4;7;10;13] + + end) + +module Test_values = struct + let long1 = + let v = lazy (range 1 100_000) in + fun () -> Lazy.force v + + let long2 = + let v = lazy (range 2 100_001) in + fun () -> Lazy.force v + + let l1 = [1;2;3;4;5;6;7;8;9;10] +end + +let%test_unit _ = [%test_result: int list] (rev_append [1;2;3] [4;5;6]) ~expect:[3;2;1;4;5;6] +let%test_unit _ = [%test_result: int list] (rev_append [] [4;5;6]) ~expect:[4;5;6] +let%test_unit _ = [%test_result: int list] (rev_append [1;2;3] []) ~expect:[3;2;1] +let%test_unit _ = [%test_result: int list] (rev_append [1] [2;3]) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (rev_append [1;2] [3]) ~expect:[2;1;3] +let%test_unit _ = + let long = Test_values.long1 () in + ignore (rev_append long long:int list) + +let%test_unit _ = + let long1 = Test_values.long1 () in + let long2 = Test_values.long2 () in + [%test_result: int list] (map long1 ~f:(fun x -> x + 1)) ~expect:long2 + +let test_ordering n = + let l = List.range 0 n in + let r = ref [] in + let _ : unit list = List.map l ~f:(fun x -> r := x :: !r) in + [%test_eq: int list] l (List.rev !r) + +let%test_unit _ = test_ordering 10 +let%test_unit _ = test_ordering 1000 +let%test_unit _ = test_ordering 1_000_000 + +let%test _ = for_all2_exn [] [] ~f:(fun _ _ -> assert false) + +let%test_unit _ = [%test_result: int option] (find_mapi [0;5;2;1;4] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (find_mapi [3;5;2;1;4] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 4) +let%test_unit _ = [%test_result: int option] (find_mapi [3;5;1;1;4] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:(Some 8) +let%test_unit _ = [%test_result: int option] (find_mapi [3;5;1;1;2] ~f:(fun i x -> if i = x then Some (i+x) else None)) ~expect:None + +let%test_unit _ = [%test_result: bool] (for_alli [] ~f:(fun _ _ -> false)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [0;1;2;3] ~f:(fun i x -> i = x)) ~expect:true +let%test_unit _ = [%test_result: bool] (for_alli [0;1;3;3] ~f:(fun i x -> i = x)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [] ~f:(fun _ _ -> true)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [0;1;2;3] ~f:(fun i x -> i <> x)) ~expect:false +let%test_unit _ = [%test_result: bool] (existsi [0;1;3;3] ~f:(fun i x -> i <> x)) ~expect:true + +let%test_unit _ = [%test_result: int list] (append [1;2;3] [4;5;6]) ~expect:[1;2;3;4;5;6] +let%test_unit _ = [%test_result: int list] (append [] [4;5;6]) ~expect:[4;5;6] +let%test_unit _ = [%test_result: int list] (append [1;2;3] []) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (append [1] [2;3]) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (append [1;2] [3]) ~expect:[1;2;3] +let%test_unit _ = + let long = Test_values.long1 () in + ignore (append long long:int list) + +let%test_unit _ = [%test_result: int list] (map ~f:Fn.id Test_values.l1) ~expect:Test_values.l1 +let%test_unit _ = [%test_result: int list] (map ~f:Fn.id []) ~expect:[] +let%test_unit _ = [%test_result: float list] (map ~f:(fun x -> x +. 5.) [1.;2.;3.]) ~expect:[6.;7.;8.] +let%test_unit _ = + ignore (map ~f:Fn.id (Test_values.long1 ()):int list) + +let%test_unit _ = + [%test_result: (int * char) list] + (map2_exn ~f:(fun a b -> a, b) [1;2;3] ['a';'b';'c']) + ~expect:[(1,'a'); (2,'b'); (3,'c')] +let%test_unit _ = [%test_result: _ list] (map2_exn ~f:(fun _ _ -> ()) [] []) ~expect:[] +let%test_unit _ = + let long = Test_values.long1 () in + ignore (map2_exn ~f:(fun _ _ -> ()) long long:unit list) + +let%test_unit _ = [%test_result: int list] (rev_map_append [1;2;3;4;5] [6] ~f:Fn.id) ~expect:[5;4;3;2;1;6] +let%test_unit _ = [%test_result: int list] (rev_map_append [1;2;3;4;5] [6] ~f:(fun x -> 2 * x)) ~expect:[10;8;6;4;2;6] +let%test_unit _ = [%test_result: int list] (rev_map_append [] [6] ~f:(fun _ -> failwith "bug!")) ~expect:[6] + +let%test_unit _ = + [%test_result: int list] + (fold_right ~f:(fun e acc -> e :: acc) Test_values.l1 ~init:[]) + ~expect:Test_values.l1 +let%test_unit _ = [%test_result: string] (fold_right ~f:(fun e acc -> e ^ acc) ["1";"2"] ~init:"3") ~expect:"123" +let%test_unit _ = [%test_result: unit] (fold_right ~f:(fun _ _ -> ()) [] ~init:()) ~expect:() +let%test_unit _ = + let long = Test_values.long1 () in + ignore (fold_right ~f:(fun e acc -> e :: acc) long ~init:[]) + +let%test_unit _ = + let l1 = Test_values.l1 in + [%test_result: int list * int list] (unzip (zip_exn l1 (List.rev l1))) ~expect:(l1, List.rev l1) +;; + +let%test_unit _ = + let long = Test_values.long1 () in + ignore (unzip (zip_exn long long)) +;; + +let%test_unit _ = [%test_result: int list * int list] (unzip [(1,2) ; (4,5) ]) ~expect:([1; 4], [2; 5] ) +let%test_unit _ = [%test_result: int list * int list * int list] (unzip3 [(1,2,3); (4,5,6)]) ~expect:([1; 4], [2; 5], [3; 6]) + +let%test_unit _ = [%test_result: (int * int) list Or_unequal_lengths.t] (zip [1;2;3] [4;5;6]) ~expect:(Ok [1,4;2,5;3,6]) +let%test_unit _ = [%test_result: (int * int) list Or_unequal_lengths.t] (zip [1] [4;5;6] ) ~expect:Unequal_lengths + +let%test_unit _ = [%test_result: (int * int) list] (zip_exn [1;2;3] [4;5;6]) ~expect:[1,4;2,5;3,6] + +let%expect_test _ = + show_raise (fun () -> zip_exn [1] [4;5;6]); + [%expect {| (raised (Invalid_argument "length mismatch in zip_exn: 1 <> 3 ")) |}] +;; + +let%test_unit _ = + [%test_result: (int * string) list] + (mapi ~f:(fun i x -> (i,x)) ["one";"two";"three";"four"]) + ~expect:[0,"one";1,"two";2,"three";3,"four"] +let%test_unit _ = [%test_result: (int * _) list] (mapi ~f:(fun i x -> (i,x)) []) ~expect:[] + +let%test_module "group" = + (module struct + let%test_unit _ = [%test_result: int list list] (group [1;2;3;4] ~break:(fun _ x -> x = 3)) ~expect:[[1;2];[3;4]] + + let%test_unit _ = [%test_result: int list list] (group [] ~break:(fun _ -> assert false)) ~expect:[] + + let mis = ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] + let equal_letters = + [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] + let single_letters = + [['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i']] + let every_three = + [['M'; 'i'; 's']; ['s'; 'i'; 's']; ['s'; 'i'; 'p']; ['p'; 'i' ]] + + let%test_unit _ = [%test_result: char list list] (group ~break:Char.(<>) mis) ~expect:equal_letters + let%test_unit _ = [%test_result: char list list] (group ~break:(fun _ _ -> false) mis) ~expect:single_letters + let%test_unit _ = [%test_result: char list list] (groupi ~break:(fun i _ _ -> i % 3 = 0) mis) ~expect:every_three + end) + +let%test_module "chunks_of" = + (module struct + + let test length break_every = + let l = List.init length ~f:Fn.id in + let b = chunks_of l ~length:break_every in + [%test_eq: int list] (List.concat b) l; + List.iter b ~f:([%test_pred: int list] (fun batch -> + List.length batch <= break_every)); + ;; + + let expect_exn length break_every = + match test length break_every with + | exception _ -> () + | () -> raise_s [%message "Didn't raise." (length : int) (break_every : int)] + ;; + + let%test_unit _ = + for n = 0 to 10 do + for k = n + 2 downto 1 do + test n k + done + done; + expect_exn 1 0; + expect_exn 1 (-1); + ;; + + let%test_unit _ = [%test_result: _ list list] (chunks_of [] ~length:1) ~expect:[] + end) + +let%test _ = last_exn [1;2;3] = 3 +let%test _ = last_exn [1] = 1 +let%test _ = last_exn (Test_values.long1 ()) = 99_999 + +let%test _ = is_prefix [] ~prefix:[] ~equal:(=) +let%test _ = is_prefix [1] ~prefix:[] ~equal:(=) +let%test _ = is_prefix [1] ~prefix:[1] ~equal:(=) +let%test _ = not (is_prefix [1] ~prefix:[1;2] ~equal:(=)) +let%test _ = not (is_prefix [1;3] ~prefix:[1;2] ~equal:(=)) +let%test _ = is_prefix [1;2;3] ~prefix:[1;2] ~equal:(=) + +let%test_unit _ = + List.iter ~f:(fun (t, expect) -> + assert (Poly.equal expect (find_consecutive_duplicate t ~equal:Poly.equal))) + [ [] , None + ; [ 1 ] , None + ; [ 1; 1 ] , Some (1, 1) + ; [ 1; 2 ] , None + ; [ 1; 2; 1 ] , None + ; [ 1; 2; 2 ] , Some (2, 2) + ; [ 1; 1; 2; 2 ], Some (1, 1) + ] +;; + +let%test_unit _ = + [%test_result: ((int * char) * (int * char)) option] + (find_consecutive_duplicate [(0,'a');(1,'b');(2,'b')] + ~equal:(fun (_, a) (_, b) -> Char.(=) a b)) + ~expect:(Some ((1, 'b'), (2, 'b'))) +;; + +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [] + ~equal:Int.(=)) ~expect:[] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [5;5;5;5;5] + ~equal:Int.(=)) ~expect:[5] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [5;6;5;6;5;6] + ~equal:Int.(=)) ~expect:[5;6;5;6;5;6] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`Last [5;5;6;6;5;5;8;8] + ~equal:Int.(=)) ~expect:[5;6;5;8] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(0,2);(2,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[ (0,2);(2,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[(0,1);(2,2);(0,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(2,1);(0,2);(4,2)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[ (2,1); (4,2)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`Last [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[(0,1); (0,2);(4,1)] + +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [] + ~equal:Int.(=)) ~expect:[] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [5;5;5;5;5] + ~equal:Int.(=)) ~expect:[5] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [5;6;5;6;5;6] + ~equal:Int.(=)) ~expect:[5;6;5;6;5;6] +let%test_unit _ = [%test_result: int list] + (remove_consecutive_duplicates ~which_to_keep:`First [5;5;6;6;5;5;8;8] + ~equal:Int.(=)) ~expect:[5;6;5;8] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(0,2);(2,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[(0,1); (2,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (a,_) (b,_) -> Int.(=) a b)) ~expect:[(0,1);(2,2);(0,2);(4,1)] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(2,1);(0,2);(4,2)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[(0,1); (0,2); ] +let%test_unit _ = [%test_result: (int * int) list] + (remove_consecutive_duplicates ~which_to_keep:`First [(0,1);(2,2);(0,2);(4,1)] + ~equal:(fun (_,a) (_,b) -> Int.(=) a b)) ~expect:[(0,1);(2,2); (4,1)] + +let%test_unit _ = [%test_result: int list] (dedup_and_sort ~compare:Int.compare []) ~expect:[] +let%test_unit _ = [%test_result: int list] (dedup_and_sort ~compare:Int.compare [5;5;5;5;5]) ~expect:[5] +let%test_unit _ = [%test_result: int] (length (dedup_and_sort ~compare:Int.compare [2;1;5;3;4])) ~expect:5 +let%test_unit _ = [%test_result: int] (length (dedup_and_sort ~compare:Int.compare [2;3;5;3;4])) ~expect:4 +let%test_unit _ = [%test_result: int] (length (dedup_and_sort [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (a,_) (b,_) -> Int.compare a b))) ~expect:3 +let%test_unit _ = [%test_result: int] (length (dedup_and_sort [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (_,a) (_,b) -> Int.compare a b))) ~expect:2 + +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare []) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3]) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;4]) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;3]) ~expect:(Some 3) +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;5;4;6;12]) ~expect:None +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;5;4;5;12]) ~expect:(Some 5) +let%test_unit _ = [%test_result: int option] (find_a_dup ~compare:Int.compare [3;5;12;5;12]) ~expect:(Some 5) +let%test_unit _ = + [%test_result: (int * int) option] + (find_a_dup ~compare:[%compare: int * int] [(0,1);(2,2);(0,2);(4,1)]) + ~expect:None +let%test _ = (find_a_dup [(0,1);(2,2);(0,2);(4,1)] + ~compare:(fun (_,a) (_,b) -> Int.compare a b)) + |> Option.is_some +let%test _ = let dup = find_a_dup [(0,1);(2,2);(0,2);(4,1)] + ~compare:(fun (a,_) (b,_) -> Int.compare a b) + in + match dup with + | Some (0, _) -> true + | _ -> false + +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare []) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;4]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;3]) ~expect:true +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;5;4;6;12]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;5;4;5;12]) ~expect:true +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:Int.compare [3;5;12;5;12]) ~expect:true +let%test_unit _ = [%test_result: bool] (contains_dup ~compare:[%compare: int * int] [(0,1);(2,2);(0,2);(4,1)]) ~expect:false +let%test_unit _ = [%test_result: bool] (contains_dup [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (_,a) (_,b) -> Int.compare a b)) ~expect:true + +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare []) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3]) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;4]) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;3]) ~expect:[3] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;5;4;6;12]) ~expect:[] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;5;4;5;12]) ~expect:[5] +let%test_unit _ = [%test_result: int list] (find_all_dups ~compare:Int.compare [3;5;12;5;12]) ~expect:[5;12] +let%test_unit _ = [%test_result: (int * int) list] (find_all_dups ~compare:[%compare: int * int] [(0,1);(2,2);(0,2);(4,1)]) ~expect:[] +let%test_unit _ = [%test_result: int] (length (find_all_dups [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (_,a) (_,b) -> Int.compare a b))) ~expect:2 +let%test_unit _ = [%test_result: int] (length (find_all_dups [(0,1);(2,2);(0,2);(4,1)] ~compare:(fun (a,_) (b,_) -> Int.compare a b))) ~expect:1 + +let%test_unit _ = [%test_result: int] (counti [0;1;2;3;4] ~f:(fun idx x -> idx = x)) ~expect:5 +let%test_unit _ = [%test_result: int] (counti [0;1;2;3;4] ~f:(fun idx x -> idx = 4-x)) ~expect:1 + +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun x -> Some x) Test_values.l1) ~expect:Test_values.l1 +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun x -> Some x) []) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun _x -> None) [1.;2.;3.]) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_map ~f:(fun x -> if (x > 0) then Some x else None) [1;-1;3]) ~expect:[1;3] + +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i x -> Some x) Test_values.l1) ~expect:Test_values.l1 +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i x -> Some x) []) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i _x -> None) [1.;2.;3.]) ~expect:[] +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun _i x -> if (x > 0) then Some x else None) [1;-1;3]) ~expect:[1;3] +let%test_unit _ = [%test_result: int list] (filter_mapi ~f:(fun i x -> if (i % 2=0) then Some x else None) [1;-1;3]) ~expect:[1;3] + +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] 3) ~expect:([1;2;3],[4;5;6]) +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] 100) ~expect:([1;2;3;4;5;6],[]) +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] 0) ~expect:([],[1;2;3;4;5;6]) +let%test_unit _ = [%test_result: (int list * int list)] (split_n [1;2;3;4;5;6] (-5)) ~expect:([],[1;2;3;4;5;6]) + +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] 3) ~expect:[1;2;3] +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] 100) ~expect:[1;2;3;4;5;6] +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] 0) ~expect:[] +let%test_unit _ = [%test_result: int list] (take [1;2;3;4;5;6] (-5)) ~expect:[] + +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] 3) ~expect:[4;5;6] +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] 100) ~expect:[] +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] 0) ~expect:[1;2;3;4;5;6] +let%test_unit _ = [%test_result: int list] (drop [1;2;3;4;5;6] (-5)) ~expect:[1;2;3;4;5;6] + +let%test_module "{take,drop,split}_while" = + (module struct + + let pred = function + | '0' .. '9' -> true + | _ -> false + + let test xs prefix suffix = + let (prefix1, suffix1) = split_while ~f:pred xs in + let prefix2 = take_while xs ~f:pred in + let suffix2 = drop_while xs ~f:pred in + [%test_eq: char list] xs (prefix @ suffix); + [%test_result: char list] ~expect:prefix prefix1; + [%test_result: char list] ~expect:prefix prefix2; + [%test_result: char list] ~expect:suffix suffix1; + [%test_result: char list] ~expect:suffix suffix2 + + let%test_unit _ = test ['1';'2';'3';'a';'b';'c'] ['1';'2';'3'] ['a';'b';'c'] + let%test_unit _ = test ['1';'2'; 'a';'b';'c'] ['1';'2' ] ['a';'b';'c'] + let%test_unit _ = test ['1'; 'a';'b';'c'] ['1' ] ['a';'b';'c'] + let%test_unit _ = test [ 'a';'b';'c'] [ ] ['a';'b';'c'] + let%test_unit _ = test ['1';'2';'3' ] ['1';'2';'3'] [ ] + let%test_unit _ = test [ ] [ ] [ ] + + end) + +let%test_unit _ = [%test_result: int list] (concat []) ~expect:[] +let%test_unit _ = [%test_result: int list] (concat [[]]) ~expect:[] +let%test_unit _ = [%test_result: int list] (concat [[3]]) ~expect:[3] +let%test_unit _ = [%test_result: int list] (concat [[1;2;3;4]]) ~expect:[1;2;3;4] +let%test_unit _ = [%test_result: int list] + (concat [[1;2;3;4];[5;6;7];[8;9;10];[];[11;12]]) + ~expect:[1;2;3;4;5;6;7;8;9;10;11;12] + +let%test_unit _ = [%test_result: bool] (is_sorted [] ~compare:Int.compare) ~expect:true +let%test_unit _ = [%test_result: bool] (is_sorted [1] ~compare:Int.compare) ~expect:true +let%test_unit _ = [%test_result: bool] (is_sorted [1; 2; 3; 4] ~compare:Int.compare) ~expect:true +let%test_unit _ = [%test_result: bool] (is_sorted [2; 1] ~compare:Int.compare) ~expect:false +let%test_unit _ = [%test_result: bool] (is_sorted [1; 3; 2] ~compare:Int.compare) ~expect:false + +let%test_unit _ = + List.iter + ~f:(fun (t, expect) -> [%test_result: bool] ~expect (is_sorted_strictly t ~compare:Int.compare)) + [ [] , true; + [ 1 ] , true; + [ 1; 2 ] , true; + [ 1; 1 ] , false; + [ 2; 1 ] , false; + [ 1; 2; 3 ], true; + [ 1; 1; 3 ], false; + [ 1; 2; 2 ], false; + ] +;; + +let%test_unit _ = [%test_result: int option] (random_element []) ~expect:None +let%test_unit _ = [%test_result: int option] (random_element [0]) ~expect:(Some 0) + +let%test_module "transpose" = + (module struct + + let round_trip a b = + [%test_result: int list list option] (transpose a) ~expect:(Some b); + [%test_result: int list list option] (transpose b) ~expect:(Some a) + + let%test_unit _ = round_trip [] [] + + let%test_unit _ = [%test_result: int list list option] (transpose [[]]) ~expect:(Some []) + let%test_unit _ = [%test_result: int list list option] (transpose [[]; []]) ~expect:(Some []) + let%test_unit _ = [%test_result: int list list option] (transpose [[]; []; []]) ~expect:(Some []) + + let%test_unit _ = round_trip [[1]] [[1]] + + let%test_unit _ = round_trip [[1]; + [2]] [[1; 2]] + + let%test_unit _ = round_trip [[1]; + [2]; + [3]] [[1; 2; 3]] + + let%test_unit _ = round_trip [[1; 2]; + [3; 4]] [[1; 3]; + [2; 4]] + + let%test_unit _ = round_trip [[1; 2; 3]; + [4; 5; 6]] [[1; 4]; + [2; 5]; + [3; 6]] + + let%test_unit _ = [%test_result: int list list option] (transpose [[]; [1]]) ~expect:None + + let%test_unit _ = [%test_result: int list list option] (transpose [[1;2];[3]]) ~expect:None + + end) + +let%test_unit _ = [%test_result: int list] (intersperse [1;2;3] ~sep:0) ~expect:[1;0;2;0;3] +let%test_unit _ = [%test_result: int list] (intersperse [1;2] ~sep:0) ~expect:[1;0;2] +let%test_unit _ = [%test_result: int list] (intersperse [1] ~sep:0) ~expect:[1] +let%test_unit _ = [%test_result: int list] (intersperse [] ~sep:0) ~expect:[] + +let test_fold_map list ~init ~f ~expect = + [%test_result: int list] (folding_map list ~init ~f) ~expect:(snd expect); + [%test_result: _ * int list] (fold_map list ~init ~f) ~expect + +let test_fold_mapi list ~init ~f ~expect = + [%test_result: int list] (folding_mapi list ~init ~f) ~expect:(snd expect); + [%test_result: _ * int list] (fold_mapi list ~init ~f) ~expect + +let%test_unit _ = test_fold_map [1;2;3;4] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(10, [1;3;6;10]) +let%test_unit _ = test_fold_map [] ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) + ~expect:(0, []) +let%test_unit _ = test_fold_mapi [1;2;3;4] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(20, [0;2;8;20]) +let%test_unit _ = test_fold_mapi [] ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) + ~expect:(0, []) diff --git a/test/test_list.mli b/test/test_list.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_list.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_map.ml b/test/test_map.ml new file mode 100644 index 0000000..d433795 --- /dev/null +++ b/test/test_map.ml @@ -0,0 +1,62 @@ +open! Import +open! Map + +let%test _ = + invariants (of_increasing_iterator_unchecked (module Int) ~len:20 ~f:(fun x -> x,x)) + +let%test _ = + invariants (Poly.of_increasing_iterator_unchecked ~len:20 ~f:(fun x -> x,x)) + +module M = M + +let add12 t = add_exn t ~key:1 ~data:2 + +type int_map = int Map.M(Int).t [@@deriving compare, hash, sexp] + +let%expect_test "[add_exn] success" = + print_s [%sexp (add12 (empty (module Int)) : int_map)]; + [%expect {| ((1 2)) |}] +;; + +let%expect_test "[add_exn] failure" = + show_raise (fun () -> add12 (add12 (empty (module Int)))); + [%expect {| (raised ("[Map.add_exn] got key already present" (key 1))) |}] +;; + +let%expect_test "[add] success" = + print_s [%sexp ( + add (empty (module Int)) ~key:1 ~data:2 : int_map Or_duplicate.t)]; + [%expect {| (Ok ((1 2))) |}] +;; + +let%expect_test "[add] duplicate" = + print_s [%sexp ( + add (add12 (empty (module Int))) ~key:1 ~data:2 : int_map Or_duplicate.t)]; + [%expect {| Duplicate |}] +;; + +let%expect_test "[Map.of_alist_multi] preserves value ordering" = + print_s [%sexp ( + Map.of_alist_multi (module String) ["a", 1; "a", 2; "b", 1; "b", 3] + : int list Map.M(String).t)]; + [%expect {| + ((a (1 2)) + (b (1 3))) |}] +;; + +module Poly = struct + let%test _ = + length Poly.empty = 0 + ;; + + let%test _ = + let a = Poly.of_alist_exn [] in + Poly.equal Base.Poly.equal a Poly.empty + ;; + + let%test _ = + let a = Poly.of_alist_exn [("a", 1)] in + let b = Poly.of_alist_exn [(1, "b")] in + length a = length b + ;; +end diff --git a/test/test_map.mlt b/test/test_map.mlt new file mode 100644 index 0000000..03ea12b --- /dev/null +++ b/test/test_map.mlt @@ -0,0 +1,6 @@ +open Base + +let _ = Map.add + +[%%expect {| +|}];; diff --git a/test/test_maybe_bound.ml b/test/test_maybe_bound.ml new file mode 100644 index 0000000..4b134d1 --- /dev/null +++ b/test/test_maybe_bound.ml @@ -0,0 +1,115 @@ +open! Import +open! Maybe_bound + +let%test_unit "bounds_crossed" = + let a, b, c, d = Incl 1, Excl 1, Incl 3, Excl 3 in + let cases = [ + a, a, false; + a, b, false; + a, c, false; + a, d, false; + b, a, false; + b, b, false; + b, c, false; + b, d, false; + c, a, true; + c, b, true; + c, c, false; + c, d, false; + d, a, true; + d, b, true; + d, c, false; + d, d, false; + ] in + List.iter cases ~f:(fun (lower, upper, expect) -> + let actual = bounds_crossed ~lower ~upper ~compare in + assert ([%compare.equal: bool] expect actual)); +;; + +let%test_module "is_lower_bound" = + (module struct + let compare = Int.compare + + let%test _ = is_lower_bound Unbounded ~of_:Int.min_value ~compare + + let%test _ = not (is_lower_bound (Incl 2) ~of_:1 ~compare) + let%test _ = is_lower_bound (Incl 2) ~of_:2 ~compare + let%test _ = is_lower_bound (Incl 2) ~of_:3 ~compare + + let%test _ = not (is_lower_bound (Excl 2) ~of_:1 ~compare) + let%test _ = not (is_lower_bound (Excl 2) ~of_:2 ~compare) + let%test _ = is_lower_bound (Excl 2) ~of_:3 ~compare + end) + +let%test_module "is_upper_bound" = + (module struct + let compare = Int.compare + + let%test _ = is_upper_bound Unbounded ~of_:Int.max_value ~compare + + let%test _ = is_upper_bound (Incl 2) ~of_:1 ~compare + let%test _ = is_upper_bound (Incl 2) ~of_:2 ~compare + let%test _ = not (is_upper_bound (Incl 2) ~of_:3 ~compare) + + let%test _ = is_upper_bound (Excl 2) ~of_:1 ~compare + let%test _ = not (is_upper_bound (Excl 2) ~of_:2 ~compare) + let%test _ = not (is_upper_bound (Excl 2) ~of_:3 ~compare) + end) + +let%test_module "check_range" = + (module struct + let compare = Int.compare + + let tests (lower, upper) cases = + List.iter cases ~f:(fun (n, comparison) -> + [%test_result: interval_comparison] + ~expect:comparison + (compare_to_interval_exn n ~lower ~upper ~compare); + [%test_result: bool] + ~expect:(match comparison with In_range -> true | _ -> false) + (interval_contains_exn n ~lower ~upper ~compare)) + + let%test_unit _ = + tests (Unbounded, Unbounded) + [ (Int.min_value, In_range) + ; (0, In_range) + ; (Int.max_value, In_range) + ] + + let%test_unit _ = + tests (Incl 2, Incl 4) + [ (1, Below_lower_bound) + ; (2, In_range) + ; (3, In_range) + ; (4, In_range) + ; (5, Above_upper_bound) + ] + + let%test_unit _ = + tests (Incl 2, Excl 4) + [ (1, Below_lower_bound) + ; (2, In_range) + ; (3, In_range) + ; (4, Above_upper_bound) + ; (5, Above_upper_bound) + ] + + let%test_unit _ = + tests (Excl 2, Incl 4) + [ (1, Below_lower_bound) + ; (2, Below_lower_bound) + ; (3, In_range) + ; (4, In_range) + ; (5, Above_upper_bound) + ] + + let%test_unit _ = + tests (Excl 2, Excl 4) + [ (1, Below_lower_bound) + ; (2, Below_lower_bound) + ; (3, In_range) + ; (4, Above_upper_bound) + ; (5, Above_upper_bound) + ] + end) + diff --git a/test/test_maybe_bound.mli b/test/test_maybe_bound.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_maybe_bound.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_nativeint.ml b/test/test_nativeint.ml new file mode 100644 index 0000000..9062421 --- /dev/null +++ b/test/test_nativeint.ml @@ -0,0 +1,7 @@ +open! Import +open! Nativeint + +let%expect_test "hash coherence" = + check_int_hash_coherence [%here] (module Nativeint); + [%expect {| |}]; +;; diff --git a/test/test_nativeint.mli b/test/test_nativeint.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_nativeint.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_nativeint_pow2.ml b/test/test_nativeint_pow2.ml new file mode 100644 index 0000000..7a9774d --- /dev/null +++ b/test/test_nativeint_pow2.ml @@ -0,0 +1,127 @@ +open! Import +open! Nativeint + +let examples = + [ -1n + ; 0n + ; 1n + ; 2n + ; 3n + ; 4n + ; 5n + ; 7n + ; 8n + ; 9n + ; 63n + ; 64n + ; 65n + ] +;; + +let examples_64_bit = + [ min_value + ; succ min_value + ; pred max_value + ; max_value ] +;; + +let print_for ints f = + List.iter ints ~f:(fun i -> + print_s [%message + "" + ~_:(i : nativeint) + ~_:(Or_error.try_with (fun () -> f i) : int Or_error.t)]) +;; + +let%expect_test "[floor_log2]" = + print_for examples floor_log2; + [%expect {| + (-1 (Error ("[Nativeint.floor_log2] got invalid input" -1))) + (0 (Error ("[Nativeint.floor_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 1)) + (4 (Ok 2)) + (5 (Ok 2)) + (7 (Ok 2)) + (8 (Ok 3)) + (9 (Ok 3)) + (63 (Ok 5)) + (64 (Ok 6)) + (65 (Ok 6)) |}]; +;; + +let%expect_test "[floor_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit floor_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Nativeint.floor_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Nativeint.floor_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 62)) + (9_223_372_036_854_775_807 (Ok 62)) |}]; +;; + +let%expect_test "[ceil_log2]" = + print_for examples ceil_log2; + [%expect {| + (-1 (Error ("[Nativeint.ceil_log2] got invalid input" -1))) + (0 (Error ("[Nativeint.ceil_log2] got invalid input" 0))) + (1 (Ok 0)) + (2 (Ok 1)) + (3 (Ok 2)) + (4 (Ok 2)) + (5 (Ok 3)) + (7 (Ok 3)) + (8 (Ok 3)) + (9 (Ok 4)) + (63 (Ok 6)) + (64 (Ok 6)) + (65 (Ok 7)) |}]; +;; + +let%expect_test "[ceil_log2]" [@tags "64-bits-only"] = + print_for examples_64_bit ceil_log2; + [%expect {| + (-9_223_372_036_854_775_808 ( + Error ("[Nativeint.ceil_log2] got invalid input" -9223372036854775808))) + (-9_223_372_036_854_775_807 ( + Error ("[Nativeint.ceil_log2] got invalid input" -9223372036854775807))) + (9_223_372_036_854_775_806 (Ok 63)) + (9_223_372_036_854_775_807 (Ok 63)) |}]; +;; + +let%test_module "nativeint_math" = + (module struct + + let test_cases () = + let cases = + [ 0b10101010n; 0b1010101010101010n; 0b101010101010101010101010n; + 0b10000000n; 0b1000000000001000n; 0b100000000000000000001000n; ] + in + match Word_size.word_size with + | W64 -> (* create some >32 bit values... *) + (* We can't use literals directly because the compiler complains on 32 bits. *) + let cases = + cases @ [ (0b1010101010101010n lsl 16) lor 0b1010101010101010n; + (0b1000000000000000n lsl 16) lor 0b0000000000001000n; ] + in + let added_cases = List.map cases ~f:(fun x -> x lsl 16) in + List.concat [ cases; added_cases ] + | W32 -> cases + ;; + + let%test_unit "ceil_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = ceil_pow2 x in + assert( (is_pow2 p2) && (p2 >= x && x >= (p2 / (of_int 2))) ) + ) + ;; + + let%test_unit "floor_pow2" = + List.iter (test_cases ()) + ~f:(fun x -> let p2 = floor_pow2 x in + assert( (is_pow2 p2) && (((of_int 2) * p2) >= x && x >= p2) ) + ) + ;; + end) diff --git a/test/test_not_found.mlt b/test/test_not_found.mlt new file mode 100644 index 0000000..dfeef17 --- /dev/null +++ b/test/test_not_found.mlt @@ -0,0 +1,18 @@ +open Base +open Expect_test_helpers_kernel +;; + +print_s [%sexp (Not_found_s [%message "foo"] : exn)];; +[%%expect {| (Not_found_s foo) |}];; + +Not_found;; +[%%expect {| +Line _, characters 0-9: +Error (Warning 3): deprecated: Not_found +[2016-09] this element comes from the stdlib distributed with OCaml. +Instead of raising [Not_found], consider using [raise_s] with an informative error +message. If code needs to distinguish [Not_found] from other exceptions, please change +it to handle both [Not_found] and [Not_found_s]. Then, instead of raising [Not_found], +raise [Not_found_s] with an informative error message. +|}];; + diff --git a/test/test_obj_array.ml b/test/test_obj_array.ml new file mode 100644 index 0000000..7a45b24 --- /dev/null +++ b/test/test_obj_array.ml @@ -0,0 +1,108 @@ +open! Import + +module Obj_array = Not_exposed_properly.Obj_array +open Obj_array + +let does_raise = Exn.does_raise + +let zero_obj = Caml.Obj.repr (0 : int) + +include + Test_blit.Test + (struct + type t = Caml.Obj.t + let equal = phys_equal + let of_bool b = Caml.Obj.repr (if b then 1 else 2 : int) + end) + (struct + type nonrec t = t [@@deriving sexp_of] + let create = create_zero + let get = get + let set = set + let length = length + end) + (Obj_array) +;; + +(* [create_zero] *) +let%test_unit _ = + let t = create_zero ~len:0 in + assert (length t = 0) +;; + +(* [create] *) +let%test_unit _ = + let str = Caml.Obj.repr "foo" in + let t = create ~len:2 str in + assert (phys_equal (get t 0) str); + assert (phys_equal (get t 1) str) +;; + +let%test_unit _ = + let float = Caml.Obj.repr 3.5 in + let t = create ~len:2 float in + assert (Caml.Obj.tag (Caml.Obj.repr t) = 0); (* not a double array *) + assert (phys_equal (get t 0) float); + assert (phys_equal (get t 1) float); + set t 1 (Caml.Obj.repr 4.); + assert (Float.(=) (Caml.Obj.obj (get t 1)) 4.); +;; + +(* [empty] *) +let%test _ = length empty = 0 + +let%test _ = does_raise (fun () -> get empty 0) + +(* [singleton] *) +let%test _ = length (singleton zero_obj) = 1 + +let%test _ = phys_equal (get (singleton zero_obj) 0) zero_obj + +let%test _ = does_raise (fun () -> get (singleton zero_obj) 1) + +let%test_unit _ = + let f = 13. in + let t = singleton (Caml.Obj.repr f) in + invariant t; + assert (Poly.equal (Caml.Obj.repr f) (get t 0)) +;; + +(* [get], [unsafe_get], [set], [unsafe_set], [unsafe_set_assuming_currently_int] *) +let%test_unit _ = + let t = create_zero ~len:1 in + assert (length t = 1); + assert (phys_equal (get t 0) zero_obj); + assert (phys_equal (unsafe_get t 0) zero_obj); + let one_obj = Caml.Obj.repr (1 : int) in + let check_get expect = + assert (phys_equal (get t 0) expect); + assert (phys_equal (unsafe_get t 0) expect); + in + set t 0 one_obj; + check_get one_obj; + unsafe_set t 0 zero_obj; + check_get zero_obj; + unsafe_set_assuming_currently_int t 0 one_obj; + check_get one_obj +;; + +(* [truncate] *) +let%test _ = does_raise (fun () -> truncate empty ~len:0) +let%test _ = does_raise (fun () -> truncate empty ~len:1) +let%test _ = does_raise (fun () -> truncate empty ~len:(-1)) +let%test _ = does_raise (fun () -> truncate (create_zero ~len:1) ~len:0) +let%test _ = does_raise (fun () -> truncate (create_zero ~len:1) ~len:2) + +let%test_unit _ = + let t = create_zero ~len:1 in + truncate t ~len:1; + assert (length t = 1) +;; + +let%test_unit _ = + let t = create_zero ~len:3 in + truncate t ~len:2; + assert (length t = 2); + truncate t ~len:1; + assert (length t = 1) +;; diff --git a/test/test_option.ml b/test/test_option.ml new file mode 100644 index 0000000..8a92649 --- /dev/null +++ b/test/test_option.ml @@ -0,0 +1,8 @@ +open! Import +open! Option + +let f = (+) +let%test _ = [%compare.equal: int t] (merge None None ~f) None +let%test _ = [%compare.equal: int t] (merge (Some 3) None ~f) (Some 3) +let%test _ = [%compare.equal: int t] (merge None (Some 3) ~f) (Some 3) +let%test _ = [%compare.equal: int t] (merge (Some 1) (Some 3) ~f) (Some 4) diff --git a/test/test_option.mli b/test/test_option.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_option.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_option_array.ml b/test/test_option_array.ml new file mode 100644 index 0000000..a39e732 --- /dev/null +++ b/test/test_option_array.ml @@ -0,0 +1,69 @@ +open! Import +open Option_array + + +let%test_module "Cheap_option" = ( + module struct + open For_testing.Unsafe_cheap_option + + let roundtrip_via_cheap_option (type a) (x : a) = + let opt : a t = some x in + assert (is_some opt); + assert (phys_equal (value_exn opt) x) + + let%test_unit _ = + roundtrip_via_cheap_option 0 + let%test_unit _ = + roundtrip_via_cheap_option 1 + let%test_unit _ = + roundtrip_via_cheap_option (ref 0) + let%test_unit _ = + roundtrip_via_cheap_option `x6e8ee3478e1d7449 + let%test_unit _ = + roundtrip_via_cheap_option 0.0 + let%test _ = + not (is_some none) + + let%test_unit "memory corruption" = + let make_list () = + List.init ~f:(fun i -> Some i) 5 + in + Caml.Gc.minor (); + let x = value_unsafe (some (make_list ())) in + Caml.Gc.minor (); + let _ = List.init ~f:(fun i -> Some (i*100)) 10000 in + [%test_result: Int.t Option.t List.t] + ~expect:(make_list ()) x + end) + +module Sequence = struct + let length = length + let get = get + let set = set +end + +include Base_for_tests.Test_blit.Test1_generic(struct + include Option + + let equal a b = Option.equal Bool.equal a b + let of_bool b = Some b + end) (struct + type nonrec 'a t = 'a t [@@deriving sexp] + type 'a z = 'a + include Sequence + + let create_bool ~len = init_some len ~f:(fun _ -> false) + end)(Option_array) + + +let%test_unit "floats are not re-boxed" = + let one = 1.0 in + let array = init_some 1 ~f:(fun _ -> one) in + assert (phys_equal one (get_some_exn array 0)) + +let%test_unit "segfault does not happen" = + (* if [Core_array] is used instead of [Uniform_array], this dies with a segfault *) + let _array = init 2 ~f:(fun i -> + if i = 0 then Some 1.0 else None) + in + () diff --git a/test/test_or_error.ml b/test/test_or_error.ml new file mode 100644 index 0000000..340bea4 --- /dev/null +++ b/test/test_or_error.ml @@ -0,0 +1,26 @@ +open! Import +open! Or_error + +let%test _ = [%compare.equal: string t] (errorf "foo %d" 13) (error_string "foo 13") + +let%test_unit _ = + for i = 0 to 10; do + assert ([%compare.equal: unit list t] + (combine_errors (List.init i ~f:(fun _ -> Ok ()))) + (Ok (List.init i ~f:(fun _ -> ())))); + done +let%test _ = Result.is_error (combine_errors [ error_string "" ]) +let%test _ = Result.is_error (combine_errors [ Ok (); error_string "" ]) + +let (=) = [%compare.equal: unit t] +let%test _ = combine_errors_unit [Ok (); Ok ()] = Ok () +let%test _ = combine_errors_unit [] = Ok () +let%test _ = + let a = Error.of_string "a" and b = Error.of_string "b" in + match combine_errors_unit [Ok (); Error a; Ok (); Error b] with + | Ok _ -> false + | Error e -> String.equal + (Error.to_string_hum e) + (Error.to_string_hum (Error.of_list [a;b])) +;; + diff --git a/test/test_or_error.mli b/test/test_or_error.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_or_error.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ordered_collection_common.ml b/test/test_ordered_collection_common.ml new file mode 100644 index 0000000..c5607bb --- /dev/null +++ b/test/test_ordered_collection_common.ml @@ -0,0 +1,66 @@ +open! Import +open! Ordered_collection_common + +let%test_unit "fast check_pos_len_exn is correct" = + let n_vals = + [ 0 + ; 1 + ; 2 + ; 10 + ; 100 + ; Int.max_value / 2 - 2 + ; Int.max_value / 2 - 1 + ; Int.max_value / 2 + ; Int.max_value - 2 + ; Int.max_value - 1 + ; Int.max_value + ] + in + let z_vals = + [ Int.min_value + ; Int.min_value + 1 + ; Int.min_value + 2 + ; Int.min_value / 2 + ; Int.min_value / 2 + 1 + ; Int.min_value / 2 + 2 + ; -100 + ; -10 + ; -2 + ; -1 + ] @ n_vals + in + List.iter z_vals ~f:(fun pos -> + List.iter z_vals ~f:(fun len -> + List.iter n_vals ~f:(fun total_length -> + assert + (Bool.equal + (Exn.does_raise (fun () -> Private.slow_check_pos_len_exn ~pos ~len ~total_length)) + (Exn.does_raise (fun () -> check_pos_len_exn ~pos ~len ~total_length)))))) +;; + +let%test_unit _ = + let vals = [ -1; 0; 1; 2; 3 ] in + List.iter [ 0; 1; 2 ] ~f:(fun total_length -> + List.iter vals ~f:(fun pos -> + List.iter vals ~f:(fun len -> + let result = Result.try_with (fun () -> check_pos_len_exn ~pos ~len ~total_length) in + let valid = pos >= 0 && len >= 0 && len <= total_length - pos in + assert (Bool.equal valid (Result.is_ok result))))) +;; + +let%test_unit _ = + let opts = [ None; Some (-1); Some 0; Some 1; Some 2 ] in + List.iter [ 0; 1; 2 ] ~f:(fun total_length -> + List.iter opts ~f:(fun pos -> + List.iter opts ~f:(fun len -> + let result = Result.try_with (fun () -> get_pos_len_exn () ?pos ?len ~total_length) in + let pos = match pos with Some x -> x | None -> 0 in + let len = match len with Some x -> x | None -> total_length - pos in + let valid = pos >= 0 && len >= 0 && len <= total_length - pos in + match result with + | Error _ -> assert (not valid); + | Ok (pos', len') -> + assert (pos' = pos); + assert (len' = len); + assert valid))) +;; diff --git a/test/test_ordered_collection_common.mli b/test/test_ordered_collection_common.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_ordered_collection_common.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ordering.ml b/test/test_ordering.ml new file mode 100644 index 0000000..b9dfce9 --- /dev/null +++ b/test/test_ordering.ml @@ -0,0 +1,15 @@ +open! Import +open! Ordering + +let%test _ = equal (of_int (-10)) Less +let%test _ = equal (of_int (-1) ) Less +let%test _ = equal (of_int 0 ) Equal +let%test _ = equal (of_int 1 ) Greater +let%test _ = equal (of_int 10 ) Greater + +let%test _ = equal (of_int (Int.compare 0 1)) Less +let%test _ = equal (of_int (Int.compare 1 1)) Equal +let%test _ = equal (of_int (Int.compare 1 0)) Greater + +let%test _ = List.for_all all ~f:(fun t -> equal t (t |> to_int |> of_int)) +let%test _ = List.for_all [ -1; 0; 1 ] ~f:(fun i -> i = (i |> of_int |> to_int)) diff --git a/test/test_popcount.ml b/test/test_popcount.ml new file mode 100644 index 0000000..8304384 --- /dev/null +++ b/test/test_popcount.ml @@ -0,0 +1,43 @@ +open! Import + +module type T = sig + type t [@@deriving compare, sexp_of] + + (* for implementing popcount_naive *) + val zero : t + val one : t + val (+) : t -> t -> t + val (lsr) : t -> int -> t + val (land) : t -> t -> t + + val quickcheck_generator : t Quickcheck.Generator.t + val to_int_exn : t -> int + + val popcount : t -> int +end + +module Make(Int : T) = struct + let popcount_naive (int : Int.t) : int = + let open Int in + let rec loop n count = + if Int.compare n zero <> 0 + then (loop (n lsr 1) (count + (n land one))) + else count + in + loop int zero + |> to_int_exn + ;; + + let%test_unit _ = + Quickcheck.test Int.quickcheck_generator + ~sexp_of:[%sexp_of: Int.t] + ~f:(fun int -> + let expect = popcount_naive int in + [%test_result: int] ~expect (Int.popcount int)) + ;; +end + +include Make(Quickcheck.Int) +include Make(Quickcheck.Int32) +include Make(Quickcheck.Int64) +include Make(Quickcheck.Nativeint) diff --git a/test/test_popcount.mli b/test/test_popcount.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_popcount.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ppx_compare_lib.ml b/test/test_ppx_compare_lib.ml new file mode 100644 index 0000000..d418d3a --- /dev/null +++ b/test/test_ppx_compare_lib.ml @@ -0,0 +1,104 @@ +open! Import +open! Ppx_compare_lib + +module Unit = struct + type t = unit [@@deriving compare, sexp_of] +end + +module type T = sig + type t [@@deriving compare, sexp_of] +end + +let test (type a) (module T : T with type t = a) ordered = + List.iteri ordered ~f:(fun i ti -> + List.iteri ordered ~f:(fun j tj -> + require [%here] + (Ordering.equal + (Ordering.of_int (T.compare ti tj)) + (Ordering.of_int (Int.compare i j))) + ~if_false_then_print_s:( + lazy [%message "" ~_:(ti : T.t) ~_:(tj : T.t)]))); +;; + +let%expect_test "bool, char, unit" = + test (module Bool) [ false; true ]; + test (module Char) [ '\000'; 'a'; 'b' ]; + [%expect {| |}]; + test (module Unit) [ () ]; + [%expect {| |}]; +;; + +module type Min_zero_max = sig + include T + val min_value : t + val max_value : t + val zero : t +end + +let test_min_zero_max (type a) (module T : Min_zero_max with type t = a)= + test (module T) + [ T.min_value + ; T.zero + ; T.max_value ] +;; + +let%expect_test _ = + test_min_zero_max (module Float); + test_min_zero_max (module Int); + test_min_zero_max (module Int32); + test_min_zero_max (module Int64); + test_min_zero_max (module Nativeint); +;; + +let%expect_test "option" = + test + (module struct + type t = int option [@@deriving compare, sexp_of] + end) + [ None + ; Some 0 + ; Some 1 ] +;; +let%expect_test "ref" = + test + (module struct + type t = int ref [@@deriving compare, sexp_of] + end) + ([ -1; 0; 1 ] |> List.map ~f:ref); +;; + +module type Sequence = sig + type 'a t [@@deriving compare, sexp_of] + val of_list : 'a list -> 'a t +end + +let test_sequence (module T : Sequence) ordered = + test + (module struct + type t = int T.t [@@deriving compare, sexp_of] + end) + (ordered |> List.map ~f:T.of_list) +;; + +let%expect_test "array, list" = + test_sequence (module Array) + [ [ ] + ; [ 1 ] + ; [ 2 ] + ; [ 1; 2 ] + ; [ 2; 1 ]]; + test_sequence (module List) + [ [ ] + ; [ 1 ] + ; [ 1; 2 ] + ; [ 2 ] + ; [ 2; 1 ]]; +;; + +let%expect_test "[compare_abstract]" = + show_raise (fun () -> compare_abstract ~type_name:"TY" () ()); + [%expect {| + (raised ( + Failure + "Compare called on the type TY, which is abstract in an implementation.")) |}]; +;; diff --git a/test/test_ppx_compare_lib.mli b/test/test_ppx_compare_lib.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_ppx_compare_lib.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_queue.ml b/test/test_queue.ml new file mode 100644 index 0000000..7345577 --- /dev/null +++ b/test/test_queue.ml @@ -0,0 +1,902 @@ +open! Core_kernel + +let%test_module _ = + (module (struct + + open Queue + + module type S = S + + let does_raise = Exn.does_raise + + type nonrec 'a t = 'a t [@@deriving bin_io, sexp] + + let capacity = capacity + let set_capacity = set_capacity + + let%test_unit _ = + let t = create () in + assert (capacity t = 1); + enqueue t 1; + assert (capacity t = 1); + enqueue t 2; + assert (capacity t = 2); + enqueue t 3; + assert (capacity t = 4); + set_capacity t 0; + assert (capacity t = 4); + set_capacity t 3; + assert (capacity t = 4); + set_capacity t 100; + assert (capacity t = 128); + enqueue t 4; + enqueue t 5; + set_capacity t 0; + assert (capacity t = 8); + set_capacity t (-1); + assert (capacity t = 8); + ;; + + + let round_trip_sexp t = + let sexp = sexp_of_t Int.sexp_of_t t in + let t' = t_of_sexp Int.t_of_sexp sexp in + assert (to_list t = to_list t') + ;; + let%test_unit _ = round_trip_sexp (of_list [1; 2; 3; 4]) + let%test_unit _ = round_trip_sexp (create ()) + let%test_unit _ = round_trip_sexp (of_list []) + + let invariant = invariant + + let create = create + let%test_unit _ = + let t = create () in + assert (length t = 0); + assert (capacity t = 1); + ;; + let%test_unit _ = + let t = create ~capacity:0 () in + assert (length t = 0); + assert (capacity t = 1); + ;; + let%test_unit _ = + let t = create ~capacity:6 () in + assert (length t = 0); + assert (capacity t = 8); + ;; + let%test_unit _ = + assert (does_raise (fun () -> (create ~capacity:(-1) () : _ Queue.t))) + ;; + + let singleton = singleton + let%test_unit _ = + let t = singleton 7 in + assert (length t = 1); + assert (capacity t = 1); + assert (dequeue t = Some 7); + assert (dequeue t = None) + ;; + + let init = init + let%test_unit _ = + let t = init 0 ~f:(fun _ -> assert false) in + assert (length t = 0); + assert (capacity t = 1); + assert (dequeue t = None); + ;; + let%test_unit _ = + let t = init 3 ~f:(fun i -> i * 2) in + assert (length t = 3); + assert (capacity t = 4); + assert (dequeue t = Some 0); + assert (dequeue t = Some 2); + assert (dequeue t = Some 4); + assert (dequeue t = None); + ;; + let%test_unit _ = + assert (does_raise (fun () -> (init (-1) ~f:(fun _ -> ()) : unit Queue.t))) + ;; + + let get = get + let set = set + let%test_unit _ = + let t = create () in + let get_opt t i = Option.try_with (fun () -> get t i) in + assert (get_opt t 0 = None); + assert (get_opt t (-1) = None); + assert (get_opt t 10 = None); + List.iter [ -1; 0; 1 ] ~f:(fun i -> assert (does_raise (fun () -> set t i 0))); + enqueue t 0; + enqueue t 1; + enqueue t 2; + assert (get_opt t 0 = Some 0); + assert (get_opt t 1 = Some 1); + assert (get_opt t 2 = Some 2); + assert (get_opt t 3 = None); + ignore (dequeue_exn t); + assert (get_opt t 0 = Some 1); + assert (get_opt t 1 = Some 2); + assert (get_opt t 2 = None); + set t 0 3; + assert (get_opt t 0 = Some 3); + assert (get_opt t 1 = Some 2); + List.iter [ -1; 2 ] ~f:(fun i -> assert (does_raise (fun () -> set t i 0))) + ;; + + let map = map + let%test_unit _ = + for i = 0 to 5 do + let l = List.init i ~f:Fn.id in + let t = of_list l in + let f x = x * 2 in + let t' = map t ~f in + assert (to_list t' = List.map l ~f); + done + ;; + + let%test_unit _ = + let t = create () in + let t' = map t ~f:(fun x -> x * 2) in + assert (length t' = length t); + assert (length t' = 0); + assert (to_list t' = []) + ;; + + let mapi = mapi + let%test_unit _ = + for i = 0 to 5 do + let l = List.init i ~f:Fn.id in + let t = of_list l in + let f i x = (i, x * 2) in + let t' = mapi t ~f in + assert (to_list t' = List.mapi l ~f); + done + + let%test_unit _ = + let t = create () in + let t' = mapi t ~f:(fun i x -> (i, x * 2)) in + assert (length t' = length t); + assert (length t' = 0); + assert (to_list t' = []) + ;; + + include Test_container.Test_S1 (Queue) + + let dequeue_exn = dequeue_exn + let enqueue = enqueue + let peek = peek + let peek_exn = peek_exn + let last = last + let last_exn = last_exn + let%test_unit _ = + let t = create () in + assert (is_none (peek t)); + assert (is_none (last t)); + enqueue t 1; + enqueue t 2; + assert (peek t = Some 1); + assert (peek_exn t = 1); + assert (last t = Some 2); + assert (last_exn t = 2); + assert (dequeue_exn t = 1); + assert (dequeue_exn t = 2); + assert (does_raise (fun () -> dequeue_exn t)); + assert (does_raise (fun () -> peek_exn t)); + assert (does_raise (fun () -> last_exn t)) + ;; + + let enqueue_all = enqueue_all + let%test_unit _ = + let t = create () in + enqueue_all t [1; 2; 3]; + assert (dequeue_exn t = 1); + assert (dequeue_exn t = 2); + assert (last t = Some 3); + enqueue_all t [4; 5]; + assert (last t = Some 5); + assert (dequeue_exn t = 3); + assert (dequeue_exn t = 4); + assert (dequeue_exn t = 5); + assert (does_raise (fun () -> dequeue_exn t)); + enqueue_all t []; + assert (does_raise (fun () -> dequeue_exn t)); + ;; + + let of_list = of_list + let to_list = to_list + + let%test_unit _ = + for i = 0 to 4 do + let list = List.init i ~f:Fn.id in + assert (Poly.equal (to_list (of_list list)) list); + done + ;; + + let%test _ = + let t = create () in + begin + for i = 1 to 5 do enqueue t i done; + to_list t = [1;2;3;4;5] + end + ;; + + let of_array = of_array + let to_array = to_array + + let%test_unit _ = + for len = 0 to 4 do + let array = Array.init len ~f:Fn.id in + assert (Poly.equal (to_array (of_array array)) array); + done + ;; + + let compare = compare + let equal = equal + + let%test_module "comparisons" = + (module struct + + let sign x = if x < 0 then ~-1 else if x > 0 then 1 else 0 + + let test t1 t2 = + [%test_result: bool] + (equal Int.equal t1 t2) + ~expect:(List.equal Int.equal (to_list t1) (to_list t2)); + [%test_result: int] + (sign (compare Int.compare t1 t2)) + ~expect:(sign (List.compare Int.compare (to_list t1) (to_list t2))) + ;; + + let lists = + [ [] + ; [ 1 ] + ; [ 2 ] + ; [ 1; 1 ] + ; [ 1; 2 ] + ; [ 2; 1 ] + ; [ 1; 1; 1 ] + ; [ 1; 2; 3 ] + ; [ 1; 2; 4 ] + ; [ 1; 2; 4; 8 ] + ; [ 1; 2; 3; 4; 5 ] + ] + ;; + + let%test_unit _ = (* [phys_equal] inputs *) + List.iter lists ~f:(fun list -> + let t = of_list list in + test t t) + ;; + + let%test_unit _ = + List.iter lists ~f:(fun list1 -> + List.iter lists ~f:(fun list2 -> + test (of_list list1) (of_list list2))) + ;; + end) + + let clear = clear + + let blit_transfer = blit_transfer + + let%test_unit _ = + let q_list = [1; 2; 3; 4] in + let q = of_list q_list in + let q' = create () in + blit_transfer ~src:q ~dst:q' (); + assert (to_list q' = q_list); + assert (to_list q = []) + ;; + + let%test_unit _ = + let q = of_list [1; 2; 3; 4] in + let q' = create () in + blit_transfer ~src:q ~dst:q' ~len:2 (); + assert (to_list q' = [1; 2]); + assert (to_list q = [3; 4]) + ;; + + let%test_unit "blit_transfer on wrapped queues" = + let list = [1; 2; 3; 4] in + let q = of_list list in + let q' = copy q in + ignore (dequeue_exn q); + ignore (dequeue_exn q); + ignore (dequeue_exn q'); + ignore (dequeue_exn q'); + ignore (dequeue_exn q'); + enqueue q 5; + enqueue q 6; + blit_transfer ~src:q ~dst:q' ~len:3 (); + assert (to_list q' = [4; 3; 4; 5]); + assert (to_list q = [6]) + ;; + + let copy = copy + let dequeue = dequeue + let filter = filter + let filteri = filteri + let filter_inplace = filter_inplace + let filteri_inplace = filteri_inplace + let concat_map = concat_map + let concat_mapi = concat_mapi + let filter_map = filter_map + let filter_mapi = filter_mapi + let counti = counti + let existsi = existsi + let for_alli = for_alli + let iter = iter + let iteri = iteri + let foldi = foldi + let findi = findi + let find_mapi = find_mapi + + let%test_module "Linked_queue bisimulation" = + (module struct + module type Queue_intf = sig + type 'a t [@@deriving sexp_of] + + val create : unit -> 'a t + val enqueue : 'a t -> 'a -> unit + val dequeue : 'a t -> 'a option + val to_array : 'a t -> 'a array + val fold : 'a t -> init:'b -> f:('b -> 'a -> 'b) -> 'b + val foldi : 'a t -> init:'b -> f:(int -> 'b -> 'a -> 'b) -> 'b + val iter : 'a t -> f:('a -> unit) -> unit + val iteri : 'a t -> f:(int -> 'a -> unit) -> unit + val length : 'a t -> int + val clear : 'a t -> unit + val concat_map : 'a t -> f:( 'a -> 'b list) -> 'b t + val concat_mapi : 'a t -> f:(int -> 'a -> 'b list) -> 'b t + val filter_map : 'a t -> f:( 'a -> 'b option) -> 'b t + val filter_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b t + val filter : 'a t -> f:( 'a -> bool) -> 'a t + val filteri : 'a t -> f:(int -> 'a -> bool) -> 'a t + val filter_inplace : 'a t -> f:( 'a -> bool) -> unit + val filteri_inplace : 'a t -> f:(int -> 'a -> bool) -> unit + val map : 'a t -> f:( 'a -> 'b) -> 'b t + val mapi : 'a t -> f:(int -> 'a -> 'b) -> 'b t + val counti : 'a t -> f:(int -> 'a -> bool) -> int + val existsi : 'a t -> f:(int -> 'a -> bool) -> bool + val for_alli : 'a t -> f:(int -> 'a -> bool) -> bool + val findi : 'a t -> f:(int -> 'a -> bool) -> (int * 'a) option + val find_mapi : 'a t -> f:(int -> 'a -> 'b option) -> 'b option + val transfer : src:'a t -> dst:'a t -> unit + val copy : 'a t -> 'a t + end + + module That_queue : Queue_intf = Linked_queue + + module This_queue : Queue_intf = struct + include Queue + let create () = create () + let transfer ~src ~dst = blit_transfer ~src ~dst () + end + + let this_to_string this_t = + Sexp.to_string (this_t |> [%sexp_of: int This_queue.t]) + ;; + + let that_to_string that_t = + Sexp.to_string (that_t |> [%sexp_of: int That_queue.t]) + ;; + + let array_string arr = + Sexp.to_string (arr |> [%sexp_of: int array]) + ;; + + let create () = (This_queue.create (), That_queue.create ()) + + + let enqueue (t_a, t_b) v = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + This_queue.enqueue t_a v; + That_queue.enqueue t_b v; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if end_a <> end_b + then failwithf "enqueue transition failure of: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let iter (t_a, t_b) = + let r_a, r_b = ref 0, ref 0 in + This_queue.iter t_a ~f:(fun x -> r_a := !r_a + x); + That_queue.iter t_b ~f:(fun x -> r_b := !r_b + x); + if !r_a <> !r_b + then failwithf "error in iter: %s (from %s) <> %s (from %s)" + (Int.to_string !r_a) + (this_to_string t_a) + (Int.to_string !r_b) + (that_to_string t_b) + () + ;; + + let iteri (t_a, t_b) = + let r_a, r_b = ref 0, ref 0 in + This_queue.iteri t_a ~f:(fun i x -> r_a := !r_a + x lxor i); + That_queue.iteri t_b ~f:(fun i x -> r_b := !r_b + x lxor i); + if !r_a <> !r_b + then failwithf "error in iteri: %s (from %s) <> %s (from %s)" + (Int.to_string !r_a) + (this_to_string t_a) + (Int.to_string !r_b) + (that_to_string t_b) + () + ;; + + let dequeue (t_a, t_b) = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + let a, b = This_queue.dequeue t_a, That_queue.dequeue t_b in + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if a <> b || end_a <> end_b + then failwithf "error in dequeue: %s (%s -> %s) <> %s (%s -> %s)" + (Option.value ~default:"None" (Option.map a ~f:Int.to_string)) + (array_string start_a) + (array_string end_a) + (Option.value ~default:"None" (Option.map b ~f:Int.to_string)) + (array_string start_b) + (array_string end_b) + () + ;; + + let clear (t_a, t_b) = + This_queue.clear t_a; + That_queue.clear t_b; + ;; + + let is_even x = (x land 1) = 0 + + let filter (t_a, t_b) = + let t_a' = This_queue.filter t_a ~f:is_even in + let t_b' = That_queue.filter t_b ~f:is_even in + if This_queue.to_array t_a' <> That_queue.to_array t_b' + then failwithf "error in filter: %s -> %s vs. %s -> %s" + (this_to_string t_a) + (this_to_string t_a') + (that_to_string t_b) + (that_to_string t_b') + () + ;; + + let filteri (t_a, t_b) = + let t_a' = This_queue.filteri t_a ~f:(fun i j -> is_even i = is_even j) in + let t_b' = That_queue.filteri t_b ~f:(fun i j -> is_even i = is_even j) in + if This_queue.to_array t_a' <> That_queue.to_array t_b' + then failwithf "error in filteri: %s -> %s vs. %s -> %s" + (this_to_string t_a) + (this_to_string t_a') + (that_to_string t_b) + (that_to_string t_b') + () + ;; + + let filter_inplace (t_a, t_b) = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + This_queue.filter_inplace t_a ~f:is_even; + That_queue.filter_inplace t_b ~f:is_even; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if end_a <> end_b + then failwithf "error in filter_inplace: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let filteri_inplace (t_a, t_b) = + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + let f i x = is_even i = is_even x in + This_queue.filteri_inplace t_a ~f; + That_queue.filteri_inplace t_b ~f; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + if end_a <> end_b + then failwithf "error in filteri_inplace: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let concat_map (t_a, t_b) = + let f x = [x; x + 1; x + 2] in + let t_a' = This_queue.concat_map t_a ~f in + let t_b' = That_queue.concat_map t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in concat_map: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let concat_mapi (t_a, t_b) = + let f i x = [x; x + 1; x + 2; x + i] in + let t_a' = This_queue.concat_mapi t_a ~f in + let t_b' = That_queue.concat_mapi t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in concat_mapi: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let filter_map (t_a, t_b) = + let f x = if is_even x then None else Some (x + 1) in + let t_a' = This_queue.filter_map t_a ~f in + let t_b' = That_queue.filter_map t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in filter_map: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let filter_mapi (t_a, t_b) = + let f i x = if is_even i = is_even x then None else Some (x + 1 + i) in + let t_a' = This_queue.filter_mapi t_a ~f in + let t_b' = That_queue.filter_mapi t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in filter_mapi: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let map (t_a, t_b) = + let f x = x * 7 in + let t_a' = This_queue.map t_a ~f in + let t_b' = That_queue.map t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in map: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let mapi (t_a, t_b) = + let f i x = (x + 3) lxor i in + let t_a' = This_queue.mapi t_a ~f in + let t_b' = That_queue.mapi t_b ~f in + if (This_queue.to_array t_a') <> (That_queue.to_array t_b') + then failwithf "error in mapi: %s (for %s) <> %s (for %s)" + (this_to_string t_a') + (this_to_string t_a) + (that_to_string t_b') + (that_to_string t_b) + () + ;; + + let counti (t_a, t_b) = + let f i x = i < 7 && (i % 7 = x % 7) in + let a' = This_queue.counti t_a ~f in + let b' = That_queue.counti t_b ~f in + if a' <> b' + then failwithf "error in counti: %d (for %s) <> %d (for %s)" + (a') + (this_to_string t_a) + (b') + (that_to_string t_b) + () + ;; + + let existsi (t_a, t_b) = + let f i x = i < 7 && (i % 7 = x % 7) in + let a' = This_queue.existsi t_a ~f in + let b' = That_queue.existsi t_b ~f in + if a' <> b' + then failwithf "error in existsi: %b (for %s) <> %b (for %s)" + (a') + (this_to_string t_a) + (b') + (that_to_string t_b) + () + ;; + + let for_alli (t_a, t_b) = + let f i x = i >= 7 || (i % 7 <> x % 7) in + let a' = This_queue.for_alli t_a ~f in + let b' = That_queue.for_alli t_b ~f in + if a' <> b' + then failwithf "error in for_alli: %b (for %s) <> %b (for %s)" + (a') + (this_to_string t_a) + (b') + (that_to_string t_b) + () + ;; + + let findi (t_a, t_b) = + let f i x = i < 7 && (i % 7 = x % 7) in + let a' = This_queue.findi t_a ~f in + let b' = That_queue.findi t_b ~f in + if a' <> b' + then failwithf "error in findi: %s (for %s) <> %s (for %s)" + (Sexp.to_string ([%sexp_of: (int * int) option] a')) + (this_to_string t_a) + (Sexp.to_string ([%sexp_of: (int * int) option] b')) + (that_to_string t_b) + () + ;; + + let find_mapi (t_a, t_b) = + let f i x = if i < 7 && (i % 7 = x % 7) then Some (i + x) else None in + let a' = This_queue.find_mapi t_a ~f in + let b' = That_queue.find_mapi t_b ~f in + if a' <> b' + then failwithf "error in find_mapi: %s (for %s) <> %s (for %s)" + (Sexp.to_string ([%sexp_of: int option] a')) + (this_to_string t_a) + (Sexp.to_string ([%sexp_of: int option] b')) + (that_to_string t_b) + () + ;; + + let copy (t_a, t_b) = + let copy_a = This_queue.copy t_a in + let copy_b = That_queue.copy t_b in + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + let end_a = This_queue.to_array copy_a in + let end_b = That_queue.to_array copy_b in + if end_a <> end_b + then failwithf "error in copy: %s -> %s vs. %s -> %s" + (array_string start_a) + (array_string end_a) + (array_string start_b) + (array_string end_b) + () + ;; + + let transfer (t_a, t_b) = + let dst_a = This_queue.create () in + let dst_b = That_queue.create () in + (* sometimes puts some elements in the destination queues *) + if Random.bool () + then begin + List.iter [ 1; 2; 3; 4; 5 ] ~f:(fun elem -> + This_queue.enqueue dst_a elem; + That_queue.enqueue dst_b elem); + end; + let start_a = This_queue.to_array t_a in + let start_b = That_queue.to_array t_b in + This_queue.transfer ~src:t_a ~dst:dst_a; + That_queue.transfer ~src:t_b ~dst:dst_b; + let end_a = This_queue.to_array t_a in + let end_b = That_queue.to_array t_b in + let end_a' = This_queue.to_array dst_a in + let end_b' = That_queue.to_array dst_b in + if end_a' <> end_b' || end_a <> end_b + then failwithf "error in transfer: %s -> (%s, %s) vs. %s -> (%s, %s)" + (array_string start_a) + (array_string end_a) + (array_string end_a') + (array_string start_b) + (array_string end_b) + (array_string end_b) + () + ;; + + let fold_check (t_a, t_b) = + let make_list fold t = + fold t ~init:[] ~f:(fun acc x -> x :: acc) + in + let this_l = make_list This_queue.fold t_a in + let that_l = make_list That_queue.fold t_b in + if this_l <> that_l + then failwithf "error in fold: %s (from %s) <> %s (from %s)" + (Sexp.to_string (this_l |> [%sexp_of: int list])) + (this_to_string t_a) + (Sexp.to_string (that_l |> [%sexp_of: int list])) + (that_to_string t_b) + () + ;; + + let foldi_check (t_a, t_b) = + let make_list foldi t = + foldi t ~init:[] ~f:(fun i acc x -> (i,x) :: acc) + in + let this_l = make_list This_queue.foldi t_a in + let that_l = make_list That_queue.foldi t_b in + if this_l <> that_l + then failwithf "error in foldi: %s (from %s) <> %s (from %s)" + (Sexp.to_string (this_l |> [%sexp_of: (int * int) list])) + (this_to_string t_a) + (Sexp.to_string (that_l |> [%sexp_of: (int * int) list])) + (that_to_string t_b) + () + ;; + + let length_check (t_a, t_b) = + let this_len = This_queue.length t_a in + let that_len = That_queue.length t_b in + if this_len <> that_len + then failwithf "error in length: %i (for %s) <> %i (for %s)" + this_len (this_to_string t_a) + that_len (that_to_string t_b) + () + ;; + + let%test_unit _ = + let t = create () in + let rec loop ~all_ops ~non_empty_ops = + if all_ops <= 0 && non_empty_ops <= 0 + then begin + let (t_a, t_b) = t in + let arr_a = This_queue.to_array t_a in + let arr_b = That_queue.to_array t_b in + if arr_a <> arr_b + then failwithf "queue final states not equal: %s vs. %s" + (array_string arr_a) + (array_string arr_b) + () + end else begin + let queue_was_empty = This_queue.length (fst t) = 0 in + let r = Random.int 195 in + begin + if r < 60 + then enqueue t (Random.int 10_000) + else if r < 65 + then dequeue t + else if r < 70 + then clear t + else if r < 80 + then iter t + else if r < 85 + then iteri t + else if r < 90 + then fold_check t + else if r < 95 + then foldi_check t + else if r < 100 + then filter t + else if r < 105 + then filteri t + else if r < 110 + then concat_map t + else if r < 115 + then concat_mapi t + else if r < 120 + then transfer t + else if r < 130 + then filter_map t + else if r < 135 + then filter_mapi t + else if r < 140 + then copy t + else if r < 150 + then filter_inplace t + else if r < 155 + then for_alli t + else if r < 160 + then existsi t + else if r < 165 + then counti t + else if r < 170 + then findi t + else if r < 175 + then find_mapi t + else if r < 180 + then map t + else if r < 185 + then mapi t + else if r < 190 + then filteri_inplace t + else if r < 195 + then length_check t + else failwith "Impossible: We did [Random.int 195] above" + end; + loop + ~all_ops:(all_ops - 1) + ~non_empty_ops:(if queue_was_empty then non_empty_ops else non_empty_ops - 1) + end + in + loop ~all_ops:30_000 ~non_empty_ops:20_000 + ;; + end) + + let binary_search = binary_search + let binary_search_segmented = binary_search_segmented + + let%test_unit "modification-during-iteration" = + let x = `A 0 in + let t = of_list [x; x] in + let f (`A n) = ignore n; clear t in + assert (does_raise (fun () -> iter t ~f)) + ;; + + let%test_unit "more-modification-during-iteration" = + let nested_iter_okay = ref false in + let t = of_list [ `iter; `clear ] in + assert (does_raise (fun () -> + iter t ~f:(function + | `iter -> iter t ~f:ignore; nested_iter_okay := true + | `clear -> clear t))); + assert !nested_iter_okay + ;; + + let%test_unit "modification-during-filter" = + let reached_unreachable = ref false in + let t = of_list [`clear; `unreachable] in + let f x = + match x with + | `clear -> clear t; false + | `unreachable -> reached_unreachable := true; false + in + assert (does_raise (fun () -> filter t ~f)); + assert (not !reached_unreachable) + ;; + + let%test_unit "modification-during-filter-inplace" = + let reached_unreachable = ref false in + let t = of_list [`drop_this; `enqueue_new_element; `unreachable] in + let f x = + begin match x with + | `drop_this | `new_element -> () + | `enqueue_new_element -> enqueue t `new_element + | `unreachable -> reached_unreachable := true + end; + false + in + assert (does_raise (fun () -> filter_inplace t ~f)); + (* even though we said to drop the first element, the aborted call to [filter_inplace] + shouldn't have made that change *) + assert (peek_exn t = `drop_this); + assert (not !reached_unreachable) + ;; + + let%test_unit "filter-inplace-during-iteration" = + let reached_unreachable = ref false in + let t = of_list [`filter_inplace; `unreachable] in + let f x = + match x with + | `filter_inplace -> filter_inplace t ~f:(fun _ -> false) + | `unreachable -> reached_unreachable := true + in + assert (does_raise (fun () -> iter t ~f)); + assert (not !reached_unreachable) + ;; + + module Stable = struct + module V1 = Stable.V1 + include Stable_unit_test.Make (struct + type nonrec t = int V1.t [@@deriving sexp, bin_io, compare] + let equal = [%compare.equal: t] + let tests = + let manipulated = Queue.of_list [0;3;6;1] in + ignore (Queue.dequeue_exn manipulated : int); + ignore (Queue.dequeue_exn manipulated : int); + Queue.enqueue manipulated 4; + [ Queue.of_list [], "()", "\000" + ; Queue.of_list [1; 2; 6; 4], "(1 2 6 4)", "\004\001\002\006\004" + ; manipulated, "(6 1 4)", "\003\006\001\004" + ] + end) + end + end + (* This signature is here to remind us to update the unit tests whenever we + change [Core_queue]. *) + : module type of Queue)) diff --git a/test/test_queue.mli b/test/test_queue.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_queue.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_random.ml b/test/test_random.ml new file mode 100644 index 0000000..bb88758 --- /dev/null +++ b/test/test_random.ml @@ -0,0 +1,185 @@ +open! Import +open! Random + +module State = struct + include State + + let%test_unit "random int above 2^30" [@tags "64-bits-only"] = + let state = make [| 1 ; 2 ; 3 ; 4 ; 5 |] in + for _ = 1 to 100 do + let bound = Int.shift_left 1 40 in + let n = int state bound in + if n < 0 || n >= bound then + failwith (Printf.sprintf "random result %d out of bounds (0,%d)" n (bound-1)) + done + ;; +end + +external random_seed: unit -> Caml.Obj.t = "caml_sys_random_seed";; +let%test_unit _ = + (* test that the return type of "caml_sys_random_seed" is what we expect *) + let module Obj = Caml.Obj in + let obj = random_seed () in + assert (Obj.is_block obj); + assert (Obj.tag obj = Obj.tag (Obj.repr [| 13 |])); + for i = 0 to Obj.size obj - 1 do + assert (Obj.is_int (Obj.field obj i)); + done +;; + +module type T = sig + type t [@@deriving compare, sexp_of] +end +;; + +(* We test that [count] trials of [generate ()] all produce values between [min, max], and + generate at least one value between [lo, hi]. *) +let test (type t) here m count generate ~min ~max ~check_range:(lo, hi) = + let (module T : T with type t = t) = m in + let between t ~lower_bound ~upper_bound = + T.compare t lower_bound >= 0 && + T.compare t upper_bound <= 0 + in + let generated = + List.init count ~f:(fun _ -> generate ()) + |> List.dedup_and_sort ~compare:T.compare + in + require here + (List.for_all generated ~f:(fun t -> + between t ~lower_bound:min ~upper_bound:max)) + ~if_false_then_print_s: + (lazy [%message + "generated values outside of bounds" + (min : T.t) + (max : T.t) + (generated : T.t list)]); + require here + (List.exists generated ~f:(fun t -> + between t ~lower_bound:lo ~upper_bound:hi)) + ~if_false_then_print_s: + (lazy [%message + "did not generate value inside range" + (lo : T.t) + (hi : T.t) + (generated : T.t list)]); +;; + +let%expect_test "float" = + test [%here] (module Float) 1_000 (fun () -> float 100.) + ~min:0. ~max:100. ~check_range:(10.,20.); + [%expect {||}]; +;; + +let%expect_test "float_range" = + test [%here] (module Float) 1_000 (fun () -> float_range (-100.) 100.) + ~min:(-100.) ~max:100. ~check_range:(-20.,-10.); + [%expect {||}]; +;; + +let%expect_test "int" = + test [%here] (module Int) 1_000 (fun () -> int 100) + ~min:0 ~max:99 ~check_range:(10,20); + [%expect {||}]; +;; + +let%expect_test "int_incl" = + test [%here] (module Int) 1_000 (fun () -> int_incl (-100) 100) + ~min:(-100) ~max:100 ~check_range:(-20,-10); + [%expect {||}]; + test [%here] (module Int) 1_000 (fun () -> int_incl 0 Int.max_value) + ~min:0 ~max:Int.max_value ~check_range:(0, Int.max_value / 100); + [%expect {||}]; + test [%here] (module Int) 1_000 (fun () -> int_incl Int.min_value Int.max_value) + ~min:Int.min_value ~max:Int.max_value + ~check_range:(Int.min_value / 100, Int.max_value / 100); + [%expect {||}]; +;; + +let%expect_test "int32" = + test [%here] (module Int32) 1_000 (fun () -> int32 100l) + ~min:0l ~max:99l ~check_range:(10l,20l); + [%expect {||}]; +;; + +let%expect_test "int32_incl" = + test [%here] (module Int32) 1_000 (fun () -> int32_incl (-100l) 100l) + ~min:(-100l) ~max:100l ~check_range:(-20l,-10l); + [%expect {||}]; + test [%here] (module Int32) 1_000 (fun () -> int32_incl 0l Int32.max_value) + ~min:0l ~max:Int32.max_value + ~check_range:(0l, Int32.( / ) Int32.max_value 100l); + [%expect {||}]; + test [%here] (module Int32) 1_000 (fun () -> int32_incl Int32.min_value Int32.max_value) + ~min:Int32.min_value ~max:Int32.max_value + ~check_range:(Int32.( / ) Int32.min_value 100l, Int32.( / ) Int32.max_value 100l); + [%expect {||}]; +;; + +let%expect_test "int64" = + test [%here] (module Int64) 1_000 (fun () -> int64 100L) + ~min:0L ~max:99L ~check_range:(10L,20L); + [%expect {||}]; +;; + +let%expect_test "int64_incl" = + test [%here] (module Int64) 1_000 (fun () -> int64_incl (-100L) 100L) + ~min:(-100L) ~max:100L ~check_range:(-20L,-10L); + [%expect {||}]; + test [%here] (module Int64) 1_000 (fun () -> int64_incl 0L Int64.max_value) + ~min:0L ~max:Int64.max_value + ~check_range:(0L, Int64.( / ) Int64.max_value 100L); + [%expect {||}]; + test [%here] (module Int64) 1_000 (fun () -> int64_incl Int64.min_value Int64.max_value) + ~min:Int64.min_value ~max:Int64.max_value + ~check_range:(Int64.( / ) Int64.min_value 100L, Int64.( / ) Int64.max_value 100L); + [%expect {||}]; +;; + +let%expect_test "nativeint" = + test [%here] (module Nativeint) 1_000 (fun () -> nativeint 100n) + ~min:0n ~max:99n ~check_range:(10n,20n); + [%expect {||}]; +;; + +let%expect_test "nativeint_incl" = + test [%here] (module Nativeint) 1_000 (fun () -> nativeint_incl (-100n) 100n) + ~min:(-100n) ~max:100n ~check_range:(-20n,-10n); + [%expect {||}]; + test [%here] (module Nativeint) 1_000 (fun () -> nativeint_incl 0n Nativeint.max_value) + ~min:0n ~max:Nativeint.max_value + ~check_range:(0n, Nativeint.( / ) Nativeint.max_value 100n); + [%expect {||}]; + test [%here] (module Nativeint) 1_000 (fun () -> + nativeint_incl Nativeint.min_value Nativeint.max_value) + ~min:Nativeint.min_value ~max:Nativeint.max_value + ~check_range:(Nativeint.( / ) Nativeint.min_value 100n, + Nativeint.( / ) Nativeint.max_value 100n); + [%expect {||}]; +;; + +(* The int63 functions come from [Int63] rather than [Random], but we test them here + along with the others anyway. *) + +let%expect_test "int63" = + let i = Int63.of_int in + test [%here] (module Int63) 1_000 (fun () -> Int63.random (i 100)) + ~min:(i 0) ~max:(i 99) ~check_range:(i 10,i 20); + [%expect {||}]; +;; + +let%expect_test "int63_incl" = + let i = Int63.of_int in + test [%here] (module Int63) 1_000 (fun () -> Int63.random_incl (i (-100)) (i 100)) + ~min:(i (-100)) ~max:(i 100) ~check_range:(i (-20),i (-10)); + [%expect {||}]; + test [%here] (module Int63) 1_000 (fun () -> Int63.random_incl (i 0) Int63.max_value) + ~min:(i 0) ~max:Int63.max_value + ~check_range:(i 0, Int63.( / ) Int63.max_value (i 100)); + [%expect {||}]; + test [%here] (module Int63) 1_000 (fun () -> + Int63.random_incl Int63.min_value Int63.max_value) + ~min:Int63.min_value ~max:Int63.max_value + ~check_range:(Int63.( / ) Int63.min_value (i 100), + Int63.( / ) Int63.max_value (i 100)); + [%expect {||}]; +;; diff --git a/test/test_random.mli b/test/test_random.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_random.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_ref.ml b/test/test_ref.ml new file mode 100644 index 0000000..298d903 --- /dev/null +++ b/test/test_ref.ml @@ -0,0 +1,21 @@ +open! Core_kernel +open Ref + +let%test_unit "[set_temporarily] without raise" = + let r = ref 0 in + [%test_result: int] ~expect:1 (set_temporarily r 1 ~f:(fun () -> !r)); + [%test_result: int] ~expect:0 !r; +;; + +let%test_unit "[set_temporarily] with raise" = + let r = ref 0 in + try + never_returns (set_temporarily r 1 ~f:(fun () -> failwith "")); + with _ -> [%test_result: int] ~expect:0 !r +;; + +let%test_unit "[set_temporarily] where [f] sets the ref" = + let r = ref 0 in + set_temporarily r 1 ~f:(fun () -> r := 2); + [%test_result: int] ~expect:0 !r; +;; diff --git a/test/test_ref.mli b/test/test_ref.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_ref.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_sequence.ml b/test/test_sequence.ml new file mode 100644 index 0000000..0887194 --- /dev/null +++ b/test/test_sequence.ml @@ -0,0 +1,488 @@ +open! Import +open! Sequence + +let%test_unit "of_lazy" = + let t = range 0 100 in + [%test_result: int list] + (to_list (of_lazy (lazy t))) + ~expect:(to_list t) + +let%test_unit _ = + let seq_of_seqs = + unfold ~init:0 ~f:(fun i -> + Some (unfold ~init:i ~f:(fun j -> Some ((i, j), j + 1)), + i + 1)) + in + [%test_result: (int * int) list] + (to_list (take (interleave seq_of_seqs) 10)) + ~expect:[ 0,0 + ; 0,1 ; 1,1 + ; 0,2 ; 1,2 ; 2,2 + ; 0,3 ; 1,3 ; 2,3 ; 3,3 + ] + +let%expect_test "round_robin vs interleave" = + let list_of_lists = + [ [1; 10; 100; 1000] + ; [2; 20; 200] + ; [3; 30] + ; [4] + ] + in + let list_of_seqs = List.map list_of_lists ~f:of_list in + let seq_of_seqs = of_list list_of_seqs in + print_s [%sexp (to_list (round_robin list_of_seqs) : int list)]; + [%expect {| (1 2 3 4 10 20 30 100 200 1_000) |}]; + print_s [%sexp (to_list (interleave seq_of_seqs) : int list)]; + [%expect {| (1 10 2 100 20 3 1_000 200 30 4) |}]; +;; + +let%test_unit _ = + let evens = unfold ~init:0 ~f:(fun i -> Some (i, i + 2)) in + let vowels = cycle_list_exn ['a';'e';'i';'o';'u'] in + [%test_result: (int * char) list] + (to_list (take (interleaved_cartesian_product evens vowels) 10)) + ~expect:[ 0,'a' + ; 0,'e' ; 2,'a' + ; 0,'i' ; 2,'e' ; 4,'a' + ; 0,'o' ; 2,'i' ; 4,'e' ; 6,'a' + ] + +let%test_module "Sequence.merge*" = + (module struct + let%test_unit _ = + [%test_eq: (int, int) Merge_with_duplicates_element.t list] + (to_list + (merge_with_duplicates + (of_list [ 1; 2; ]) + (of_list [ 2; 3; ]) + (* Can't use Core_int.compare because it would be a dependency cycle. *) + ~compare:Int.compare)) + [ Left 1; Both (2, 2); Right 3; ] + + let%test_unit _ = + [%test_eq: (int, int) Merge_with_duplicates_element.t list] + (to_list + (merge_with_duplicates + (of_list [ 2; 1; ]) + (of_list [ 2; 3; ]) + ~compare:Int.compare)) + [ Both (2, 2); Left 1; Right 3; ] + + let%test_unit _ = + [%test_eq: (int * string) list] + (to_list + (merge + (of_list [ (0, "A"); (1, "A"); ]) + (of_list [ (1, "B"); (2, "B"); ]) + ~compare:(fun a b -> [%compare: int] (fst a) (fst b)))) + [ (0, "A"); (1, "A"); (2, "B"); ] + end) + +let%test _ = fold ~f:(+) ~init:0 (of_list [1; 2; 3; 4; 5]) = 15 +let%test _ = fold ~f:(+) ~init:0 (of_list []) = 0 + +let%test_unit _ = + let test_equal l = [%test_result: int list] (to_list (of_list l)) ~expect:l in + test_equal []; + test_equal [1; 2; 3; 4; 5] +(* The test for longer list is after range *) + +let%test_unit _ = [%test_result: int list] (to_list (range 0 5)) ~expect:[0;1;2;3;4] +let%test_unit _ = [%test_result: int list] (to_list (range ~stop:`inclusive 0 5)) ~expect:[0;1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (range ~start:`exclusive 0 5)) ~expect:[1;2;3;4] +let%test_unit _ = [%test_result: int list] (to_list (range ~stride:(-2) 5 1)) ~expect:[5;3] + +(* Test for to_list *) +let%test_unit _ = [%test_result: int list] (to_list (range 0 5000)) ~expect:(List.range 0 5000) + +(* Functions used for testing by comparing to List implementation*) +let test_to_list s f g = + [%test_result: int list] (to_list (f s)) ~expect:(g (to_list s)) + +(* For testing, we create a sequence which is equal to 1;2;3;4;5, but + with a more interesting structure inside*) + +let s12345 = map ~f:(fun x -> x / 2) (filter ~f:(fun x -> x % 2 = 0) + (of_list [1;2;3;4;5;6;7;8;9;10])) + +let sempty = filter ~f:(fun x -> x < 0) (of_list [1;2;3;4]) + +let test f g = test_to_list s12345 f g; test_to_list sempty f g + +let%test_unit _ = + [%test_result: int list] (to_list s12345) ~expect:[1; 2; 3; 4; 5]; + [%test_result: int list] (to_list sempty) ~expect:[] + +let%test_unit _ = + [%test_result: int list] + (to_list (unfold_with s12345 ~init:1 + ~f:(fun s _ -> + if s % 2 = 0 then + Skip (s+1) + else if s = 5 then + Done + else + Yield(s, s+1)))) + ~expect:[1;3] + +let test_delay init = + unfold_with_and_finish ~init + ~running_step:(fun prev next -> + Yield (prev, next)) + ~inner_finished:(fun x -> Some x) + ~finishing_step:(fun prev -> + match prev with + | None -> Done + | Some prev -> Yield (prev, None)) + +let%test_unit _ = + [%test_result: int list] + (to_list (test_delay 0 s12345)) + ~expect:[0; 1; 2; 3; 4; 5] + +let%test_unit _ = + [%test_result: int list] + (to_list (test_delay 0 sempty)) + ~expect:[0] + +let%test_unit _ = [%test_result: int list] (to_list s12345) ~expect:[1; 2; 3; 4; 5] + +let%test_unit _ = test + (map ~f:(fun i -> -i)) + (List.map ~f:(fun i -> -i)) + +let%test_unit _ = test + (mapi ~f:(fun i j -> j - 2 *i)) + (List.mapi ~f:(fun i j -> j - 2 *i)) + +let%test_unit _ = test + (filter ~f:(fun i -> i % 2 = 0)) + (List.filter ~f:(fun i -> i % 2 = 0)) + +let%test _ = length s12345 = 5 && length sempty = 0 + +let%test_unit _ = + [%test_result: int option] (find s12345 ~f:(fun x -> x = 3)) ~expect:(Some 3); + [%test_result: int option] (find s12345 ~f:(fun x -> x = 7)) ~expect:None + +let%test_unit _ = + [%test_result: string option] + (find_map s12345 ~f:(fun x -> if x = 3 then Some "a" else None)) + ~expect:(Some "a"); + [%test_result: string option] + (find_map s12345 ~f:(fun x -> if x = 7 then Some "a" else None)) + ~expect:None + +let%test_unit _ = + [%test_result: string option] + (find_mapi s12345 ~f:(fun _ x -> if x = 3 then Some "a" else None)) + ~expect:(Some "a") +let%test_unit _ = + [%test_result: string option] + (find_mapi s12345 ~f:(fun _ x -> if x = 7 then Some "a" else None)) + ~expect:None +let%test_unit _ = + [%test_result: (int * int) option] + (find_mapi s12345 ~f:(fun i x -> if i + x >= 6 then Some (i,x) else None)) + ~expect:(Some (3,4)) + +let%test _ = for_all sempty ~f:(fun _ -> false) +let%test _ = for_all s12345 ~f:(fun x -> x > 0) +let%test _ = not (for_all s12345 ~f:(fun x -> x < 5)) + +let%test _ = for_alli sempty ~f:(fun _ _ -> false) +let%test _ = for_alli s12345 ~f:(fun _ x -> x > 0) +let%test _ = not (for_alli s12345 ~f:(fun _ x -> x < 5)) +let%test _ = for_alli s12345 ~f:(fun i x -> x = i+1) + +let%test _ = not (exists sempty ~f:(fun _ -> assert false)) +let%test _ = exists s12345 ~f:(fun x -> x = 5) +let%test _ = not (exists s12345 ~f:(fun x -> x = 0)) + +let%test _ = not (existsi sempty ~f:(fun _ _ -> assert false)) +let%test _ = existsi s12345 ~f:(fun _ x -> x = 5) +let%test _ = not (existsi s12345 ~f:(fun _ x -> x = 0)) +let%test _ = not (existsi s12345 ~f:(fun i x -> x <> i+1)) + +let%test_unit _ = + let l = ref [] in + iter s12345 ~f:(fun x -> l := x::!l); + [%test_result: int list] !l ~expect:[5;4;3;2;1] + +let%test _ = is_empty sempty +let%test _ = not (is_empty (of_list [1])) + +let%test _ = mem s12345 1 ~equal:Int.equal +let%test _ = not (mem s12345 6 ~equal:Int.equal) + +let%test_unit _ = [%test_result: int list] (to_list empty) ~expect:[] + +let%test_unit _ = + [%test_result: int list] + (to_list (bind sempty ~f:(fun _ -> s12345))) + ~expect:[] +let%test_unit _ = + [%test_result: int list] + (to_list (bind s12345 ~f:(fun _ -> sempty))) + ~expect:[] +let%test_unit _ = + [%test_result: int list] + (to_list (bind s12345 ~f:(fun x -> of_list [x;-x]))) + ~expect:[1;-1;2;-2;3;-3;4;-4;5;-5] + +let%test_unit _ = [%test_result: int list] (to_list (return 1)) ~expect:[1] + +let%test_unit _ = [%test_result: int option] (nth s12345 3) ~expect:(Some 4) +let%test_unit _ = [%test_result: int option] (nth s12345 5) ~expect:None + +let%test_unit _ = [%test_result: int option] (hd s12345) ~expect:(Some 1) +let%test_unit _ = [%test_result: int option] (hd sempty) ~expect:None + +let%test_unit _ = [%test_result: int t option] (tl sempty) ~expect:None +let%test_unit _ = match tl s12345 with + | Some l -> [%test_result: int list] (to_list l) ~expect:[2;3;4;5] + | None -> failwith "expected Some" + +let%test_unit _ = [%test_result: (int * int t) option] (next sempty) ~expect:None +let%test_unit _ = match next s12345 with + | Some (hd,tl) -> + [%test_result: int] hd ~expect:1; + [%test_result: int list] (to_list tl) ~expect:[2;3;4;5] + | None -> failwith "expected Some" + +let%test_unit _ = + [%test_result: int list] + (to_list (filter_opt (of_list [None; Some 1; None ;Some 2; Some 3]))) + ~expect:[1;2;3] + +let%test_unit _ = + let (l,r) = split_n s12345 2 in + [%test_result: int list] l ~expect:[1;2]; + [%test_result: int list] (to_list r) ~expect:[3;4;5] + +let%test_unit _ = [%test_result: int list list] (to_list (chunks_exn s12345 2)) ~expect:[[1;2];[3;4];[5]] + +let%test_unit _ = [%test_result: int list] (to_list (append s12345 s12345)) ~expect:[1;2;3;4;5;1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (append sempty s12345)) ~expect:[1;2;3;4;5] + +let%test_unit _ = + [%test_result: (int * int) list] + (to_list (zip s12345 sempty)) + ~expect:[] +let%test_unit _ = + [%test_result: (int * int) list] + (to_list (zip s12345 (of_list [6;5;4;3;2;1]))) + ~expect:[1,6;2,5;3,4;4,3;5,2] +let%test_unit _ = + [%test_result: (int * string) list] + (to_list (zip s12345 (of_list ["a"]))) + ~expect:[1,"a"] + +let%test_unit _ = + [%test_result: (int * int) option] + (find_consecutive_duplicate s12345 ~equal:(=)) + ~expect:None +let%test_unit _ = + [%test_result: (int * int) option] + (find_consecutive_duplicate (of_list [1;2;2;3;4;4;5]) ~equal:(=)) + ~expect:(Some (2,2)) + +let%test_unit _ = + [%test_result: int list] + (to_list + (remove_consecutive_duplicates ~equal:(=) (of_list [1;2;2;3;3;3;3;4;4;5;6;6;7]))) + ~expect:[1;2;3;4;5;6;7] +let%test_unit _ = + [%test_result: int list] + (to_list + (remove_consecutive_duplicates ~equal:(=) s12345)) + ~expect:[1;2;3;4;5] + +let%test_unit _ = + [%test_result: int list] + (to_list (remove_consecutive_duplicates ~equal:(fun _ _ -> true) s12345)) + ~expect:[1] + +let%test_unit _ = [%test_result: int list] (to_list (init (-1) ~f:(fun _ -> assert false))) ~expect:[] +let%test_unit _ = [%test_result: int list] (to_list (init 5 ~f:Fn.id)) ~expect:[0; 1; 2; 3; 4] + +let%test_unit _ = [%test_result: int list] (to_list (sub s12345 ~pos:4 ~len:10)) ~expect:[5] +let%test_unit _ = [%test_result: int list] (to_list (sub s12345 ~pos:1 ~len:2)) ~expect:[2;3] +let%test_unit _ = [%test_result: int list] (to_list (sub s12345 ~pos:0 ~len:0)) ~expect:[] + +let%test_unit _ = [%test_result: int list] (to_list (take s12345 2)) ~expect:[1;2] +let%test_unit _ = [%test_result: int list] (to_list (take s12345 0)) ~expect:[] +let%test_unit _ = [%test_result: int list] (to_list (take s12345 9)) ~expect:[1;2;3;4;5] + +let%test_unit _ = [%test_result: int list] (to_list (drop s12345 2)) ~expect:[3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop s12345 0)) ~expect:[1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop s12345 9)) ~expect:[] + +let%test_unit _ = [%test_result: int list] (to_list (take_while ~f:(fun x -> x < 3) s12345)) ~expect:[1;2] + +let%test_unit _ = [%test_result: int list] (to_list (drop_while ~f:(fun x -> x < 3) s12345)) ~expect:[3;4;5] + +let%test_unit _ = + [%test_result: int list] + (to_list (shift_right (shift_right s12345 0) (-1))) + ~expect:[-1;0;1;2;3;4;5] + +let%test_unit _ = [%test_result: char list] (to_list (intersperse ~sep:'a' (of_list []))) ~expect:[] +let%test_unit _ = [%test_result: char list] (to_list (intersperse ~sep:'a' (of_list ['b']))) ~expect:['b'] +let%test_unit _ = [%test_result: int list] (to_list (intersperse ~sep:(-1) (take s12345 1))) ~expect:[1] +let%test_unit _ = [%test_result: int list] (to_list (intersperse ~sep:0 s12345)) ~expect:[1;0;2;0;3;0;4;0;5] + +let%test_unit _ = [%test_result: int list] (to_list (take (repeat 1) 3)) ~expect:[1;1;1] + +let%test_unit _ = + [%test_result: int list] + (to_list (take (cycle_list_exn [1;2;3;4;5]) 7)) + ~expect:[1;2;3;4;5;1;2] + +let%test_unit _ = require_does_raise [%here] (fun () -> cycle_list_exn []) + +let%test_unit _ = + [%test_result: (char * int) list] + (to_list (cartesian_product (of_list ['a';'b']) s12345)) + ~expect:['a',1;'a',2;'a',3;'a',4;'a',5; + 'b',1;'b',2;'b',3;'b',4;'b',5] + +let%test_unit _ = + [%test_result: float] + (delayed_fold s12345 ~init:0.0 + ~f:(fun a i ~k -> + if Float.(<=) a 5.0 then + k (a +. (Float.of_int i)) else + a) + ~finish:(fun _ -> assert false)) + ~expect:6.0 + +let%expect_test "fold_m" = + let module Simple_monad = struct + type 'a t = + | Return of 'a + | Step of 'a t + [@@deriving sexp_of] + + let return a = Return a + + let rec bind t ~f = + match t with + | Return a -> f a + | Step t -> Step (bind t ~f) + ;; + + let step = Step (Return ()) + end in + fold_m ~bind:Simple_monad.bind ~return:Simple_monad.return s12345 + ~init:[] + ~f:(fun acc n -> + Simple_monad.bind Simple_monad.step ~f:(fun () -> + Simple_monad.return (n :: acc))) + |> printf !"%{sexp: int list Simple_monad.t}\n"; + [%expect {| (Step (Step (Step (Step (Step (Return (5 4 3 2 1))))))) |}] +;; + +let%expect_test "iter_m" = + iter_m ~bind:Generator.bind ~return:Generator.return s12345 ~f:Generator.yield + |> Generator.run + |> printf !"%{sexp: int t}\n"; + [%expect {| (1 2 3 4 5) |}] +;; + +let%test _ = + let num_computations = ref 0 in + let t = memoize (unfold ~init:() ~f:(fun () -> Int.incr num_computations; None)) in + iter t ~f:Fn.id; + iter t ~f:Fn.id; + !num_computations = 1 + +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 0)) ~expect:[1;2;3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 2)) ~expect:[3;4;5] +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 5)) ~expect:[] +let%test_unit _ = [%test_result: int list] (to_list (drop_eagerly s12345 8)) ~expect:[] + +let compare_tests = + [ [1; 2; 3] , [1; 2; 3] , 0 + ; [1; 2; 3] , [] , 1 + ; [] , [1; 2; 3] , -1 + ; [1; 2] , [1; 2; 3] , -1 + ; [1; 2; 3] , [1; 2] , 1 + ; [1; 3; 2] , [1; 2; 3] , 1 + ; [1; 2; 3] , [1; 3; 2] , -1 ] + +(* this test has to use base OCaml library functions to avoid circular dependencies *) +let%test _ = + List.for_all + ~f:Fn.id + (List.map + ~f:(fun (l1, l2, expected_res) -> + compare Int.compare (of_list l1) (of_list l2) = expected_res) + compare_tests) + +let%test_unit _ = + [%test_result: int list] + (folding_map (of_list [1;2;3;4]) ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) |> to_list) + ~expect:[1;3;6;10] +let%test_unit _ = + [%test_result: bool] + (folding_map empty ~init:0 + ~f:(fun acc x -> let y = acc+x in y,y) |> is_empty) + ~expect:true +let%test_unit _ = + [%test_result: int list] + (folding_mapi (of_list [1;2;3;4]) ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) |> to_list) + ~expect:[0;2;8;20] +let%test_unit _ = + [%test_result: bool] + (folding_mapi empty ~init:0 + ~f:(fun i acc x -> let y = acc+i*x in y,y) |> is_empty) + ~expect:true + +let%expect_test _ = + let xs = init 3 ~f:Fn.id |> Generator.of_sequence in + let ( @ ) xs ys = Generator.bind xs ~f:(fun () -> ys) in + (xs @ xs @ xs @ xs @ xs) + |> Generator.run + |> [%sexp_of: int t] + |> print_s; + [%expect {| + (0 1 2 0 1 2 0 1 2 0 1 2 0 1 2) |}] +;; + +let%test_module "group" = + (module struct + let%test _ = + of_list [1; 2; 3; 4] + |> group ~break:(fun _ x -> Int.equal x 3) + |> [%compare.equal: int list t] (of_list [[1; 2]; [3; 4]]) + ;; + + let%test _ = + group empty ~break:(fun _ -> assert false) + |> [%compare.equal: unit list t] empty + ;; + + let mis = of_list ['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i'] + ;; + + let equal_letters = + of_list [['M'];['i'];['s';'s'];['i'];['s';'s'];['i'];['p';'p'];['i']] + ;; + + let single_letters = of_list [['M';'i';'s';'s';'i';'s';'s';'i';'p';'p';'i']] + ;; + + let%test _ = + group ~break:Char.(<>) mis + |> [%compare.equal: char list t] equal_letters + ;; + + let%test _ = + group ~break:(fun _ _ -> false) mis + |> [%compare.equal: char list t] single_letters + ;; + end) diff --git a/test/test_sequence.mli b/test/test_sequence.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sequence.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_set.ml b/test/test_set.ml new file mode 100644 index 0000000..2129f9b --- /dev/null +++ b/test/test_set.ml @@ -0,0 +1,28 @@ +open! Import +open! Set + +type int_set = Set.M(Int).t [@@deriving compare, hash, sexp] + +let%test _ = + invariants + (of_increasing_iterator_unchecked (module Int) ~len:20 ~f:Fn.id) + +let%test _ = + invariants + (Poly.of_increasing_iterator_unchecked ~len:20 ~f:Fn.id) + +module Poly = struct + let%test _ = + length Poly.empty = 0 + ;; + + let%test _ = + Poly.equal (Poly.of_list []) Poly.empty + ;; + + let%test _ = + let a = Poly.of_list [1; 1] in + let b = Poly.of_list ["a"] in + length a = length b + ;; +end diff --git a/test/test_sexp.ml b/test/test_sexp.ml new file mode 100644 index 0000000..7014760 --- /dev/null +++ b/test/test_sexp.ml @@ -0,0 +1,40 @@ +open! Import + +let%expect_test "[sexp_array]" = + let module M = struct + type t = { x : int sexp_array } [@@deriving sexp_of] + end in + List.iter [ [| |]; [| 13 |] ] ~f:(fun x -> print_s [%sexp ({ x } : M.t)]); + [%expect {| + () + ((x (13))) |}]; +;; + +let%expect_test "[sexp_list]" = + let module M = struct + type t = { x : int sexp_list } [@@deriving sexp_of] + end in + List.iter [ [ ]; [ 13 ] ] ~f:(fun x -> print_s [%sexp ({ x } : M.t)]); + [%expect {| + () + ((x (13))) |}]; +;; + +let%expect_test "[sexp_opaque]" = + let module M = struct + type t = { x : int sexp_opaque } [@@deriving sexp_of] + end in + print_s [%sexp ({ x = 13 } : M.t)]; + [%expect {| + ((x <opaque>)) |}]; +;; + +let%expect_test "[sexp_option]" = + let module M = struct + type t = { x : int sexp_option } [@@deriving sexp_of] + end in + List.iter [ None; Some 13 ] ~f:(fun x -> print_s [%sexp ({ x } : M.t)]); + [%expect {| + () + ((x 13)) |}]; +;; diff --git a/test/test_sexp.mli b/test/test_sexp.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sexp.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_sign.ml b/test/test_sign.ml new file mode 100644 index 0000000..83471a6 --- /dev/null +++ b/test/test_sign.ml @@ -0,0 +1,17 @@ +open! Import +open! Sign + +let%test "of_int" = + of_int 37 = Pos && of_int (-22) = Neg && of_int 0 = Zero + +let%test_unit "( * )" = + List.cartesian_product all all + |> List.iter ~f:(fun (s1, s2) -> + [%test_result: int] + (to_int (s1 * s2)) + ~expect:(Int.( * ) (to_int s1) (to_int s2))) + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module Sign) all; + [%expect {| |}]; +;; diff --git a/test/test_sign.mli b/test/test_sign.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sign.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_sign_or_nan.ml b/test/test_sign_or_nan.ml new file mode 100644 index 0000000..cb30856 --- /dev/null +++ b/test/test_sign_or_nan.ml @@ -0,0 +1,10 @@ +open! Import +open! Sign_or_nan + +let%test "of_int" = + of_int 37 = Pos && of_int (-22) = Neg && of_int 0 = Zero + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module Sign_or_nan) all; + [%expect {| |}]; +;; diff --git a/test/test_sign_or_nan.mli b/test/test_sign_or_nan.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_sign_or_nan.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_stack.ml b/test/test_stack.ml new file mode 100644 index 0000000..a3df0af --- /dev/null +++ b/test/test_stack.ml @@ -0,0 +1,235 @@ +open! Base +open! Import +open! Stack + +module Debug (Stack : S) : S with type 'a t = 'a Stack.t = struct + + open Stack + + type nonrec 'a t = 'a t + + let invariant = invariant + + let check_and_return t = invariant ignore t; t + + let debug t f = + let result = Result.try_with f in + invariant ignore t; + Result.ok_exn result; + ;; + + (* The return-type annotations are to prevent an error where we don't supply all the + arguments to the function, and thus wouldn't be checking the invariant after fully + applying the function. *) + let clear t : unit = debug t (fun () -> clear t) + let copy t : _ t = check_and_return (debug t (fun () -> copy t)) + let count t ~f : int = debug t (fun () -> count t ~f) + let sum m t ~f = debug t (fun () -> sum m t ~f) + let create () : _ t = check_and_return (create ()) + let exists t ~f : bool = debug t (fun () -> exists t ~f) + let find t ~f : _ option = debug t (fun () -> find t ~f) + let find_map t ~f : _ option = debug t (fun () -> find_map t ~f) + let fold (type a) t ~init ~f : a = debug t (fun () -> fold t ~init ~f) + let for_all t ~f : bool = debug t (fun () -> for_all t ~f) + let is_empty t : bool = debug t (fun () -> is_empty t) + let iter t ~f : unit = debug t (fun () -> iter t ~f) + let length t : int = debug t (fun () -> length t) + let mem t a ~equal : bool = debug t (fun () -> mem t a ~equal) + let of_list l : _ t = check_and_return (of_list l) + let pop t : _ option = debug t (fun () -> pop t) + let pop_exn (type a) t : a = debug t (fun () -> pop_exn t) + let push t a : unit = debug t (fun () -> push t a) + let sexp_of_t sexp_of_a t : Sexp.t = debug t (fun () -> [%sexp_of: a t] t) + let singleton x : _ t = check_and_return (singleton x) + let t_of_sexp a_of_sexp sexp : _ t = check_and_return ([%of_sexp: a t] sexp) + let to_array t : _ array = debug t (fun () -> to_array t) + let to_list t : _ list = debug t (fun () -> to_list t) + let top t : _ option = debug t (fun () -> top t) + let top_exn (type a) t : a = debug t (fun () -> top_exn t) + let until_empty t f : unit = debug t (fun () -> until_empty t f) + let min_elt t ~compare : _ option = debug t (fun () -> min_elt t ~compare) + let max_elt t ~compare : _ option = debug t (fun () -> max_elt t ~compare) + let fold_result t ~init ~f = debug t (fun () -> fold_result t ~init ~f) + let fold_until t ~init ~f = debug t (fun () -> fold_until t ~init ~f) +end + +module Test (Stack : S) + (* This signature is here to remind us to add a unit test whenever we add something to + the stack interface. *) + : S with type 'a t = 'a Stack.t = struct + + open Stack + + type nonrec 'a t = 'a t + + include Test_container.Test_S1 (Stack) + + let invariant = invariant + + let create = create + let is_empty = is_empty + let top_exn = top_exn + let pop_exn = pop_exn + let pop = pop + let top = top + let singleton = singleton + let%test_unit _ = + let empty = create () in + invariant ignore empty; + invariant (fun b -> assert b) (of_list [true]); + assert (is_empty empty); + let t = create () in + push t 0; + assert (not (is_empty t)); + assert (Exn.does_raise (fun () -> top_exn empty)); + let t = create () in + push t 0; + [%test_result: int] (top_exn t) ~expect:0; + assert (Exn.does_raise (fun () -> pop_exn empty)); + let t = create () in + push t 0; + [%test_result: int] (pop_exn t) ~expect:0; + assert (is_none (pop empty)); + assert (is_some (pop (of_list [0]))); + assert (is_none (top empty)); + assert (is_some (top (of_list [0]))); + assert (is_some (top (singleton 0))); + assert (is_some (pop (singleton 0))); + assert (let t = singleton 0 in + ignore (pop_exn t : int); + is_none (top t)); + ;; + + let min_elt = min_elt + let max_elt = max_elt + + let%test_unit _ = + let empty = create () in + [%test_result: _ option] (min_elt ~compare:Int.compare empty) ~expect:None; + [%test_result: _ option] (max_elt ~compare:Int.compare empty) ~expect:None; + [%test_result: int] (sum (module Int) ~f:Fn.id empty) ~expect:0 + ;; + + let push = push + let copy = copy + let until_empty = until_empty + + let%test_unit _ = + let t = + let t = create () in + push t 0; + push t 1; + push t 2; + t + in + [%test_result: bool] (is_empty t) ~expect:false; + [%test_result: int] (length t) ~expect:3; + [%test_result: int option] (top t) ~expect:(Some 2); + [%test_result: int] (top_exn t) ~expect:2; + [%test_result: int option] (min_elt ~compare:Int.compare t) ~expect:(Some 0); + [%test_result: int option] (max_elt ~compare:Int.compare t) ~expect:(Some 2); + [%test_result: int] (sum (module Int) ~f:Fn.id t) ~expect:3; + let t' = copy t in + [%test_result: int] (pop_exn t') ~expect:2; + [%test_result: int] (pop_exn t') ~expect:1; + [%test_result: int] (pop_exn t') ~expect:0; + [%test_result: int] (length t') ~expect:0; + [%test_result: bool] (is_empty t') ~expect:true; + let t' = copy t in + [%test_result: int option] (pop t') ~expect:(Some 2); + [%test_result: int option] (pop t') ~expect:(Some 1); + [%test_result: int option] (pop t') ~expect:(Some 0); + [%test_result: int] (length t') ~expect:0; + [%test_result: bool] (is_empty t') ~expect:true; + (* test that t was not modified by pops applied to copies *) + [%test_result: int] (length t) ~expect:3; + [%test_result: int] (top_exn t) ~expect:2; + [%test_result: int list] (to_list t) ~expect:[2; 1; 0]; + [%test_result: int array] (to_array t) ~expect:[|2; 1; 0|]; + [%test_result: int] (length t) ~expect:3; + [%test_result: int] (top_exn t) ~expect:2; + let t' = copy t in + let n = ref 0 in + until_empty t' (fun x -> n := !n + x); + [%test_result: int] !n ~expect:3; + [%test_result: bool] (is_empty t') ~expect:true; + [%test_result: int] (length t') ~expect:0 + ;; + + let%test_unit _ = + let t = create () in + [%test_result: bool] (is_empty t) ~expect:true; + [%test_result: int] (length t) ~expect:0; + [%test_result: _ list] (to_list t) ~expect:[]; + [%test_result: _ option] (pop t) ~expect:None; + push t 13; + [%test_result: bool] (is_empty t) ~expect:false; + [%test_result: int] (length t) ~expect:1; + [%test_result: int option] (min_elt ~compare:Int.compare t) ~expect:(Some 13); + [%test_result: int option] (max_elt ~compare:Int.compare t) ~expect:(Some 13); + [%test_result: int] (sum (module Int) ~f:Fn.id t) ~expect:13; + [%test_result: int] (pop_exn t) ~expect:13; + [%test_result: bool] (is_empty t) ~expect:true; + [%test_result: int] (length t) ~expect:0; + push t 13; + push t 14; + [%test_result: bool] (is_empty t) ~expect:false; + [%test_result: int] (length t) ~expect:2; + [%test_result: int list] (to_list t) ~expect:[14; 13]; + [%test_result: int option] (min_elt ~compare:Int.compare t) ~expect:(Some 13); + [%test_result: int option] (max_elt ~compare:Int.compare t) ~expect:(Some 14); + [%test_result: int] (sum (module Int) ~f:Fn.id t) ~expect:27; + [%test_result: bool] (is_some (pop t)) ~expect:true; + [%test_result: bool] (is_some (pop t)) ~expect:true + ;; + + let of_list = of_list + + let%test_unit _ = + for n = 0 to 5 do + let l = List.init n ~f:Fn.id in + [%test_result: int list] (to_list (of_list l)) ~expect:l + done + ;; + + let clear = clear + + let%test_unit _ = + for n = 0 to 5 do + let t = of_list (List.init n ~f:Fn.id) in + clear t; + assert (is_empty t); + push t 13; + [%test_result: int] (length t) ~expect:1 + done + ;; + + let%test_unit "float test" = + let s = create () in + push s 1.0; + push s 2.0; + push s 3.0 + +end + +include Test_container.Test_S1 (Stack) + +include Test (Debug (Stack)) + +let capacity = capacity +let set_capacity = set_capacity + +let%test_unit _ = + let t = create () in + [%test_result: int] (capacity t) ~expect:0; + set_capacity t (-1); + [%test_result: int] (capacity t) ~expect:0; + set_capacity t 10; + [%test_result: int] (capacity t) ~expect:10; + set_capacity t 0; + [%test_result: int] (capacity t) ~expect:0; + push t (); + set_capacity t 0; + [%test_result: int] (length t) ~expect:1; + [%test_pred: int] (fun c -> c >= 1) (capacity t) +;; diff --git a/test/test_stack.mli b/test/test_stack.mli new file mode 100644 index 0000000..2f12c18 --- /dev/null +++ b/test/test_stack.mli @@ -0,0 +1,5 @@ +open! Base +open! Import + +module Debug (S : Stack.S) : Stack.S with type 'a t = 'a S.t +module Test (S : Stack.S) : sig end diff --git a/test/test_stdlib_shadowing.mlt b/test/test_stdlib_shadowing.mlt new file mode 100644 index 0000000..0212225 --- /dev/null +++ b/test/test_stdlib_shadowing.mlt @@ -0,0 +1,41 @@ +(* Additional shadowing tests, to make sure the [@@deprecated] attributes are properly + transported in [Base] *) +open Base + +let () = seek_in stdin 0 +[%%expect{| +Line _, characters 9-16: +Error (Warning 3): deprecated: Base.seek_in +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [Stdio.In_channel.seek] instead. +Line _, characters 17-22: +Error (Warning 3): deprecated: Base.stdin +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [Stdio.stdin] instead. +|}] + +let _ = StringLabels.make 10 'x' +[%%expect{| +Line _, characters 8-25: +Error (Warning 3): deprecated: module Base.StringLabels +[2016-09] this element comes from the stdlib distributed with OCaml. +Referring to the stdlib directly is discouraged by Base. You should either +use the equivalent functionality offered by Base, or if you really want to +refer to the stdlib, use Caml.StringLabels instead +|}] + +let _ = ( == ) +[%%expect{| +Line _, characters 8-14: +Error (Warning 3): deprecated: Base.== +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [phys_equal] instead. +|}] + +let _ = ( != ) +[%%expect{| +Line _, characters 8-14: +Error (Warning 3): deprecated: Base.!= +[2016-09] this element comes from the stdlib distributed with OCaml. +Use [not (phys_equal ...)] instead. +|}] diff --git a/test/test_string.ml b/test/test_string.ml new file mode 100644 index 0000000..40a5e25 --- /dev/null +++ b/test/test_string.ml @@ -0,0 +1,678 @@ +open! Import +open! String + +let%expect_test "hash coherence" [@tags "64-bits-only"] = + check_hash_coherence [%here] (module String) [ ""; "a"; "foo" ]; + [%expect {| |}]; +;; + +let%test_module "Caseless Suffix/Prefix" = + (module struct + let%test _ = Caseless.is_suffix "OCaml" ~suffix:"AmL" + let%test _ = Caseless.is_suffix "OCaml" ~suffix:"ocAmL" + let%test _ = Caseless.is_suffix "a@!$b" ~suffix:"a@!$B" + let%test _ = not (Caseless.is_suffix "a@!$b" ~suffix:"C@!$B") + let%test _ = not (Caseless.is_suffix "aa" ~suffix:"aaa") + let%test _ = Caseless.is_prefix "OCaml" ~prefix:"oc" + let%test _ = Caseless.is_prefix "OCaml" ~prefix:"ocAmL" + let%test _ = Caseless.is_prefix "a@!$b" ~prefix:"a@!$B" + let%test _ = not (Caseless.is_prefix "a@!$b" ~prefix:"a@!$C") + let%test _ = not (Caseless.is_prefix "aa" ~prefix:"aaa") + end) + +let%test_module "Caseless Comparable" = + (module struct + (* examples from docs *) + let%test _ = Caseless.equal "OCaml" "ocaml" + let%test _ = Caseless.("apple" < "Banana") + + let%test _ = Caseless.("aa" < "aaa") + let%test _ = Int.(<>) (Caseless.compare "apple" "Banana") (compare "apple" "Banana") + let%test _ = Caseless.equal "XxX" "xXx" + let%test _ = Caseless.("XxX" < "xXxX") + let%test _ = Caseless.("XxXx" > "xXx") + + let%test _ = List.is_sorted ~compare:Caseless.compare ["Apples"; "bananas"; "Carrots"] + + let%expect_test _ = + let x = Sys.opaque_identity "one string" in + let y = Sys.opaque_identity "another" in + require_no_allocation [%here] (fun () -> + ignore (Sys.opaque_identity (Caseless.equal x y) : bool)); + [%expect {||}]; + ;; + end) + +let%test_module "Caseless Hashable" = + (module struct + let%test _ = + Int.(<>) (hash "x") (hash "X") + && Int.(=) (Caseless.hash "x") (Caseless.hash "X") + let%test _ = Int.(=) (Caseless.hash "OCaml") (Caseless.hash "ocaml") + let%test _ = Int.(<>) (Caseless.hash "aaa") (Caseless.hash "aaaa") + let%test _ = Int.(<>) (Caseless.hash "aaa") (Caseless.hash "aab") + let%test _ = + let tbl = Hashtbl.create (module Caseless) in + Hashtbl.add_exn tbl ~key:"x" ~data:7; + [%compare.equal: int option] (Hashtbl.find tbl "X") (Some 7) + end) + +let%test _ = not (contains "" 'a') +let%test _ = contains "a" 'a' +let%test _ = not (contains "a" 'b') +let%test _ = contains "ab" 'a' +let%test _ = contains "ab" 'b' +let%test _ = not (contains "ab" 'c') +let%test _ = not (contains "abcd" 'b' ~pos:1 ~len:0) +let%test _ = contains "abcd" 'b' ~pos:1 ~len:1 +let%test _ = contains "abcd" 'c' ~pos:1 ~len:2 +let%test _ = not (contains "abcd" 'd' ~pos:1 ~len:2) +let%test _ = contains "abcd" 'd' ~pos:1 +let%test _ = not (contains "abcd" 'a' ~pos:1) + +let%test_module "Search_pattern" = + (module struct + open Search_pattern + + let%test_module "Search_pattern.create" = + (module struct + let prefix s n = sub s ~pos:0 ~len:n + let suffix s n = sub s ~pos:(length s - n) ~len:n + + let slow_create pattern = + (* Compute the longest prefix-suffix array from definition, O(n^3) *) + let n = length pattern in + let kmp_arr = Array.create ~len:n (-1) in + for i = 0 to n - 1 do + let x = prefix pattern (i + 1) in + for j = 0 to i do + if String.equal (prefix x j) (suffix x j) then + kmp_arr.(i) <- j + done + done; + (pattern, kmp_arr) + ;; + + let sexp_of_int = Base.Not_exposed_properly.Sexp_conv.sexp_of_int + + let test_both (s, a) = + let create_s = create s |> [%sexp_of: t ] in + let slow_create_s = slow_create s |> [%sexp_of: string * int array] in + let expected = [%sexp ((s, a) : string * int array)] in + require [%here] (Sexp.equal create_s expected && Sexp.equal slow_create_s expected) + ~if_false_then_print_s:(lazy ( + [%message "not equal" + (create_s : Sexp.t) + (slow_create_s : Sexp.t) + (expected : Sexp.t)])) + ;; + + let cmp_both s = + let create_s = create s |> [%sexp_of: t ] in + let slow_create_s = slow_create s |> [%sexp_of: string * int array] in + require [%here] (Sexp.equal create_s slow_create_s) + ~if_false_then_print_s:(lazy ( + [%message "not equal" + (create_s : Sexp.t) + (slow_create_s : Sexp.t)])) + ;; + + let%expect_test _ = + test_both ("", [| |]) + let%expect_test _ = + test_both ("ababab", [|0; 0; 1; 2; 3; 4|]) + let%expect_test _ = + test_both ("abaCabaD", [|0; 0; 1; 0; 1; 2; 3; 0|]) + let%expect_test _ = + test_both ("abaCabaDabaCabaCabaDabaCabaEabab", + [|0; 0; 1; 0; 1; 2; 3; 0; 1; 2; 3; 4; 5; 6; 7; 4; 5; 6; 7; 8; + 9; 10; 11; 12; 13; 14; 15; 0; 1; 2; 3; 2|]) + + let rec x k = + if Int.(<) k 0 then "" else + let b = x (k - 1) in + b ^ (make 1 (Caml.Char.unsafe_chr (65 + k))) ^ b + ;; + + let%expect_test _ = + cmp_both (x 10) + let%expect_test _ = + cmp_both ((x 5) ^ "E" ^ (x 4) ^ "D" ^ (x 3) ^ "B" ^ (x 2) ^ "C" ^ (x 3)) + end) + + let (=) = [%compare.equal: int option] + let%test _ = index (create "") ~in_:"abababac" = Some 0 + let%test _ = index ~pos:(-1) (create "") ~in_:"abababac" = None + let%test _ = index ~pos:1 (create "") ~in_:"abababac" = Some 1 + let%test _ = index ~pos:7 (create "") ~in_:"abababac" = Some 7 + let%test _ = index ~pos:8 (create "") ~in_:"abababac" = Some 8 + let%test _ = index ~pos:9 (create "") ~in_:"abababac" = None + let%test _ = index (create "abababaca") ~in_:"abababac" = None + let%test _ = index (create "abababac") ~in_:"abababac" = Some 0 + let%test _ = index ~pos:0 (create "abababac") ~in_:"abababac" = Some 0 + let%test _ = index (create "abac") ~in_:"abababac" = Some 4 + let%test _ = index ~pos:4 (create "abac") ~in_:"abababac" = Some 4 + let%test _ = index ~pos:5 (create "abac") ~in_:"abababac" = None + let%test _ = index ~pos:5 (create "abac") ~in_:"abababaca" = None + let%test _ = index ~pos:5 (create "baca") ~in_:"abababaca" = Some 5 + let%test _ = index ~pos:(-1) (create "a") ~in_:"abc" = None + let%test _ = index ~pos:2 (create "a") ~in_:"abc" = None + let%test _ = index ~pos:2 (create "c") ~in_:"abc" = Some 2 + let%test _ = index ~pos:3 (create "c") ~in_:"abc" = None + + let (=) = [%compare.equal: bool] + let%test _ = matches (create "") "abababac" = true + let%test _ = matches (create "abababaca") "abababac" = false + let%test _ = matches (create "abababac") "abababac" = true + let%test _ = matches (create "abac") "abababac" = true + let%test _ = matches (create "abac") "abababaca" = true + let%test _ = matches (create "baca") "abababaca" = true + let%test _ = matches (create "a") "abc" = true + let%test _ = matches (create "c") "abc" = true + + let (=) = [%compare.equal: int list] + let%test _ = index_all (create "") ~may_overlap:false ~in_:"abcd" = [0; 1; 2; 3; 4] + let%test _ = index_all (create "") ~may_overlap:true ~in_:"abcd" = [0; 1; 2; 3; 4] + let%test _ = index_all (create "abab") ~may_overlap:false ~in_:"abababab" = [0; 4] + let%test _ = index_all (create "abab") ~may_overlap:true ~in_:"abababab" = [0; 2; 4] + let%test _ = index_all (create "abab") ~may_overlap:false ~in_:"ababababab" = [0; 4] + let%test _ = index_all (create "abab") ~may_overlap:true ~in_:"ababababab" = [0; 2; 4; 6] + let%test _ = index_all (create "aaa") ~may_overlap:false ~in_:"aaaaBaaaaaa" = [0; 5; 8] + let%test _ = index_all (create "aaa") ~may_overlap:true ~in_:"aaaaBaaaaaa" = [0; 1; 5; 6; 7; 8] + + let (=) = [%compare.equal: string] + let%test _ = replace_first (create "abab") ~in_:"abababab" ~with_:"" = "abab" + let%test _ = replace_first (create "abab") ~in_:"abacabab" ~with_:"" = "abac" + let%test _ = replace_first (create "abab") ~in_:"ababacab" ~with_:"A" = "Aacab" + let%test _ = replace_first (create "abab") ~in_:"acabababab" ~with_:"A" = "acAabab" + let%test _ = replace_first (create "ababab") ~in_:"acabababab" ~with_:"A" = "acAab" + let%test _ = replace_first (create "abab") ~in_:"abababab" ~with_:"abababab" = "abababababab" + + let%test _ = replace_all (create "abab") ~in_:"abababab" ~with_:"" = "" + let%test _ = replace_all (create "abab") ~in_:"abacabab" ~with_:"" = "abac" + let%test _ = replace_all (create "abab") ~in_:"acabababab" ~with_:"A" = "acAA" + let%test _ = replace_all (create "ababab") ~in_:"acabababab" ~with_:"A" = "acAab" + let%test _ = replace_all (create "abaC") ~in_:"abaCabaDCababaCabaCaba" ~with_:"x" = "xabaDCabxxaba" + let%test _ = replace_all (create "a") ~in_:"aa" ~with_:"aaa" = "aaaaaa" + let%test _ = replace_all (create "") ~in_:"abcdeefff" ~with_:"X1" = "X1aX1bX1cX1dX1eX1eX1fX1fX1fX1" + + (* a doc comment in core_string.mli gives this as an example *) + let%test _ = replace_all (create "bc") ~in_:"aabbcc" ~with_:"cb" = "aabcbc" + end) + +let%test _ = rev "" = "";; +let%test _ = rev "a" = "a";; +let%test _ = rev "ab" = "ba";; +let%test _ = rev "abc" = "cba";; + +let%test_unit _ = + List.iter ~f:(fun (t, expect) -> + let actual = split_lines t in + if not ([%compare.equal: string list] actual expect) + then raise_s [%message "split_lines bug" + (t : t) (actual : t list) (expect : t list)]) + [ "" , []; + "\n" , [""]; + "a" , ["a"]; + "a\n" , ["a"]; + "a\nb" , ["a"; "b"]; + "a\nb\n" , ["a"; "b"]; + "a\n\n" , ["a"; "" ]; + "a\n\nb" , ["a"; "" ; "b"]; + ] +;; + +let%test_unit _ = + let lines = [ ""; "a"; "bc" ] in + let newlines = [ "\n"; "\r\n" ] in + let rec loop n expect to_concat = + if Int.(=) n 0 then begin + let input = concat to_concat in + let actual = Or_error.try_with (fun () -> split_lines input) in + if not ([%compare.equal: t list Or_error.t] actual (Ok expect)) + then raise_s [%message "split_lines bug" + (input : t) + (actual : t list Or_error.t) + (expect : t list)] + end else begin + loop (n - 1) expect to_concat; + List.iter lines ~f:(fun t -> + let loop to_concat = loop (n - 1) (t :: expect) (t :: to_concat) in + if not (is_empty t) && List.is_empty to_concat then loop []; + List.iter newlines ~f:(fun newline -> loop (newline :: to_concat))); + end + in + loop 3 [] [] +;; + +let%test_unit _ = + let s = init 10 ~f:(Char.of_int_exn) in + assert (phys_equal s (sub s ~pos:0 ~len:(String.length s))); + assert (phys_equal s (prefix s (String.length s))); + assert (phys_equal s (suffix s (String.length s))); + assert (phys_equal s (concat [s])); + assert (phys_equal s (tr s ~target:'\255' ~replacement:'\000')) + +let%test_module "tr_multi" = (module struct + let gold_standard ~target ~replacement string = + map string ~f:(fun char -> + match rindex target char with + | None -> char + | Some i -> get replacement (Int.min i (length replacement - 1))) + + module Test = struct + type nonrec t = + { target : t + ; replacement : t + ; string : t + ; expected : t sexp_option } + [@@deriving sexp_of] + + let quickcheck_generator = + let open Base_quickcheck.Generator in + let open Base_quickcheck.Generator.Let_syntax in + let%bind size = size in + let%bind target_len = int_log_uniform_inclusive 1 255 in + let%bind target = string_with_length ~length:target_len in + let%bind replacement_len = int_inclusive 1 target_len in + let%bind replacement = string_with_length ~length:replacement_len in + let%bind string_length = int_inclusive 0 size in + let%map string = string_with_length ~length:string_length in + { target; replacement; string; expected = None } + + let quickcheck_shrinker = Base_quickcheck.Shrinker.atomic + end + + let examples = + [ "" , "" , "abcdefg", "abcdefg" + ; "" , "a" , "abcdefg", "abcdefg" + ; "aaaa", "abcd", "abcdefg", "dbcdefg" + ; "abcd", "bcde", "abcdefg", "bcdeefg" + ; "abcd", "bcde", "" , "" + ; "abcd", "_" , "abcdefg", "____efg" + ; "abcd", "b_" , "abcdefg", "b___efg" + ; "a" , "dcba", "abcdefg", "dbcdefg" + ; "ab" , "dcba", "abcdefg", "dccdefg" + ] + |> List.map ~f:(fun (target, replacement, string, expected) -> + { Test.target; replacement; string; expected = Some expected }) + + let%test_unit _ = + Base_quickcheck.Test.run_exn (module Test) ~examples + ~f:(fun ({ target; replacement; string; expected } : Test.t) -> + (* test implementation behavior against gold standard *) + let impl_result = unstage (tr_multi ~target ~replacement) string in + let gold_result = gold_standard ~target ~replacement string in + [%test_result:t] ~expect:gold_result impl_result; + (* test against expected result, if one is provided (non-random examples) *) + Option.iter expected ~f:(fun expected -> + [%test_result:t] ~expect:expected impl_result); + (* test for returning input if the string is unchanged *) + if equal string impl_result + then assert (phys_equal string impl_result)) +end) + +let%test_unit _ = [%test_result: int option] (lfindi "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (lfindi ~pos:0 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (lfindi ~pos:1 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 2) +let%test_unit _ = [%test_result: int option] (lfindi "bob" ~f:(fun _ -> Char.(=) 'x')) ~expect:None + +let%test_unit _ = [%test_result: char option] + (find_map "fop" ~f:(fun c -> if Char.(c >= 'o') then Some c else None)) + ~expect:(Some 'o') +let%test_unit _ = [%test_result: _ option] (find_map "bar" ~f:(fun _ -> None)) ~expect:None +let%test_unit _ = [%test_result: _ option] (find_map "" ~f:(fun _ -> assert false)) ~expect:None + +let%test_unit _ = [%test_result: int option] (rfindi "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 2) +let%test_unit _ = [%test_result: int option] (rfindi ~pos:2 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 2) +let%test_unit _ = [%test_result: int option] (rfindi ~pos:1 "bob" ~f:(fun _ -> Char.(=) 'b')) ~expect:(Some 0) +let%test_unit _ = [%test_result: int option] (rfindi "bob" ~f:(fun _ -> Char.(=) 'x')) ~expect:None + +let%test_unit _ = [%test_result: string] (strip " foo bar \n") ~expect:"foo bar" +let%test_unit _ = [%test_result: string] (strip ~drop:(Char.(=) '"') "\" foo bar ") ~expect:" foo bar " +let%test_unit _ = [%test_result: string] (strip ~drop:(Char.(=) '"') " \" foo bar ") ~expect:" \" foo bar " + +let%test_unit _ = [%test_result: bool] ~expect:false (exists "" ~f:(fun _ -> assert false)) +let%test_unit _ = [%test_result: bool] ~expect:false (exists "abc" ~f:(Fn.const false)) +let%test_unit _ = [%test_result: bool] ~expect:true (exists "abc" ~f:(Fn.const true)) +let%test_unit _ = [%test_result: bool] ~expect:true (exists "abc" ~f:(function + 'a' -> false | 'b' -> true | _ -> assert false)) + +let%test_unit _ = [%test_result: bool] ~expect:true (for_all "" ~f:(fun _ -> assert false)) +let%test_unit _ = [%test_result: bool] ~expect:true (for_all "abc" ~f:(Fn.const true)) +let%test_unit _ = [%test_result: bool] ~expect:false (for_all "abc" ~f:(Fn.const false)) +let%test_unit _ = [%test_result: bool] ~expect:false (for_all "abc" ~f:(function + 'a' -> true | 'b' -> false | _ -> assert false)) + +let%test_unit _ = [%test_result: (int * char) list] + (foldi "hello" ~init:[] ~f:(fun i acc ch -> (i,ch)::acc)) + ~expect:(List.rev [0,'h';1,'e';2,'l';3,'l';4,'o']) + +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(Char.(<>) 'h')) ~expect:"ello" +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(Char.(<>) 'l')) ~expect:"heo" +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(fun _ -> false)) ~expect:"" +let%test_unit _ = [%test_result: t] (filter "hello" ~f:(fun _ -> true)) ~expect:"hello" +let%test_unit _ = + let s = "hello" in + [%test_result: bool] ~expect:true (phys_equal (filter s ~f:(fun _ -> true)) s) +let%test_unit _ = + let s = "abc" in + let r = ref 0 in + assert (phys_equal s (filter s ~f:(fun _ -> Int.incr r; true))); + assert (Int.(=) !r (String.length s)) +;; + +let%test_module "Hash" = + (module struct + external hash : string -> int = "Base_hash_string" [@@noalloc] + + let%test_unit _ = + List.iter ~f:(fun string -> + assert (Int.(=) (hash string) (Caml.Hashtbl.hash string)); + (* with 31-bit integers, the hash computed by ppx_hash overflows so it doesn't match + polymorphic hash exactly. *) + if Int.(>) Int.num_bits 31 then + assert (Int.(=) (hash string) ([%hash: string] string)) + ) + [ "Oh Gloria inmarcesible! Oh jubilo inmortal!" + ; "Oh say can you see, by the dawn's early light" + ; "Hahahaha\200" + ] + ;; + end) + +let%test _ = of_char_list ['a';'b';'c'] = "abc" +let%test _ = of_char_list [] = "" + +let%expect_test "mem does not allocate" = + let string = Sys.opaque_identity "abracadabra" in + let char = Sys.opaque_identity 'd' in + require_no_allocation [%here] (fun () -> + ignore (String.mem string char : bool)); + [%expect {||}]; +;; + +let%expect_test "is_substring_at" = + let string = "lorem ipsum dolor sit amet" in + let test pos substring = + match is_substring_at string ~pos ~substring with + | bool -> print_s [%sexp (bool : bool)] + | exception exn -> print_s [%message "raised" ~_:(exn : exn)] + in + test 0 "lorem"; + [%expect {| true |}]; + test 1 "lorem"; + [%expect {| false |}]; + test 6 "ipsum"; + [%expect {| true |}]; + test 5 "ipsum"; + [%expect {| false |}]; + test 22 "amet"; + [%expect {| true |}]; + test 23 "amet"; + [%expect {| false |}]; + test 22 "amet and some other stuff"; + [%expect {| false |}]; + test 0 ""; + [%expect {| true |}]; + test 10 ""; + [%expect {| true |}]; + test 26 ""; + [%expect {| true |}]; + test 100 ""; + [%expect {| + (raised ( + Invalid_argument + "String.is_substring_at: invalid index 100 for string of length 26")) |}]; + test (-1) ""; + [%expect {| + (raised ( + Invalid_argument + "String.is_substring_at: invalid index -1 for string of length 26")) |}]; +;; + +let%test_module "Escaping" = + (module struct + open Escaping + + let%test_module "escape_gen" = + (module struct + let escape = unstage + (escape_gen_exn + ~escapeworthy_map:[('%','p');('^','c')] ~escape_char:'_') + + let%test _ = escape "" = "" + let%test _ = escape "foo" = "foo" + let%test _ = escape "_" = "__" + let%test _ = escape "foo%bar" = "foo_pbar" + let%test _ = escape "^foo%" = "_cfoo_p" + + let escape2 = unstage + (escape_gen_exn + ~escapeworthy_map:[('_','.');('%','p');('^','c')] ~escape_char:'_') + + let%test _ = escape2 "_." = "_.." + let%test _ = escape2 "_" = "_." + let%test _ = escape2 "foo%_bar" = "foo_p_.bar" + let%test _ = escape2 "_foo%" = "_.foo_p" + + let checks_for_one_to_one escapeworthy_map = + Exn.does_raise (fun () -> escape_gen_exn ~escapeworthy_map ~escape_char:'_') + + let%test _ = checks_for_one_to_one [('%','p');('^','c');('$','c')] + let%test _ = checks_for_one_to_one [('%','p');('^','c');('%','d')] + end) + + let%test_module "unescape_gen" = + (module struct + let unescape = + unstage + (unescape_gen_exn ~escapeworthy_map:['%','p';'^','c'] ~escape_char:'_') + + let%test _ = unescape "__" = "_" + let%test _ = unescape "foo" = "foo" + let%test _ = unescape "__" = "_" + let%test _ = unescape "foo_pbar" = "foo%bar" + let%test _ = unescape "_cfoo_p" = "^foo%" + + let unescape2 = + unstage + (unescape_gen_exn ~escapeworthy_map:['_','.';'%','p';'^','c'] ~escape_char:'_') + + (* this one is ill-formed, just ignore the escape_char without escaped char *) + let%test _ = unescape2 "_" = "" + let%test _ = unescape2 "a_" = "a" + + let%test _ = unescape2 "__" = "_" + let%test _ = unescape2 "_.." = "_." + let%test _ = unescape2 "_." = "_" + let%test _ = unescape2 "foo_p_.bar" = "foo%_bar" + let%test _ = unescape2 "_.foo_p" = "_foo%" + + (* generate [n] random string and check if escaping and unescaping are consistent *) + let random_test ~escapeworthy_map ~escape_char n = + let escape = + unstage (escape_gen_exn ~escapeworthy_map ~escape_char) + in + let unescape = + unstage (unescape_gen_exn ~escapeworthy_map ~escape_char) + in + let test str = + let escaped = escape str in + let unescaped = unescape escaped in + if str <> unescaped then + failwith ( + Printf.sprintf + "string: %s\nescaped string: %s\nunescaped string: %s" + str escaped unescaped) + in + let random_char = + let print_chars = + List.range (Char.to_int Char.min_value) (Char.to_int Char.max_value + 1) + |> List.filter_map ~f:Char.of_int + |> List.filter ~f:Char.is_print + |> Array.of_list + in + fun () -> Array.random_element_exn print_chars + in + let escapeworthy_chars = + List.map escapeworthy_map ~f:fst |> Array.of_list + in + try + for _ = 0 to n - 1 do + let str = + List.init (Random.int 50) ~f:(fun _ -> + let p = Random.int 100 in + if Int.(p < 10) then + escape_char + else if Int.(p < 25) then + Array.random_element_exn escapeworthy_chars + else + random_char () + ) + |> of_char_list + in + test str + done; + true + with e -> + raise e + + let%test _ = random_test 1000 ~escapeworthy_map:['%','p';'^','c'] ~escape_char:'_' + let%test _ = random_test 1000 ~escapeworthy_map:['_','.';'%','p';'^','c'] ~escape_char:'_' + end) + + let%test_module "escape" = + (module struct + let escape = unstage (escape ~escape_char:'_' ~escapeworthy:['_'; '%'; '^']) + let%test _ = escape "foo" = "foo" + let%test _ = escape "_" = "__" + let%test _ = escape "foo%bar" = "foo_%bar" + let%test _ = escape "^foo%" = "_^foo_%" + end) + + let%test_module "unescape" = + (module struct + let unescape = unstage (unescape ~escape_char:'_') + let%test _ = unescape "foo" = "foo" + let%test _ = unescape "__" = "_" + let%test _ = unescape "foo_%bar" = "foo%bar" + let%test _ = unescape "_^foo_%" = "^foo%" + end) + + let%test_module "is_char_escaping" = + (module struct + let is = is_char_escaping ~escape_char:'_' + let%test_unit _ = [%test_result: bool] (is "___" 0) ~expect:true + let%test_unit _ = [%test_result: bool] (is "___" 1) ~expect:false + let%test_unit _ = [%test_result: bool] (is "___" 2) ~expect:true + (* considered escaping, though there's nothing to escape *) + + let%test_unit _ = [%test_result: bool] (is "a_b__c" 0) ~expect:false + let%test_unit _ = [%test_result: bool] (is "a_b__c" 1) ~expect:true + let%test_unit _ = [%test_result: bool] (is "a_b__c" 2) ~expect:false + let%test_unit _ = [%test_result: bool] (is "a_b__c" 3) ~expect:true + let%test_unit _ = [%test_result: bool] (is "a_b__c" 4) ~expect:false + let%test_unit _ = [%test_result: bool] (is "a_b__c" 5) ~expect:false + end) + + let%test_module "is_char_escaped" = + (module struct + let is = is_char_escaped ~escape_char:'_' + let%test_unit _ = [%test_result: bool] (is "___" 2) ~expect:false + let%test_unit _ = [%test_result: bool] (is "x" 0) ~expect:false + let%test_unit _ = [%test_result: bool] (is "_x" 1) ~expect:true + let%test_unit _ = [%test_result: bool] (is "sadflkas____sfff" 12) ~expect:false + let%test_unit _ = [%test_result: bool] (is "s_____s" 6) ~expect:true + end) + + let%test_module "is_char_literal" = + (module struct + let is_char_literal = is_char_literal ~escape_char:'_' + let%test_unit _ = [%test_result: bool] (is_char_literal "123456" 4) ~expect:true + let%test_unit _ = [%test_result: bool] (is_char_literal "12345_6" 6) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "12345_6" 5) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "123__456" 4) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "123456__" 7) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "__123456" 1) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "__123456" 0) ~expect:false + let%test_unit _ = [%test_result: bool] (is_char_literal "__123456" 2) ~expect:true + end) + + let%test_module "index_from" = + (module struct + let f = index_from ~escape_char:'_' + let%test_unit _ = [%test_result: int option] (f "__" 0 '_') ~expect:None + let%test_unit _ = [%test_result: int option] (f "_.." 0 '.') ~expect:(Some 2) + let%test_unit _ = [%test_result: int option] (f "1273456_7789" 3 '7') ~expect:(Some 9) + let%test_unit _ = [%test_result: int option] (f "1273_7456_7789" 3 '7') ~expect:(Some 11) + let%test_unit _ = [%test_result: int option] (f "1273_7456_7789" 3 'z') ~expect:None + end) + + let%test_module "rindex" = + (module struct + let f = rindex_from ~escape_char:'_' + let%test_unit _ = [%test_result: int option] (f "__" 0 '_') ~expect:None + let%test_unit _ = [%test_result: int option] (f "123456_37839" 9 '3') ~expect:(Some 2) + let%test_unit _ = [%test_result: int option] (f "123_2321" 6 '2') ~expect:(Some 6) + let%test_unit _ = [%test_result: int option] (f "123_2321" 5 '2') ~expect:(Some 1) + + let%test_unit _ = [%test_result: int option] (rindex "" ~escape_char:'_' 'x') ~expect:None + let%test_unit _ = [%test_result: int option] (rindex "a_a" ~escape_char:'_' 'a') ~expect:(Some 0) + end) + + let%test_module "split" = + (module struct + let split = split ~escape_char:'_' ~on:',' + let%test_unit _ = [%test_result: string list] (split "foo,bar,baz") ~expect:["foo"; "bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_,bar,baz") ~expect:["foo_,bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_,bar_,baz") ~expect:["foo_,bar_,baz"] + let%test_unit _ = [%test_result: string list] (split "foo__,bar,baz") ~expect:["foo__"; "bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo,bar,baz_,") ~expect:["foo"; "bar"; "baz_,"] + let%test_unit _ = [%test_result: string list] (split "foo,bar_,baz_,,") ~expect:["foo"; "bar_,baz_,"; ""] + end) + + let%test_module "split_on_chars" = + (module struct + let split = split_on_chars ~escape_char:'_' ~on:[',';':'] + let%test_unit _ = [%test_result: string list] (split "foo,bar:baz") ~expect:["foo"; "bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_,bar,baz") ~expect:["foo_,bar"; "baz"] + let%test_unit _ = [%test_result: string list] (split "foo_:bar_,baz") ~expect:["foo_:bar_,baz"] + let%test_unit _ = [%test_result: string list] (split "foo,bar,baz_,") ~expect:["foo"; "bar"; "baz_,"] + let%test_unit _ = [%test_result: string list] (split "foo:bar_,baz_,,") ~expect:["foo"; "bar_,baz_,"; ""] + end) + + let%test_module "split2" = + (module struct + let escape_char = '_' + let on = ',' + let%test_unit _ = [%test_result: (string * string) option] (lsplit2 ~escape_char ~on "foo_,bar,baz_,0") ~expect:(Some ("foo_,bar", "baz_,0")) + let%test_unit _ = [%test_result: (string * string) option] (rsplit2 ~escape_char ~on "foo_,bar,baz_,0") ~expect:(Some ("foo_,bar", "baz_,0")) + let%test_unit _ = [%test_result: string * string] (lsplit2_exn ~escape_char ~on "foo_,bar,baz_,0") ~expect:("foo_,bar", "baz_,0") + let%test_unit _ = [%test_result: string * string] (rsplit2_exn ~escape_char ~on "foo_,bar,baz_,0") ~expect:("foo_,bar", "baz_,0") + let%test_unit _ = [%test_result: (string * string) option] (lsplit2 ~escape_char ~on "foo_,bar") ~expect:None + let%test_unit _ = [%test_result: (string * string) option] (rsplit2 ~escape_char ~on "foo_,bar") ~expect:None + let%test _ = Exn.does_raise (fun () -> lsplit2_exn ~escape_char ~on "foo_,bar") + let%test _ = Exn.does_raise (fun () -> rsplit2_exn ~escape_char ~on "foo_,bar") + end) + + let%test _ = strip_literal ~escape_char:' ' " foo bar \n" = " foo bar \n" + let%test _ = strip_literal ~escape_char:' ' " foo bar \n\n" = " foo bar \n" + let%test _ = strip_literal ~escape_char:'\n' " foo bar \n" = "foo bar \n" + + let%test _ = lstrip_literal ~escape_char:' ' " foo bar \n\n" = " foo bar \n\n" + let%test _ = rstrip_literal ~escape_char:' ' " foo bar \n\n" = " foo bar \n" + let%test _ = lstrip_literal ~escape_char:'\n' " foo bar \n" = "foo bar \n" + let%test _ = rstrip_literal ~escape_char:'\n' " foo bar \n" = " foo bar \n" + + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'\\' "foo boar" = " " + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'\\' "fooboar" = "" + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'o' "foo boar" = "oo boa" + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'a' "foo boar" = " boar" + let%test _ = strip_literal ~drop:(Char.is_alpha) ~escape_char:'b' "foo boar" = " bo" + + let%test _ = lstrip_literal ~drop:(Char.is_alpha) ~escape_char:'o' "foo boar" = "oo boar" + let%test _ = rstrip_literal ~drop:(Char.is_alpha) ~escape_char:'o' "foo boar" = "foo boa" + let%test _ = lstrip_literal ~drop:(Char.is_alpha) ~escape_char:'b' "foo boar" = " boar" + let%test _ = rstrip_literal ~drop:(Char.is_alpha) ~escape_char:'b' "foo boar" = "foo bo" + end) diff --git a/test/test_string.mli b/test/test_string.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_string.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_type_equal.ml b/test/test_type_equal.ml new file mode 100644 index 0000000..64b3b78 --- /dev/null +++ b/test/test_type_equal.ml @@ -0,0 +1,73 @@ +open! Import +open! Type_equal + +let%expect_test "[Id.sexp_of_t]" = + let id = Id.create ~name:"some-type-id" [%sexp_of: unit] in + print_s [%sexp (id : _ Id.t)]; + [%expect {| + some-type-id |}]; +;; + +let%test_module "Type_equal.Id" = + (module struct + open Type_equal.Id + let t1 = create ~name:"t1" [%sexp_of: _] + let t2 = create ~name:"t2" [%sexp_of: _] + + let%test _ = same t1 t1 + let%test _ = not (same t1 t2) + + let%test _ = Option.is_some (same_witness t1 t1) + let%test _ = Option.is_none (same_witness t1 t2) + + let%test_unit _ = ignore (same_witness_exn t1 t1 : (_, _) Type_equal.equal) + let%test _ = Result.is_error (Result.try_with (fun () -> same_witness_exn t1 t2)) + end) + +(* This test shows that we need [conv] even though [Type_equal.T] is exposed. *) +let%test_module "Type_equal" = + (module struct + open Type_equal + + let id = Id.create ~name:"int" [%sexp_of: int] + + module A : sig + type t + val id : t Id.t + end = struct + type t = int + let id = id + end + + module B : sig + type t + val id : t Id.t + end = struct + type t = int + let id = id + end + + let _a_to_b (a : A.t) = + let eq = Id.same_witness_exn A.id B.id in + (conv eq a : B.t) + ;; + + (* the following is rejected by the compiler *) + (* let _a_to_b (a : A.t) = + * let T = Id.same_witness_exn A.id B.id in + * (a : B.t) + *) + + module C = struct + type 'a t + end + + module Liftc = Lift (C) + + let _ac_to_bc (ac : A.t C.t) = + let eq = Liftc.lift (Id.same_witness_exn A.id B.id) in + (conv eq ac : B.t C.t) + ;; + end) + + diff --git a/test/test_type_equal.mli b/test/test_type_equal.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_type_equal.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_uchar.ml b/test/test_uchar.ml new file mode 100644 index 0000000..f3ec288 --- /dev/null +++ b/test/test_uchar.ml @@ -0,0 +1,81 @@ +open! Import + +let min_int = Int.min_value +let max_int = Int.max_value + +let raises f v = Exn.does_raise (fun () -> f v) + +let%test_module "test_constants" = + (module struct + let%test _ = Uchar.(to_scalar min_value) = 0x0000 + let%test _ = Uchar.(to_scalar max_value) = 0x10FFFF + end) + +let%test_module "test_succ_exn" = + (module struct + let%test _ = raises Uchar.succ_exn Uchar.max_value + let%test _ = Uchar.(to_scalar (succ_exn min_value)) = 0x0001 + let%test _ = Uchar.(to_scalar (succ_exn (of_scalar_exn 0xD7FF))) = 0xE000 + let%test _ = Uchar.(to_scalar (succ_exn (of_scalar_exn 0xE000))) = 0xE001 + end) + +let%test_module "test_pred_exn" = + (module struct + let%test _ = raises Uchar.pred_exn Uchar.min_value + let%test _ = Uchar.(to_scalar (pred_exn (of_scalar_exn 0xD7FF))) = 0xD7FE + let%test _ = Uchar.(to_scalar (pred_exn (of_scalar_exn 0xE000))) = 0xD7FF + let%test _ = Uchar.(to_scalar (pred_exn max_value)) = 0x10FFFE + end) + +let%test_module "test_int_is_scalar" = + (module struct + let%test _ = not (Uchar.int_is_scalar (-1)) + let%test _ = Uchar.int_is_scalar 0x0000 + let%test _ = Uchar.int_is_scalar 0xD7FF + let%test _ = not (Uchar.int_is_scalar 0xD800) + let%test _ = not (Uchar.int_is_scalar 0xDFFF) + let%test _ = Uchar.int_is_scalar 0xE000 + let%test _ = Uchar.int_is_scalar 0x10FFFF + let%test _ = not (Uchar.int_is_scalar 0x110000) + let%test _ = not (Uchar.int_is_scalar min_int) + let%test _ = not (Uchar.int_is_scalar max_int) + end) + +let char_max = Uchar.of_scalar_exn 0x00FF + +let%test_module "test_is_char" = + (module struct + let%test _ = Uchar.(is_char Uchar.min_value) + let%test _ = Uchar.(is_char char_max) + let%test _ = Uchar.(not (is_char (of_scalar_exn 0x0100))) + let%test _ = not (Uchar.is_char Uchar.max_value) + end) + +let%test_module "test_of_char" = + (module struct + let%test _ = Uchar.(equal (of_char '\xFF') char_max) + let%test _ = Uchar.(equal (of_char '\x00') min_value) + end) + +let%test_module "test_to_char_exn" = + (module struct + let%test _ = Char.equal Uchar.(to_char_exn min_value) '\x00' + let%test _ = Char.equal Uchar.(to_char_exn char_max) '\xFF' + let%test _ = raises Uchar.to_char_exn (Uchar.succ_exn char_max) + let%test _ = raises Uchar.to_char_exn Uchar.max_value + end) + +let%test_module "test_equal" = + (module struct + let%test _ = Uchar.(equal min_value min_value) + let%test _ = Uchar.(equal max_value max_value) + let%test _ = not Uchar.(equal min_value max_value) + end) + +let%test_module "test_compare" = + (module struct + let%test _ = Uchar.(compare min_value min_value) = 0 + let%test _ = Uchar.(compare max_value max_value) = 0 + let%test _ = Uchar.(compare min_value max_value) = (-1) + let%test _ = Uchar.(compare max_value min_value) = 1 + end) diff --git a/test/test_uchar.mli b/test/test_uchar.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_uchar.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_uniform_array.ml b/test/test_uniform_array.ml new file mode 100644 index 0000000..dab159c --- /dev/null +++ b/test/test_uniform_array.ml @@ -0,0 +1,13 @@ +open! Import +open Uniform_array + +module Sequence = struct + type nonrec 'a t = 'a t + type 'a z = 'a + let length = length + let get = get + let set = set + let create_bool ~len = create ~len false +end + +include Base_for_tests.Test_blit.Test1(Sequence)(Uniform_array) diff --git a/test/test_validate.ml b/test/test_validate.ml new file mode 100644 index 0000000..f88ace0 --- /dev/null +++ b/test/test_validate.ml @@ -0,0 +1,95 @@ +open! Import +open! Validate + +let print t = + List.iter (errors t) ~f:Caml.print_endline +;; + +let%expect_test "Validate.all" = + print + (all [ + (fun _ -> fail "a"); + (fun _ -> pass); + (fun _ -> fail "b"); + (fun _ -> pass); + (fun _ -> fail "c"); + ] + ()); + [%expect {| + ("" a) + ("" b) + ("" c) + |}] +;; + +let%expect_test _ = + print (first_failure pass (fail "foo")); + [%expect {| ("" foo) |}] +;; + +let%expect_test _ = + print (first_failure (fail "foo") (fail "bar")); + [%expect {| ("" foo) |}] +;; + +let two_errors = of_list [fail "foo"; fail "bar"] + +let%expect_test _ = + print (first_failure two_errors (fail "snoo")); + [%expect {| + ("" foo) + ("" bar) + |}] +;; + +let%expect_test _ = + print (first_failure (fail "snoo") two_errors); + [%expect {| ("" snoo) |}] +;; + +let%expect_test _ = + let v () = + if true + then + failwith "This unit validation raises"; + Validate.pass + in + print (protect v ()); + [%expect {| + ("" + ("Exception raised during validation" + (Failure "This unit validation raises"))) |}] +;; + +let%expect_test "try_with" = + let v () = + failwith "this function raises" + in + print (try_with v); + [%expect {| + ("" ("Exception raised during validation" (Failure "this function raises"))) |}] +;; + +type t = { x : bool } [@@deriving fields] + +let%expect_test "typical use of Validate.field_direct_folder doesn't allocate on success" = + let validate_x = Staged.unstage (Validate.field_direct_folder Validate.pass_bool) in + let validate t = + Fields.Direct.fold t ~init:[] ~x:validate_x + |> Validate.of_list + |> Validate.result + in + let t = { x = true } in + require_no_allocation [%here] (fun () -> ignore (validate t : unit Or_error.t)); +;; + +let%expect_test "Validate.all doesn't allocate on success" = + let checks = List.init 5 ~f:(Fn.const Validate.pass_bool) in + require_no_allocation [%here] (fun () -> ignore (Validate.all checks true : Validate.t)); +;; + +let%expect_test "Validate.combine doesn't allocate on success" = + require_no_allocation [%here] (fun () -> + ignore (Validate.combine Validate.pass Validate.pass : Validate.t)); +;; + diff --git a/test/test_validate.mli b/test/test_validate.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_validate.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_with_return.ml b/test/test_with_return.ml new file mode 100644 index 0000000..be7423b --- /dev/null +++ b/test/test_with_return.ml @@ -0,0 +1,54 @@ +open! Import +open! With_return + +let test_loop loop_limit jump_out = + with_return (fun { return } -> + for i = 0 to loop_limit do begin + if i = jump_out then return (`Jumped_out i); + end done; + `Normal) +;; + +let ( = ) = Poly.equal + +let%test _ = test_loop 5 10 = `Normal +let%test _ = test_loop 10 5 = `Jumped_out 5 +let%test _ = test_loop 5 5 = `Jumped_out 5 + +let test_nested outer inner = + with_return (fun { return = return_outer } -> + if outer = `Outer_jump then return_outer `Outer_jump; + let inner_res = + with_return (fun { return = return_inner } -> + if inner = `Inner_jump_out_completely then return_outer `Inner_jump; + if inner = `Inner_jump then return_inner `Inner_jump; + `Inner_normal) + in + if outer = `Jump_with_inner then return_outer (`Outer_later_jump inner_res); + `Outer_normal inner_res) +;; + +let%test _ = test_nested `Outer_jump `Inner_jump = `Outer_jump +let%test _ = test_nested `Outer_jump `Inner_jump_out_completely = `Outer_jump +let%test _ = test_nested `Outer_jump `Foo = `Outer_jump + +let%test _ = test_nested `Jump_with_inner `Inner_jump_out_completely = `Inner_jump +let%test _ = test_nested `Jump_with_inner `Inner_jump = `Outer_later_jump `Inner_jump +let%test _ = test_nested `Jump_with_inner `Foo = `Outer_later_jump `Inner_normal + +let%test _ = test_nested `Foo `Inner_jump_out_completely = `Inner_jump +let%test _ = test_nested `Foo `Inner_jump = `Outer_normal `Inner_jump +let%test _ = test_nested `Foo `Foo = `Outer_normal `Inner_normal + +let test_loop loop_limit jump_out = + with_return_option (fun { return } -> + for i = 0 to loop_limit do begin + if i = jump_out then return (`Jumped_out i); + end done) +;; + +let ( = ) = Poly.equal + +let%test _ = test_loop 5 10 = None +let%test _ = test_loop 10 5 = Some (`Jumped_out 5) +let%test _ = test_loop 5 5 = Some (`Jumped_out 5) diff --git a/test/test_with_return.mli b/test/test_with_return.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_with_return.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/test_word_size.ml b/test/test_word_size.ml new file mode 100644 index 0000000..1275de1 --- /dev/null +++ b/test/test_word_size.ml @@ -0,0 +1,11 @@ +open! Import +open! Word_size + +let%expect_test _ = + print_s [%message (W32 : t)]; + [%expect {| + (W32 W32) |}]; + print_s [%message (W64 : t)]; + [%expect {| + (W64 W64) |}]; +;; diff --git a/test/test_word_size.mli b/test/test_word_size.mli new file mode 100644 index 0000000..74bb729 --- /dev/null +++ b/test/test_word_size.mli @@ -0,0 +1 @@ +(*_ This signature is deliberately empty. *) diff --git a/test/validate_fields_folder.mlt b/test/validate_fields_folder.mlt new file mode 100644 index 0000000..2a9a01f --- /dev/null +++ b/test/validate_fields_folder.mlt @@ -0,0 +1,127 @@ +open! Core_kernel + +(* Regression tests to ensure that [Validate.field], [Validate.field_folder], and + [Validate.field_direct_folder] continue to work with private record types. *) + +module Fold_with_private + (M : sig + type t = private + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = Validate.field_folder t f in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end + +[%%expect {| +|}] + +module Fold_regular + (M : sig + type t = + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = Validate.field_folder t f in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end + +[%%expect {| +|}] + +module Fold_direct_private + (M : sig + type t = private + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = unstage (Validate.field_direct_folder f) in + Fields.Direct.fold t ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end + +[%%expect {| +|}] + +module Fold_direct_regular + (M : sig + type t = + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w f = unstage (Validate.field_direct_folder f) in + Fields.Direct.fold t ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end +[%%expect{| +|}] + +module Validate_field_private + (M : sig + type t = private + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w check acc field = Validate.field t field check :: acc in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end +[%%expect{| +|}] + +module Validate_field + (M : sig + type t = + { a : int } + [@@deriving fields] + end) + : sig + val validate : M.t -> Validate.t + end = struct + open M + + let validate t = + let w check acc field = Validate.field t field check :: acc in + Fields.fold ~init:[] + ~a:(w (fun _ -> Validate.pass)) + |> Validate.of_list +end +[%%expect{| +|}] |