summaryrefslogtreecommitdiff
path: root/src/cry.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/cry.ml')
-rw-r--r--src/cry.ml68
1 files changed, 49 insertions, 19 deletions
diff --git a/src/cry.ml b/src/cry.ml
index 8580ebd..7c6b223 100644
--- a/src/cry.ml
+++ b/src/cry.ml
@@ -20,6 +20,23 @@
(** OCaml low level implementation of the shout source protocol. *)
+external poll :
+ Unix.file_descr array ->
+ Unix.file_descr array ->
+ Unix.file_descr array ->
+ float ->
+ Unix.file_descr array * Unix.file_descr array * Unix.file_descr array
+ = "caml_cry_poll"
+
+let poll r w e timeout =
+ let r = Array.of_list r in
+ let w = Array.of_list w in
+ let e = Array.of_list e in
+ let r, w, e = poll r w e timeout in
+ (Array.to_list r, Array.to_list w, Array.to_list e)
+
+let select = match Sys.os_type with "Unix" -> poll | _ -> Unix.select
+
type error =
| Create of exn
| Connect of exn
@@ -52,7 +69,13 @@ and transport =
< name : string
; protocol : string
; default_port : int
- ; connect : ?bind_address:string -> ?timeout:float -> string -> int -> socket >
+ ; connect :
+ ?bind_address:string ->
+ ?timeout:float ->
+ ?prefer:[ `System_default | `Ipv4 | `Ipv6 ] ->
+ string ->
+ int ->
+ socket >
(* Wait for [`Read socker], [`Write socket] or [`Both socket] for at most
* [timeout] seconds on the given [socket]. Raises [Timeout] if timeout
@@ -68,7 +91,7 @@ let wait_for ?(log = fun _ -> ()) event timeout =
in
let rec wait t =
let r, w, _ =
- try Unix.select r w [] t
+ try select r w [] t
with Unix.Unix_error (Unix.EINTR, _, _) -> ([], [], [])
in
if r = [] && w = [] then (
@@ -106,18 +129,23 @@ let sockaddr_of_address address =
| addr :: _ -> addr.ai_addr
let addrinfo_order = function
- | Unix.ADDR_UNIX _ -> 2
- | Unix.ADDR_INET (s, _) -> if Unix.is_inet6_addr s then 1 else 0
+ | _, Unix.ADDR_UNIX _ -> 2
+ | `Ipv4, Unix.ADDR_INET (s, _) -> if Unix.is_inet6_addr s then 1 else 0
+ | `Ipv6, Unix.ADDR_INET (s, _) -> if Unix.is_inet6_addr s then 0 else 1
-let resolve_host host port =
+let resolve_host ~prefer host port =
match
- Unix.getaddrinfo host (string_of_int port) [AI_SOCKTYPE SOCK_STREAM]
+ ( prefer,
+ Unix.getaddrinfo host (string_of_int port) [AI_SOCKTYPE SOCK_STREAM] )
with
- | [] -> raise Not_found
- | l ->
+ | _, [] -> raise Not_found
+ | `System_default, l -> l
+ | ((`Ipv4, l) as v) | ((`Ipv6, l) as v) ->
List.sort
(fun { Unix.ai_addr = s; _ } { Unix.ai_addr = s'; _ } ->
- Stdlib.compare (addrinfo_order s) (addrinfo_order s'))
+ Stdlib.compare
+ (addrinfo_order (fst v, s))
+ (addrinfo_order (fst v, s')))
l
let connect_sockaddr ?bind_address ?timeout sockaddr =
@@ -138,7 +166,7 @@ let connect_sockaddr ?bind_address ?timeout sockaddr =
match timeout with
| Some timeout ->
(* Block in a select call for [timeout] seconds. *)
- let _, w, _ = Unix.select [] [socket] [] timeout in
+ let _, w, _ = select [] [socket] [] timeout in
if w = [] then raise Timeout;
Unix.clear_nonblock socket;
socket
@@ -163,7 +191,7 @@ let connect_sockaddr ?bind_address ?timeout sockaddr =
end;
Printexc.raise_with_backtrace e bt
-let unix_connect ?bind_address ?timeout host port =
+let unix_connect ?bind_address ?timeout ?(prefer = `System_default) host port =
let rec connect_any ?bind_address ?timeout (addrs : Unix.addr_info list) =
match addrs with
| [] -> raise Not_found
@@ -174,7 +202,7 @@ let unix_connect ?bind_address ?timeout host port =
try connect_sockaddr ?bind_address ?timeout addr.ai_addr
with _ -> connect_any ?bind_address ?timeout tail)
in
- connect_any ?bind_address ?timeout (resolve_host host port)
+ connect_any ?bind_address ?timeout (resolve_host ~prefer host port)
let unix_transport : transport =
object (self)
@@ -182,8 +210,8 @@ let unix_transport : transport =
method protocol = "http"
method default_port = 80
- method connect ?bind_address ?timeout host port =
- let socket = unix_connect ?bind_address ?timeout host port in
+ method connect ?bind_address ?timeout ?prefer host port =
+ let socket = unix_connect ?bind_address ?timeout ?prefer host port in
unix_socket self socket
end
@@ -321,11 +349,13 @@ let write_data ~timeout ?(offset = 0) ?length (socket : socket) request =
let close x =
try
let c = get_connection_data x in
- if x.chunked then write_data ~timeout:x.timeout c.socket "0\r\n\r\n";
- c.socket#close;
- x.chunked <- false;
- x.icy_cap <- false;
- x.status <- PrivDisconnected
+ Fun.protect
+ ~finally:(fun () -> c.socket#close)
+ (fun () ->
+ if x.chunked then write_data ~timeout:x.timeout c.socket "0\r\n\r\n";
+ x.chunked <- false;
+ x.icy_cap <- false;
+ x.status <- PrivDisconnected)
with
| Error _ as e -> raise e
| e -> raise (Error (Close e))