diff --git a/controller/server/update_client.ml b/controller/server/update_client.ml index 1839e755..ee4cd616 100644 --- a/controller/server/update_client.ml +++ b/controller/server/update_client.ml @@ -10,12 +10,14 @@ end module type UpdateClientDeps = sig val base_url: Uri.t + val download_dir: string val get_proxy: unit -> Uri.t option Lwt.t end -let make_deps get_proxy base_url : (module UpdateClientDeps) = (module struct +let make_deps ?(download_dir="/tmp") get_proxy base_url : (module UpdateClientDeps) = (module struct let base_url = base_url let get_proxy = get_proxy + let download_dir = download_dir end) let bundle_name = Config.System.bundle_name @@ -32,6 +34,7 @@ let ensure_trailing_slash uri = module UpdateClient (DepsI: UpdateClientDeps) = struct let get_proxy = DepsI.get_proxy + let download_dir = DepsI.download_dir let base_url = ensure_trailing_slash DepsI.base_url let download_url version_string = @@ -51,7 +54,9 @@ module UpdateClient (DepsI: UpdateClientDeps) = struct (** download RAUC bundle *) let download version = let url = download_url version in - let bundle_path = Format.sprintf "/tmp/%s" (bundle_file_name version) in + let bundle_path = Format.sprintf + "%s/%s" download_dir (bundle_file_name version) + in let options = [ "--continue-at"; "-" (* resume download *) ; "--limit-rate"; "10M" diff --git a/controller/server/update_client.mli b/controller/server/update_client.mli index fc4457ee..1b35e2de 100644 --- a/controller/server/update_client.mli +++ b/controller/server/update_client.mli @@ -11,10 +11,11 @@ end module type UpdateClientDeps = sig val base_url: Uri.t + val download_dir: string val get_proxy: unit -> Uri.t option Lwt.t end -val make_deps : (unit -> Uri.t option Lwt.t) -> Uri.t -> (module UpdateClientDeps) +val make_deps : ?download_dir:string -> (unit -> Uri.t option Lwt.t) -> Uri.t -> (module UpdateClientDeps) module Make (DepsI : UpdateClientDeps) : S diff --git a/controller/tests/update_client_tests.ml b/controller/tests/update_client_tests.ml index adc384d9..0ccd82fe 100644 --- a/controller/tests/update_client_tests.ml +++ b/controller/tests/update_client_tests.ml @@ -21,6 +21,8 @@ type state = { available_bundles: (string, string) Hashtbl.t ; } +type range = (int Option.t) * (int Option.t) + let stub_server () = object (self) val mutable state = ref { latest_version = "0.0.0"; @@ -42,11 +44,71 @@ let stub_server () = object (self) in Lwt.return resp + method private extract_range_bytes req : range = + let headers = Request.headers req in + let range = Cohttp.Header.get headers "Range" in + match range with + | Some range_str -> begin + try + let regex = Str.regexp "bytes=\\([0-9]*\\)-\\([0-9]*\\)" in + let m = Str.string_match regex range_str 0 in + let r_str_to_opt s = + if (String.length s > 0) then + Some (int_of_string s) + else + None + in + if (m) then + let range_start = Str.matched_group 1 range_str in + let range_end = Str.matched_group 2 range_str in + (r_str_to_opt range_start, + r_str_to_opt range_end) + else + failwith @@ "Unsupported range string: " ^ range_str + with + | e -> + failwith @@ + "Failed to parse range headers: " ^ (Printexc.to_string e) + end + | None -> (None, None) + + method private range_resp (range_start, range_end) bundle = + let bundle_bytes = String.to_bytes bundle in + let total = Bytes.length bundle_bytes in + let b_start = Option.value ~default:0 range_start in + let b_end = Option.value ~default:total range_end in + let bytes_trunc = Bytes.sub bundle_bytes b_start (b_end-b_start) in + (bytes_trunc, (b_start, b_end, total)) + method private download_bundle_handler req = let vsn = Router.param req "vsn" in + let range = self#extract_range_bytes req in let bundle = Hashtbl.find_opt !state.available_bundles vsn in let resp = match bundle with - | Some bund -> Response.of_string_body bund + | Some bund -> begin + match range with + | (None, None) -> Response.of_string_body bund + | _ -> + let (bundle_trunc, (b_start, b_end, b_total)) = + self#range_resp range bund in + let body = bundle_trunc + |> Bytes.to_string + |> Body.of_string + in + let headers = Cohttp.Header.of_list + [( + "Content-Range", + (Format.sprintf + "bytes %d-%d/%d" + b_start b_end b_total + ) + )] + in + Response.create + ~headers + ~body + () + end | None -> Response.of_string_body ~code:`Not_found "Bundle version not found" in @@ -127,7 +189,13 @@ let run_test_case ?(proxy = NoProxy) switch f = let (proxy_url, base_url) = process_proxy_spec proxy (Uri.of_string server_url) in let get_proxy () = Lwt.return proxy_url in - let module DepsI = (val Update_client.make_deps get_proxy base_url) in + let temp_dir = Format.sprintf "%s/upd-client-test-%d" + (Filename.get_temp_dir_name ()) + (Unix.gettimeofday () |> fun x -> (x *. 1000.0) |> int_of_float) + in + let () = Sys.mkdir temp_dir 0o777 in + let module DepsI = (val Update_client.make_deps + ~download_dir:temp_dir get_proxy base_url) in let module UpdateC = Update_client.Make (DepsI) in f server (module UpdateC : S) @@ -157,6 +225,39 @@ let test_download_bundle_ok server (module Client : S) = bundle; return () + +(* NOTE: This test checks that the client resumes the download + from where it finished, but also it is an example of why naive + resuming might not be a great idea.*) +let test_resume_bundle_download server (module Client : S) = + let version = "1.0.0" in + let bundle_contents = "BUNDLE_CONTENTS: 123" in + let () = server#add_bundle version bundle_contents in + let%lwt bundle_path = Client.download version in + Alcotest.(check string) + "Bundle contents are only partial" + (read_file bundle_path) + bundle_contents; + + (* NOTE that bundle_contents is not a prefix of bundle_contents_extra ! + This is on purpose: to check that download client does not simply + overwrite the downloaded file, otherwise we would not be testing + whether it really resumes the downloaded. It also illustrates + that curl / HTTP range request do not involve any integrity checking, + bytes are just being appended to the end. + *) + let bundle_contents_extra = "BUNDLE_CONTENTS: 111999" in + let () = server#add_bundle version bundle_contents_extra in + + let%lwt bundle_path = Client.download version in + Alcotest.(check string) + "Bundle contents are resumed, not overwritten" + (read_file bundle_path) + (* NOTE: this is not the same as [bundle_contents_extra], it is only + the last bytes of it beyond the length of [bundle_contents] *) + "BUNDLE_CONTENTS: 123999"; + return () + (* invalid proxy URL is set in the `run_test_case` function, see below *) let test_invalid_proxy_fail _ (module Client : S) = Lwt.try_bind Client.get_latest_version @@ -180,6 +281,7 @@ let () = let test_cases = [ ("Get latest version", test_get_version_ok); ("Download bundle", test_download_bundle_ok); + ("Resume download works", test_resume_bundle_download); ] in (* An extra case to check that proxy settings are honored in general *) let invalid_proxy_case = Alcotest_lwt.test_case