diff options
Diffstat (limited to 'src/cry.ml')
-rw-r--r-- | src/cry.ml | 68 |
1 files changed, 49 insertions, 19 deletions
@@ -20,6 +20,23 @@ (** OCaml low level implementation of the shout source protocol. *) +external poll : + Unix.file_descr array -> + Unix.file_descr array -> + Unix.file_descr array -> + float -> + Unix.file_descr array * Unix.file_descr array * Unix.file_descr array + = "caml_cry_poll" + +let poll r w e timeout = + let r = Array.of_list r in + let w = Array.of_list w in + let e = Array.of_list e in + let r, w, e = poll r w e timeout in + (Array.to_list r, Array.to_list w, Array.to_list e) + +let select = match Sys.os_type with "Unix" -> poll | _ -> Unix.select + type error = | Create of exn | Connect of exn @@ -52,7 +69,13 @@ and transport = < name : string ; protocol : string ; default_port : int - ; connect : ?bind_address:string -> ?timeout:float -> string -> int -> socket > + ; connect : + ?bind_address:string -> + ?timeout:float -> + ?prefer:[ `System_default | `Ipv4 | `Ipv6 ] -> + string -> + int -> + socket > (* Wait for [`Read socker], [`Write socket] or [`Both socket] for at most * [timeout] seconds on the given [socket]. Raises [Timeout] if timeout @@ -68,7 +91,7 @@ let wait_for ?(log = fun _ -> ()) event timeout = in let rec wait t = let r, w, _ = - try Unix.select r w [] t + try select r w [] t with Unix.Unix_error (Unix.EINTR, _, _) -> ([], [], []) in if r = [] && w = [] then ( @@ -106,18 +129,23 @@ let sockaddr_of_address address = | addr :: _ -> addr.ai_addr let addrinfo_order = function - | Unix.ADDR_UNIX _ -> 2 - | Unix.ADDR_INET (s, _) -> if Unix.is_inet6_addr s then 1 else 0 + | _, Unix.ADDR_UNIX _ -> 2 + | `Ipv4, Unix.ADDR_INET (s, _) -> if Unix.is_inet6_addr s then 1 else 0 + | `Ipv6, Unix.ADDR_INET (s, _) -> if Unix.is_inet6_addr s then 0 else 1 -let resolve_host host port = +let resolve_host ~prefer host port = match - Unix.getaddrinfo host (string_of_int port) [AI_SOCKTYPE SOCK_STREAM] + ( prefer, + Unix.getaddrinfo host (string_of_int port) [AI_SOCKTYPE SOCK_STREAM] ) with - | [] -> raise Not_found - | l -> + | _, [] -> raise Not_found + | `System_default, l -> l + | ((`Ipv4, l) as v) | ((`Ipv6, l) as v) -> List.sort (fun { Unix.ai_addr = s; _ } { Unix.ai_addr = s'; _ } -> - Stdlib.compare (addrinfo_order s) (addrinfo_order s')) + Stdlib.compare + (addrinfo_order (fst v, s)) + (addrinfo_order (fst v, s'))) l let connect_sockaddr ?bind_address ?timeout sockaddr = @@ -138,7 +166,7 @@ let connect_sockaddr ?bind_address ?timeout sockaddr = match timeout with | Some timeout -> (* Block in a select call for [timeout] seconds. *) - let _, w, _ = Unix.select [] [socket] [] timeout in + let _, w, _ = select [] [socket] [] timeout in if w = [] then raise Timeout; Unix.clear_nonblock socket; socket @@ -163,7 +191,7 @@ let connect_sockaddr ?bind_address ?timeout sockaddr = end; Printexc.raise_with_backtrace e bt -let unix_connect ?bind_address ?timeout host port = +let unix_connect ?bind_address ?timeout ?(prefer = `System_default) host port = let rec connect_any ?bind_address ?timeout (addrs : Unix.addr_info list) = match addrs with | [] -> raise Not_found @@ -174,7 +202,7 @@ let unix_connect ?bind_address ?timeout host port = try connect_sockaddr ?bind_address ?timeout addr.ai_addr with _ -> connect_any ?bind_address ?timeout tail) in - connect_any ?bind_address ?timeout (resolve_host host port) + connect_any ?bind_address ?timeout (resolve_host ~prefer host port) let unix_transport : transport = object (self) @@ -182,8 +210,8 @@ let unix_transport : transport = method protocol = "http" method default_port = 80 - method connect ?bind_address ?timeout host port = - let socket = unix_connect ?bind_address ?timeout host port in + method connect ?bind_address ?timeout ?prefer host port = + let socket = unix_connect ?bind_address ?timeout ?prefer host port in unix_socket self socket end @@ -321,11 +349,13 @@ let write_data ~timeout ?(offset = 0) ?length (socket : socket) request = let close x = try let c = get_connection_data x in - if x.chunked then write_data ~timeout:x.timeout c.socket "0\r\n\r\n"; - c.socket#close; - x.chunked <- false; - x.icy_cap <- false; - x.status <- PrivDisconnected + Fun.protect + ~finally:(fun () -> c.socket#close) + (fun () -> + if x.chunked then write_data ~timeout:x.timeout c.socket "0\r\n\r\n"; + x.chunked <- false; + x.icy_cap <- false; + x.status <- PrivDisconnected) with | Error _ as e -> raise e | e -> raise (Error (Close e)) |