diff --git a/.gitarchive-info b/.gitarchive-info new file mode 100644 index 0000000..83e5b86 --- /dev/null +++ b/.gitarchive-info @@ -0,0 +1,2 @@ +Changeset: $Format:%H$ +Commit date: $Format:%cD$ diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..f7bf506 --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +.gitarchive-info export-subst diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000..a02a503 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,49 @@ +name: Build and test + +on: + push: + pull_request: + +jobs: + ocaml-test: + name: Ocaml tests + runs-on: ubuntu-20.04 + env: + package: "vhd-tool" + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Pull configuration from xs-opam + run: | + curl --fail --silent https://raw.githubusercontent.com/xapi-project/xs-opam/master/tools/xs-opam-ci.env | cut -f2 -d " " > .env + + - name: Load environment file + id: dotenv + uses: falti/dotenv-action@v0.2.5 + + - name: Use ocaml + uses: avsm/setup-ocaml@v1 + with: + ocaml-version: ${{ steps.dotenv.outputs.ocaml_version_full }} + opam-repository: ${{ steps.dotenv.outputs.repository }} + + - name: Update opam metadata + run: | + opam update + opam pin add . --no-action + + - name: Install external dependencies + run: opam depext -u ${{ env.package }} + + - name: Install dependencies + run: | + opam upgrade + opam install ${{ env.package }} --deps-only --with-test -v + + - name: Build + run: opam exec -- make + + - name: Run tests + run: opam exec -- make test diff --git a/.gitignore b/.gitignore index 7ea67ab..69cbf2c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,19 +1,6 @@ -*.annot -*.cmo -*.cma -*.cmi -*.a -*.o -*.cmx -*.cmxs -*.cmxa _build -*.native -*.swp -setup.data -setup.log -*.install -vhd-tool -sparse_dd -*.1 +.merlin config.mk +*.install + +scripts/*.pyc diff --git a/.ocamlformat b/.ocamlformat new file mode 100644 index 0000000..f865227 --- /dev/null +++ b/.ocamlformat @@ -0,0 +1,9 @@ +profile=ocamlformat +indicate-multiline-delimiters=closing-on-separate-line +if-then-else=fit-or-vertical +dock-collection-brackets=true +break-struct=natural +break-separators=before +break-infix=fit-or-vertical +break-infix-before-func=false +sequence-blank-line=preserve-one diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 4b211f6..0000000 --- a/.travis.yml +++ /dev/null @@ -1,12 +0,0 @@ -language: c -services: docker -install: - - wget https://raw.githubusercontent.com/xenserver/xenserver-build-env/master/utils/travis-build-repo.sh -script: bash travis-build-repo.sh -sudo: true -env: - global: - - REPO_PACKAGE_NAME=vhd-tool - - REPO_CONFIGURE_CMD=./configure - - REPO_BUILD_CMD=make - - REPO_TEST_CMD=true diff --git a/Makefile b/Makefile index c6272ad..f230525 100644 --- a/Makefile +++ b/Makefile @@ -1,54 +1,31 @@ +PROFILE=dev +.PHONY: build release lint test install uninstall clean reindent + +release: + dune build -p vhd-tool + +build: + dune build --profile=$(PROFILE) + +lint: + pycodestyle scripts/*.py + pylint --disable too-many-locals scripts/get_nbd_extents.py + pylint --disable fixme,too-many-arguments,too-many-instance-attributes scripts/python_nbd_client.py + +test: + dune runtest + +stresstest: + dune build @stresstest + +install: + dune install -p vhd-tool -include config.mk - -build: setup.data - rm -f configure.cmo configure.cmi - ocaml setup.ml -build - rm -f vhd-tool - ln -s main.native vhd-tool - ./vhd-tool --help=groff > vhd-tool.1 -ifeq ($(ENABLE_XENSERVER), "--enable-xenserver") - rm -f sparse_dd - ln -s sparse_dd.native sparse_dd - ./sparse_dd --help=groff > sparse_dd.1 -endif - -setup.data: setup.ml - rm -f configure.cmo configure.cmi - ocaml setup.ml -configure ${ENABLE_XENSERVER} - -setup.ml: _oasis - oasis setup - -.PHONY: clean -clean: setup.data - rm -f configure.cmo configure.cmi - ocaml setup.ml -clean - rm -f vhd-tool - rm -f sparse_dd - -install: build - mkdir -p ${BINDIR} - install -m 755 main.native ${BINDIR}/vhd-tool || echo "Failed to install vhd-tool" -ifeq ($(ENABLE_XENSERVER), "--enable-xenserver") - mkdir -p ${LIBEXECDIR} - install -m 755 sparse_dd.native ${LIBEXECDIR}/sparse_dd || echo "Failed to install sparse_dd" - mkdir -p ${ETCDIR} - install -m 644 src/sparse_dd.conf ${ETCDIR}/sparse_dd.conf || echo "Failed to install sparse_dd.conf" -endif - -.PHONY: uninstall uninstall: - rm -f ${BINDIR}/vhd-tool -ifeq ($(ENABLE_XENSERVER), "--enable-xenserver") - rm -f ${LIBEXECDIR}/sparse_dd - rm -f ${ETCDIR}/sparse_dd.conf -endif - -config.mk: - @echo Running './configure' with the defaults - ./configure - -.PHONY: distclean -distclean: clean - rm -f config.mk + dune uninstall -p vhd-tool + +clean: + dune clean + +format: + dune build @fmt --auto-promote diff --git a/README.md b/README.md index ffddaae..85106f3 100644 --- a/README.md +++ b/README.md @@ -1,70 +1,83 @@ -vhd-tool -======== +# vhd-tool -Command-line tools to manipulate, transcode and stream +Command-line tools to manipulate, transcode and stream [vhd](http://en.wikipedia.org/wiki/VHD_(file_format)) format data. -Basic command-line tool examples --------------------------------- +## Basic command-line tool examples To create an empty dynamic (i.e. grows on demand) vhd: -``` + +```sh vhd-tool create filename.vhd --size 16GiB ``` To create an empty difference vhd: -``` + +```sh vhd-tool create filename.vhd --size 16GiB --parent otherfile.vhd ``` To query all the parameters of a vhd: -``` + +```sh vhd-tool info filename.vhd ``` To query a specific parameter: -``` + +```sh vhd-tool get filename.vhd current-size ``` -Example: incremental backup ---------------------------- +## Example: incremental backup (This is a work in progress) -When running VMs on a hypervisor like [XenServer](http://www.xenserver.org/), it's important to have a backup strategy for your important virtual disks. One possibility is to perform periodic disk snapshots and archive the "deltas" (or differences) between the new snapshot and the last. +When running VMs on a hypervisor like [XenServer](http://www.xenserver.org/), +it is important to have a backup strategy for your important virtual disks. +One possibility is to perform periodic disk snapshots and archive the "deltas" +(or differences) between the new snapshot and the last. -First take a snapshot: this will be the first backup: -``` +First take a snapshot (this will be the first backup): + +```sh xe vdi-snapshot uuid= ``` + Next download the snapshot as a single .vhd: -``` + +```sh xe vdi-export uuid= ``` + This will print a filename to the terminal. Periodically (e.g. from cron), perform a new snapshot: -``` + +```sh xe vdi-snapshot uuid= ``` + Next download the differences from a previous snapshot as a single .vhd: -``` + +```sh xe vdi-export uuid= relative-to= ``` -Next, to avoid using too much disk space, count the number of snapshots and delete the oldest if you have too many: -``` + +Next, to avoid using too much disk space, count the number of snapshots and +delete the oldest if you have too many: + +```sh xe vdi-destroy uuid= vhd-tool commit filename.vhd --into older.vhd ``` To restore a backup onto a fresh system use: -``` + +```sh vhd-tool stream --source filename.vhd --source-format vhd --destination http://user:password@xenserver/import_vdi --destination-format vhd --progress ``` - - diff --git a/_oasis b/_oasis deleted file mode 100644 index c081c5a..0000000 --- a/_oasis +++ /dev/null @@ -1,38 +0,0 @@ -OASISFormat: 0.4 -Name: vhd-tool -Version: 0.8.0 -Synopsis: .vhd file manipulation -Authors: Jonathan Ludlam, David Scott, Anil Madhavapeddy, Cheng Sun, Euan Harris, John Else, Mike McClurg, Phus Lu, Si Beaumont, Thomas Sanders -License: LGPL-2.1 with OCaml linking exception -Plugins: META (0.4) -BuildTools: ocamlbuild - -Flag xenserver - Default: false - -Executable "vhd-tool" - CompiledObject: best - Path: src - MainIs: main.ml - Custom: true - Install: false - BuildDepends: lwt, lwt.unix, lwt.syntax, lwt.preemptive, threads, vhd-format, vhd-format.lwt, cmdliner, nbd, nbd.lwt, uri, cohttp (>= 0.12.0), cohttp.lwt, tar, sha, sha.sha1, io-page.unix, threads, tapctl, re.str - CSources: sendfile64_stubs.c - -Executable "sparse_dd" - Build$: flag(xenserver) - CompiledObject: best - Path: src - MainIs: sparse_dd.ml - Custom: true - Install: false - BuildDepends: lwt, lwt.unix, lwt.syntax, lwt.preemptive, threads, vhd-format, vhd-format.lwt, cmdliner, nbd, nbd.lwt, uri, cohttp (>= 0.12.0), cohttp.lwt, xenstore, xenstore.client, xenstore.unix, xenstore_transport, xenstore_transport.unix, threads, tapctl, xcp, sha, sha.sha1, tar, io-page.unix, re.str - CSources: sendfile64_stubs.c - -Executable get_vhd_vsize - CompiledObject: best - Path: src - MainIs: get_vhd_vsize.ml - Custom: true - Install: false - BuildDepends: lwt, lwt.unix, vhd-format, vhd-format.lwt, cstruct, io-page.unix, threads diff --git a/_tags b/_tags deleted file mode 100644 index 2cc5589..0000000 --- a/_tags +++ /dev/null @@ -1,4 +0,0 @@ -# OASIS_START -# OASIS_STOP -: syntax_camlp4o, pkg_cstruct.syntax -: syntax_camlp4o, pkg_lwt.syntax diff --git a/cli/dune b/cli/dune new file mode 100644 index 0000000..3f3a062 --- /dev/null +++ b/cli/dune @@ -0,0 +1,27 @@ +(executables + (modes byte exe) + (names main sparse_dd get_vhd_vsize) + (public_names vhd-tool sparse_dd get_vhd_vsize) + (libraries local_lib cstruct)) + +(rule + (targets vhd-tool.1) + (deps + (:x main.exe)) + (action + (with-stdout-to + %{targets} + (run %{x} --help=groff)))) + +(rule + (targets sparse_dd.1) + (deps + (:x sparse_dd.exe)) + (action + (with-stdout-to + %{targets} + (run %{x} --help)))) + +(install + (section man) + (files vhd-tool.1 sparse_dd.1)) diff --git a/cli/get_vhd_vsize.ml b/cli/get_vhd_vsize.ml new file mode 100644 index 0000000..7215f50 --- /dev/null +++ b/cli/get_vhd_vsize.ml @@ -0,0 +1,29 @@ +module Impl = Vhd_format.F.From_file (Vhd_format_lwt.IO) +open Vhd_format.F +open Vhd_format_lwt.IO +module In = From_input (Input) +open In + +let get_vhd_vsize filename = + Vhd_format_lwt.IO.openfile filename false >>= fun fd -> + let rec loop = function + | End -> + return () + | Cons (hd, tl) -> + ( match hd with + | Fragment.Footer x -> + let size = x.Footer.current_size in + Printf.printf "%Ld\n" size ; exit 0 + | _ -> + () + ) ; + tl () >>= fun x -> loop x + in + Vhd_format_lwt.IO.get_file_size filename >>= fun file_size -> + openstream (Some file_size) (Input.of_fd (Vhd_format_lwt.IO.to_file_descr fd)) + >>= fun stream -> + loop stream >>= fun () -> Vhd_format_lwt.IO.close fd + +let _ = + let t = get_vhd_vsize Sys.argv.(1) in + Lwt_main.run t diff --git a/cli/main.ml b/cli/main.ml new file mode 100644 index 0000000..1b75df0 --- /dev/null +++ b/cli/main.ml @@ -0,0 +1,380 @@ +(* + * Copyright (C) 2011-2013 Citrix Inc + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + *) + +let project_url = "http://github.com/djs55/ocaml-vhd" + +open Cmdliner + +(* Help sections common to all commands *) + +let _common_options = "COMMON OPTIONS" + +let help = + [ + `S _common_options + ; `P "These options are common to all commands." + ; `S "MORE HELP" + ; `P "Use `$(mname) $(i,COMMAND) --help' for help on a single command." + ; `Noblank + ; `S "BUGS" + ; `P (Printf.sprintf "Check bug reports at %s" project_url) + ] + +(* Options common to all commands *) +let common_options_t = + let docs = _common_options in + let debug = + let doc = "Give only debug output." in + Arg.(value & flag & info ["debug"] ~docs ~doc) + in + let verb = + let doc = "Give verbose output." in + let verbose = (true, Arg.info ["v"; "verbose"] ~docs ~doc) in + Arg.(last & vflag_all [false] [verbose]) + in + let unbuffered = + let doc = "Use unbuffered I/O." in + Arg.(value & flag & info ["unbuffered"; "direct"] ~docs ~doc) + in + let search_path = + let doc = "Search path for vhds." in + Arg.(value & opt string "." & info ["path"] ~docs ~doc) + in + Term.(pure Common.make $ debug $ verb $ unbuffered $ search_path) + +let get_cmd = + let doc = "query vhd metadata" in + let man = + [ + `S "DESCRIPTION" + ; `P "Look up a particular metadata property by name and print the value." + ] + @ help + in + let filename = + let doc = Printf.sprintf "Path to the vhd file." in + Arg.(value & pos 0 (some file) None & info [] ~doc) + in + let key = + let doc = "Key to query" in + Arg.(value & pos 1 (some string) None & info [] ~doc) + in + ( Term.(ret (pure Impl.get $ common_options_t $ filename $ key)) + , Term.info "get" ~sdocs:_common_options ~doc ~man ) + +let filename = + let doc = Printf.sprintf "Path to the vhd file." in + Arg.(value & pos 0 (some file) None & info [] ~doc) + +let info_cmd = + let doc = "display general information about a vhd" in + let man = + [ + `S "DESCRIPTION" + ; `P + "Display general information about a vhd, including header and footer \ + fields. This won't directly display block allocation tables or sector \ + bitmaps." + ] + @ help + in + ( Term.(ret (pure Impl.info $ common_options_t $ filename)) + , Term.info "info" ~sdocs:_common_options ~doc ~man ) + +let contents_cmd = + let doc = "display the contents of the vhd" in + let man = + [ + `S "DESCRIPTION" + ; `P + "Display the contents of the vhd: headers, metadata and data blocks. \ + Everything is displayed in the order it appears in the vhd file, not \ + the order it appears in the virtual disk image itself." + ] + @ help + in + ( Term.(ret (pure Impl.contents $ common_options_t $ filename)) + , Term.info "contents" ~sdocs:_common_options ~doc ~man ) + +let create_cmd = + let doc = "create a dynamic vhd" in + let man = + [ + `S "DESCRIPTION" + ; `P + "Create a dynamic vhd (i.e. one which may be sparse). A dynamic vhd \ + may be self-contained or it may have a backing-file or 'parent'." + ] + @ help + in + let filename = + let doc = Printf.sprintf "Path to the vhd file to be created." in + Arg.(value & pos 0 (some string) None & info [] ~doc) + in + let size = + let doc = Printf.sprintf "Virtual size of the disk." in + Arg.(value & opt (some string) None & info ["size"] ~doc) + in + let parent = + let doc = Printf.sprintf "Parent image" in + Arg.(value & opt (some file) None & info ["parent"] ~doc) + in + ( Term.(ret (pure Impl.create $ common_options_t $ filename $ size $ parent)) + , Term.info "create" ~sdocs:_common_options ~doc ~man ) + +let check_cmd = + let doc = "check the structure of a vhd file" in + let man = + [ + `S "DESCRIPTION" + ; `P + "Check the structure of a vhd file is valid, print any errors on the \ + console." + ] + @ help + in + let filename = + let doc = Printf.sprintf "Path to the vhd to be checked." in + Arg.(value & pos 0 (some file) None & info [] ~doc) + in + ( Term.(ret (pure Impl.check $ common_options_t $ filename)) + , Term.info "check" ~sdocs:_common_options ~doc ~man ) + +let source = + let doc = Printf.sprintf "The source disk" in + Arg.(value & opt string "stdin:" & info ["source"] ~doc) + +let source_fd = + let doc = + Printf.sprintf "An open-file descriptor pointing to the source disk" + in + Arg.(value & opt (some int) None & info ["source-fd"] ~doc) + +let source_format = + let doc = "Source format" in + Arg.(value & opt string "raw" & info ["source-format"] ~doc) + +let source_protocol = + let doc = "Transport protocol for the source data." in + Arg.(value & opt (some string) None & info ["source-protocol"] ~doc) + +let destination = + let doc = "Destination for streamed data." in + Arg.(value & opt string "stdout:" & info ["destination"] ~doc) + +let destination_fd = + let doc = "Write data to a file descriptor." in + Arg.(value & opt (some int) None & info ["destination-fd"] ~doc) + +let destination_format = + let doc = "Destination format" in + Arg.(value & opt string "raw" & info ["destination-format"] ~doc) + +let destination_size = + let doc = "Size of the destination disk" in + Arg.(value & opt (some int64) None & info ["destination-size"] ~doc) + +let prezeroed = + let doc = "Assume the destination is completely empty." in + Arg.(value & flag & info ["prezeroed"] ~doc) + +let progress = + let doc = "Display a progress bar." in + Arg.(value & flag & info ["progress"] ~doc) + +let machine = + let doc = "Machine readable output." in + Arg.(value & flag & info ["machine"] ~doc) + +let tar_filename_prefix = + let doc = "Filename prefix for tar/sha disk blocks" in + Arg.(value & opt (some string) None & info ["tar-filename-prefix"] ~doc) + +let good_ciphersuites = + let doc = "The list of ciphersuites to allow for TLS" in + Arg.(value & opt (some string) None & info ["good-ciphersuites"] ~doc) + +let serve_cmd = + let doc = "serve the contents of a disk" in + let man = + [ + `S "DESCRIPTION" + ; `P + "Allow the contents of a disk to be read or written over a network \ + protocol" + ; `P "EXAMPLES" + ; `P + " vhd-tool serve --source fd:5 --source-protocol=chunked --destination \ + file:///foo.raw --destination-format raw" + ; `P + " vhd-tool serve --source fd:5 --source-protocol=nbd --destination \ + file:///foo.raw --destination-format raw" + ; `P + " vhd-tool serve --source fd:5 --source-format=vhd \ + --source-protocol=none --destination file:///foo.raw \ + --destination-format raw" + ] + in + let ignore_checksums = + let doc = "Do not verify checksums" in + Arg.(value & flag & info ["ignore-checksums"] ~doc) + in + ( Term.( + ret + (pure Impl.serve + $ common_options_t + $ source + $ source_fd + $ source_format + $ source_protocol + $ destination + $ destination_fd + $ destination_format + $ destination_size + $ prezeroed + $ progress + $ machine + $ tar_filename_prefix + $ ignore_checksums + )) + , Term.info "serve" ~sdocs:_common_options ~doc ~man ) + +let stream_cmd = + let doc = "stream the contents of a vhd disk" in + let man = + [ + `S "DESCRIPTION" + ; `P + "Read the contents of a virtual disk from a source using (format, \ + protocol) and write it out to a destination using another (format, \ + protocol). This command allows disks to be uploaded, downloaded and \ + format-converted in a space-efficient manner." + ; `S "FORMATS" + ; `P + "The input format and the output format are specified separately: this \ + allows easy format conversion during the streaming process. The \ + following formats are defined:" + ; `P " raw: a single flat image" + ; `P " vhd: the Virtual Hard Disk format used in XenServer" + ; `P + "Note: the vhd format supports both self-contained single file images \ + and also \"differencing disks\" containing only the differences \ + between two disks. To input only the differences between two disks, \ + specify the reference disk with the \"--relative-to\" argument." + ; `S "PROTOCOLS" + ; `P + "Protocols are the means by which a disk image in a particular format \ + is written to a particular destination. The following protocols are \ + supported:" + ; `P " nbd: the Network Block Device protocol" + ; `P " chunked: the XenServer chunked disk upload protocol" + ; `P " none: unencoded write" + ; `P " tar: the XenServer import/export encoding using tar" + ; `P " human: human-readable description of the contents" + ; `P "The default behaviour is to auto-detect based on the destination." + ; `S "SOURCES and DESTINATIONS" + ; `P + "The source describes where the disk data comes from. The destination \ + describes where the disk data is written to. The following are \ + defined:" + ; `P " stdin:" + ; `P " read from standard input (input only)" + ; `P " stdout:" + ; `P " write to standard output (destination only)" + ; `P " fd:5" + ; `P " read and write from file descriptor 5" + ; `P " " + ; `P " read from or write to the file " + ; `P " unix://" + ; `P " connect to the Unix domain socket" + ; `P " tcp://server:port/path" + ; `P " to issue an HTTP PUT to server:port/path" + ; `P " tcp://host:port/" + ; `P " to connect to TCP port 'port' on host 'host'" + ; `S "OTHER OPTIONS" + ; `P + "When transferring a raw format image onto a medium which is \ + completely empty (i.e. full of zeroes) it is possible to optimise the \ + transfer by avoiding writing empty blocks. The default behaviour is \ + to write zeroes, which is always safe. If you know your media is \ + empty then supply the '--prezeroed' argument." + ; `P + "When running interactively, the --progress argument will cause a \ + progress bar and summary statistics to be printed." + ; `P + "When generating a tar/sha stream, the --tar-filename-prefix will be \ + prefixed onto each disk data block. This is typically used to place \ + the disk blocks of separate disks in different directories." + ; `S "NOTES" + ; `P + "Not all protocols can be used with all destinations. For example the \ + NBD protocol needs the ability to read (responses) and write \ + (requests); it therefore will not work with the stdout: destination" + ; `S "EXAMPLES" + ; `P + " $(tname) stream --source=foo.vhd --source-format=vhd \ + --destination-format=raw \ + --destination=http://user:password@xenserver/import_raw_vdi?vdi=" + ] + @ help + in + let source = + let doc = Printf.sprintf "The disk to be streamed" in + Arg.(value & opt string "stdin:" & info ["source"] ~doc) + in + let relative_to = + let doc = "Output only differences from the given reference disk" in + Arg.(value & opt (some file) None & info ["relative-to"] ~doc) + in + let destination_protocol = + let doc = "Transport protocol for the destination data." in + Arg.(value & opt (some string) None & info ["destination-protocol"] ~doc) + in + let stream_args_t = + Term.( + pure StreamCommon.make + $ source + $ relative_to + $ source_format + $ destination_format + $ destination + $ destination_fd + $ source_protocol + $ destination_protocol + $ prezeroed + $ progress + $ machine + $ tar_filename_prefix + $ good_ciphersuites) + in + ( Term.(ret (pure Impl.stream $ common_options_t $ stream_args_t)) + , Term.info "stream" ~sdocs:_common_options ~doc ~man ) + +let default_cmd = + let doc = "manipulate virtual disks stored in vhd files" in + let man = help in + ( Term.(ret (pure (fun _ -> `Help (`Pager, None)) $ common_options_t)) + , Term.info "vhd-tool" ~version:"1.0.0" ~sdocs:_common_options ~doc ~man ) + +let cmds = + [ + info_cmd; contents_cmd; get_cmd; create_cmd; check_cmd; serve_cmd; stream_cmd + ] + +let _ = + match Term.eval_choice default_cmd cmds with + | `Error _ -> + exit 1 + | _ -> + exit 0 diff --git a/src/sparse_dd.conf b/cli/sparse_dd.conf similarity index 100% rename from src/sparse_dd.conf rename to cli/sparse_dd.conf diff --git a/cli/sparse_dd.ml b/cli/sparse_dd.ml new file mode 100644 index 0000000..b72db2e --- /dev/null +++ b/cli/sparse_dd.ml @@ -0,0 +1,498 @@ +(* Utility program which copies between two block devices, using vhd BATs and efficient zero-scanning + for performance. *) + +module D = Debug.Make (struct let name = "sparse_dd" end) + +open D + +let config_file = "/etc/sparse_dd.conf" + +let vhd_search_path = "/dev/mapper" + +let ionice_cmd = "/usr/bin/ionice" + +let renice_cmd = "/usr/bin/renice" + +type encryption_mode = Always | Never | User + +let string_of_encryption_mode = function + | Always -> + "always" + | Never -> + "never" + | User -> + "user" + +let encryption_mode_of_string = function + | "always" -> + Always + | "never" -> + Never + | "user" -> + User + | x -> + failwith + (Printf.sprintf "Unknown encryption mode %s. Use always, never or user." + x) + +let encryption_mode = ref User + +(* Niceness: strings that may or may not be valid ints. *) +let nice = ref None + +let ionice_class = ref None + +let ionice_class_data = ref None + +let base = ref None + +let src = ref None + +let dest = ref None + +let size = ref (-1L) + +let prezeroed = ref false + +let set_machine_logging = ref false + +let experimental_reads_bypass_tapdisk = ref false + +let experimental_writes_bypass_tapdisk = ref false + +let ssl_legacy = ref false + +let good_ciphersuites = ref None + +let legacy_ciphersuites = ref None + +let string_opt = function None -> "None" | Some x -> x + +let machine_readable_progress = ref false + +let options = + let str_option name var_ref description = + ( name + , Arg.String (fun x -> var_ref := Some x) + , (fun () -> string_opt !var_ref) + , description ) + in + [ + ( "unbuffered" + , Arg.Bool (fun b -> Vhd_format_lwt.File.use_unbuffered := b) + , (fun () -> string_of_bool !Vhd_format_lwt.File.use_unbuffered) + , "use unbuffered I/O via O_DIRECT" ) + ; ( "encryption-mode" + , Arg.String (fun x -> encryption_mode := encryption_mode_of_string x) + , (fun () -> string_of_encryption_mode !encryption_mode) + , "how to use encryption" ) + ; (* Want to ignore bad values for "nice" etc. so not using Arg.Int *) + str_option "nice" nice + "If supplied, the scheduling priority will be set using this value as \ + argument to the 'nice' command." + ; str_option "ionice_class" ionice_class + "If supplied, the io scheduling class will be set using this value as -c \ + argument to the 'ionice' command." + ; str_option "ionice_class_data" ionice_class_data + "If supplied, the io scheduling class data will be set using this value \ + as -n argument to the 'ionice' command." + ; ( "experimental-reads-bypass-tapdisk" + , Arg.Set experimental_reads_bypass_tapdisk + , (fun () -> string_of_bool !experimental_reads_bypass_tapdisk) + , "bypass tapdisk and read directly from the underlying vhd file" ) + ; ( "experimental-writes-bypass-tapdisk" + , Arg.Set experimental_writes_bypass_tapdisk + , (fun () -> string_of_bool !experimental_writes_bypass_tapdisk) + , "bypass tapdisk and write directly to the underlying vhd file" ) + ; ( "base" + , Arg.String (fun x -> base := Some x) + , (fun () -> string_opt !base) + , "base disk to search for differences from" ) + ; ( "src" + , Arg.String (fun x -> src := Some x) + , (fun () -> string_opt !src) + , "source disk" ) + ; ( "dest" + , Arg.String (fun x -> dest := Some x) + , (fun () -> string_opt !dest) + , "destination disk" ) + ; ( "size" + , Arg.String (fun x -> size := Int64.of_string x) + , (fun () -> Int64.to_string !size) + , "number of bytes to copy" ) + ; ( "prezeroed" + , Arg.Set prezeroed + , (fun () -> string_of_bool !prezeroed) + , "assume the destination disk has been prezeroed" ) + ; ( "machine" + , Arg.Set machine_readable_progress + , (fun () -> string_of_bool !machine_readable_progress) + , "emit machine-readable output" ) + ; ( "ssl-legacy" + , Arg.Set ssl_legacy + , (fun () -> string_of_bool !ssl_legacy) + , " for TLS, allow all protocol versions instead of just TLSv1.2" ) + ; ( "good-ciphersuites" + , Arg.String (fun x -> good_ciphersuites := Some x) + , (fun () -> string_opt !good_ciphersuites) + , " the list of ciphersuites to allow for TLS" ) + ; ( "legacy-ciphersuites" + , Arg.String (fun x -> legacy_ciphersuites := Some x) + , (fun () -> string_opt !legacy_ciphersuites) + , " additional TLS ciphersuites allowed only if ssl-legacy is set" ) + ] + +let ( +* ) = Int64.add + +let ( -* ) = Int64.sub + +let ( ** ) = Int64.mul + +let kib = 1024L + +let mib = kib ** kib + +let startswith prefix x = + let prefix' = String.length prefix and x' = String.length x in + prefix' <= x' && String.sub x 0 prefix' = prefix + +module Opt = struct let default d = function None -> d | Some x -> x end + +module Mutex = struct + include Mutex + + let execute m f = + Mutex.lock m ; + try + let result = f () in + Mutex.unlock m ; result + with e -> Mutex.unlock m ; raise e +end + +module Progress = struct + let header = Cstruct.create Chunked.sizeof + + (** Report progress complete to another program reading stdout *) + let report fraction = + if !machine_readable_progress then ( + let s = Printf.sprintf "Progress: %.0f" (fraction *. 100.) in + let data = Cstruct.create (String.length s) in + Cstruct.blit_from_string s 0 data 0 (String.length s) ; + Chunked.marshal header {Chunked.offset= 0L; data} ; + Printf.printf "%s%s%!" (Cstruct.to_string header) s + ) + + (** Emit the end-of-stream message *) + let close () = + if !machine_readable_progress then ( + let header = Cstruct.create Chunked.sizeof in + Chunked.marshal header {Chunked.offset= 0L; data= Cstruct.create 0} ; + Printf.printf "%s%!" (Cstruct.to_string header) + ) +end + +let after f g = + try + let r = f () in + g () ; r + with e -> g () ; raise e + +(** [find_backend_device path] returns [Some path'] where [path'] is the backend path in + the driver domain corresponding to the frontend device [path] in this domain. *) +let find_backend_device path = + try + let open Xenstore in + (* If we're looking at a xen frontend device, see if the backend + is in the same domain. If so check if it looks like a .vhd *) + let rdev = (Unix.LargeFile.stat path).Unix.LargeFile.st_rdev in + let major = rdev / 256 and minor = rdev mod 256 in + let link = + Unix.readlink (Printf.sprintf "/sys/dev/block/%d:%d/device" major minor) + in + match List.rev (Re.Str.split (Re.Str.regexp_string "/") link) with + | id :: "xen" :: "devices" :: _ when startswith "vbd-" id -> + let id = int_of_string (String.sub id 4 (String.length id - 4)) in + with_xs (fun xs -> + let self = xs.Xs.read "domid" in + let backend = + xs.Xs.read (Printf.sprintf "device/vbd/%d/backend" id) + in + let params = xs.Xs.read (Printf.sprintf "%s/params" backend) in + match Re.Str.split (Re.Str.regexp_string "/") backend with + | "local" :: "domain" :: bedomid :: _ -> + assert (self = bedomid) ; + Some params + | _ -> + raise Not_found) + | _ -> + raise Not_found + with _ -> None + +let with_paused_tapdisk path f = + let path = find_backend_device path |> Opt.default path in + let context = Tapctl.create () in + match Tapctl.of_device context path with + | tapdev, _, Some (_driver, path) -> + debug "pausing tapdisk for %s" path ; + Tapctl.pause context tapdev ; + after f (fun () -> + debug "unpausing tapdisk for %s" path ; + Tapctl.unpause context tapdev path Tapctl.Vhd) + | _, _, _ -> + failwith (Printf.sprintf "Failed to pause tapdisk for %s" path) + +let deref_symlinks path = + let rec inner seen_already path = + if List.mem path seen_already then failwith "Circular symlink" ; + let stats = Unix.LargeFile.lstat path in + if stats.Unix.LargeFile.st_kind = Unix.S_LNK then + inner (path :: seen_already) (Unix.readlink path) + else + path + in + inner [] path + +(* Record when the binary started for performance measuring *) +let start = Unix.gettimeofday () + +(* Helper function to print nice progress info *) +let progress_cb = + let last_percent = ref (-1) in + function + | fraction -> + let new_percent = int_of_float (fraction *. 100.) in + if !last_percent <> new_percent then Progress.report fraction ; + if !last_percent / 10 <> new_percent / 10 then + debug "progress %d%%" new_percent ; + last_percent := new_percent + +let _ = + Vhd_format_lwt.File.use_unbuffered := true ; + Xcp_service.configure ~options () ; + let src = + match !src with + | None -> + debug "Must have -src argument\n" ; + exit 1 + | Some x -> + x + in + let dest = + match !dest with + | None -> + debug "Must have -dest argument\n" ; + exit 1 + | Some x -> + x + in + if !size = -1L then ( + debug "Must have -size argument\n" ; + exit 1 + ) ; + let size = !size in + let base = !base in + (* Helper function to bring an int into valid range *) + let clip v min max descr = + if v < min then ( + warn "Value %d is too low for %s. Using %d instead." v descr min ; + min + ) else if v > max then ( + warn "Value %d is too high for %s. Using %d instead." v descr max ; + max + ) else + v + in + (let parse_as_int str_opt int_opt_ref opt_name = + match str_opt with + | None -> + () + | Some str -> ( + try int_opt_ref := Some (int_of_string str) + with _ -> error "Ignoring invalid value for %s: %s" opt_name str + ) + in + (* renice this process if specified *) + let n_ref = ref None in + parse_as_int !nice n_ref "nice" ; + ( match !n_ref with + | None -> + () + | Some n -> + (* Run command like: renice -n priority -p pid *) + let n = clip n (-20) 19 "nice" in + let pid = string_of_int (Unix.getpid ()) in + let _ = + Forkhelpers.execute_command_get_output renice_cmd + ["-n"; string_of_int n; "-p"; pid] + in + () + ) ; + (* Possibly run command like: ionice -c class -n classdata -p pid *) + let c_ref = ref None in + let cd_ref = ref None in + parse_as_int !ionice_class c_ref "ionice_class" ; + parse_as_int !ionice_class_data cd_ref "ionice_class_data" ; + match !c_ref with + | None -> + () + | Some c -> ( + let pid = string_of_int (Unix.getpid ()) in + let ionice args = + let _ = Forkhelpers.execute_command_get_output ionice_cmd args in + () + in + let class_only c = ionice ["-c"; string_of_int c; "-p"; pid] in + let class_and_data c n = + ionice ["-c"; string_of_int c; "-n"; string_of_int n; "-p"; pid] + in + match c with + | 0 | 3 -> + class_only c + | 1 | 2 -> ( + match !cd_ref with + | None -> + class_only c + | Some n -> + let n = clip n 0 7 "ionice classdata" in + class_and_data c n + ) + | _ -> + error "Cannot use ionice due to invalid class value: %d" c + )) ; + debug "src = %s; dest = %s; base = %s; size = %Ld" src dest + (Opt.default "None" base) size ; + let src_image = Image.of_device src in + let dest_image = Image.of_device dest in + let base_image = + match base with None -> None | Some x -> Image.of_device x + in + let to_string = function None -> "None" | Some x -> Image.to_string x in + debug "src_image = %s; dest_image = %s; base_image = %s" (to_string src_image) + (to_string dest_image) (to_string base_image) ; + (* Add the directory of the vhd to the search path *) + let vhd_search_path = + match src_image with + | Some (`Vhd x) -> + vhd_search_path ^ ":" ^ Filename.dirname x + | _ -> + vhd_search_path + in + let common = Common.make true false true vhd_search_path in + if !experimental_reads_bypass_tapdisk then + warn "experimental_reads_bypass_tapdisk set: this may cause data corruption" ; + if !experimental_writes_bypass_tapdisk then + warn + "experimental_writes_bypass_tapdisk set: this may cause data corruption" ; + let relative_to = + match base_image with + | Some (`Vhd x) -> + Some x + | Some (`Raw _) -> + None + | Some (`Nbd _) -> + None (* TODO: make delta copies work with NBD, CA-289660 *) + | None -> + None + in + let rewrite_url device_or_url = + (* Ensure device_or_url is a valid URL, and apply our encryption policy *) + let uri = Uri.of_string device_or_url in + let rewrite_scheme scheme = + let uri = + Uri.make ~scheme ?userinfo:(Uri.userinfo uri) ?host:(Uri.host uri) + ?port:(Uri.port uri) ~path:(Uri.path uri) ~query:(Uri.query uri) + ?fragment:(Uri.fragment uri) () + in + Uri.to_string uri + in + match Uri.scheme uri with + | Some "https" when !encryption_mode = Never -> + warn + "turning off encryption for this transfer as requested by config file" ; + rewrite_scheme "http" + | Some "http" when !encryption_mode = Always -> + warn + "turning on encryption for this transfer as requested by config file" ; + rewrite_scheme "https" + | Some ("http" | "https") -> + device_or_url + | _ -> + "file://" ^ device_or_url + in + let open Lwt in + let stream_t, destination, destination_format = + match + ( !experimental_reads_bypass_tapdisk + , src + , src_image + , !experimental_writes_bypass_tapdisk + , dest + , dest_image ) + with + | true, _, Some (`Vhd vhd), true, _, Some (`Vhd vhd') -> + prezeroed := false ; + (* the physical disk will have vhd metadata and other stuff on it *) + info "streaming from vhd %s (relative to %s) to vhd %s" vhd + (string_opt relative_to) vhd' ; + let t = Impl.make_stream common vhd relative_to "vhd" "vhd" in + (t, "file://" ^ vhd', "vhd") + | false, _, _, true, _, _ -> + error + "Not implemented: writes bypass tapdisk while reads go through \ + tapdisk" ; + failwith + "Not implemented: writing bypassing tapdisk while reading through \ + tapdisk" + | false, _, Some (`Vhd vhd), false, _, _ -> + let dest = rewrite_url dest in + info + "streaming from raw %s using BAT from %s (relative to %s) to raw %s" + src vhd (string_opt relative_to) dest ; + let t = + Impl.make_stream common (src ^ ":" ^ vhd) relative_to "hybrid" "raw" + in + (t, dest, "raw") + | _, _, Some (`Nbd (server, export_name)), _, _, _ -> + let dest = rewrite_url dest in + let t = + Impl.make_stream common + (src ^ ":" ^ server ^ ":" ^ export_name ^ ":" ^ Int64.to_string size) + None "nbdhybrid" "raw" + in + (t, dest, "raw") + | true, _, Some (`Vhd vhd), _, _, _ -> + let dest = rewrite_url dest in + info "streaming from vhd %s (relative to %s) to raw %s" vhd + (string_opt relative_to) dest ; + let t = Impl.make_stream common vhd relative_to "vhd" "raw" in + (t, dest, "raw") + | _, _, Some (`Raw raw), _, _, _ -> + let dest = rewrite_url dest in + info "streaming from raw %s (relative to %s) to raw %s" raw + (string_opt relative_to) dest ; + let t = Impl.make_stream common raw relative_to "raw" "raw" in + (t, dest, "raw") + | _, device, None, _, _, _ -> + let dest = rewrite_url dest in + info "streaming from raw %s (relative to %s) to raw %s" src + (string_opt relative_to) dest ; + let t = Impl.make_stream common device relative_to "raw" "raw" in + (t, dest, "raw") + in + progress_cb 0. ; + let progress total_work work_done = + let fraction = Int64.(to_float work_done /. to_float total_work) in + progress_cb fraction + in + let t = + stream_t >>= fun s -> + Impl.write_stream common s destination (Some "none") None !prezeroed + progress None !good_ciphersuites + in + if destination_format = "vhd" then + with_paused_tapdisk dest (fun () -> Lwt_main.run t) + else + Lwt_main.run t ; + let time = Unix.gettimeofday () -. start in + debug "Time: %.2f seconds" time ; + Progress.close () diff --git a/configure b/configure deleted file mode 100755 index b81f01c..0000000 --- a/configure +++ /dev/null @@ -1,7 +0,0 @@ -#!/bin/sh - -trap "rm -f /tmp/configure.$$" EXIT - -ocamlfind ocamlc -package findlib,cmdliner -linkpkg -o /tmp/configure.$$ configure.ml -/tmp/configure.$$ $* - diff --git a/configure.ml b/configure.ml deleted file mode 100644 index 2e559e5..0000000 --- a/configure.ml +++ /dev/null @@ -1,83 +0,0 @@ -let config_mk = "config.mk" - -let find_ocamlfind verbose name = - let found = - try - let (_: string) = Findlib.package_property [] name "requires" in - true - with - | Not_found -> - (* property within the package could not be found *) - true - | Findlib.No_such_package(_,_ ) -> - false in - if verbose then Printf.fprintf stderr "querying for ocamlfind package %s: %s" name (if found then "ok" else "missing"); - found - -(* Configure script *) -open Cmdliner - -let bindir = - let doc = "Set the directory for installing binaries" in - Arg.(value & opt string "/usr/bin" & info ["bindir"] ~docv:"BINDIR" ~doc) - -let libexecdir = - let doc = "Set the directory for installing helper executables" in - Arg.(value & opt string "/usr/lib/xapi" & info ["libexecdir"] ~docv:"LIBEXECDIR" ~doc) - -let etcdir = - let doc = "Set the directory for installing configuration files" in - Arg.(value & opt string "/etc" & info ["etcdir"] ~docv:"ETCDIR" ~doc) - -let info = - let doc = "Configures a package" in - Term.info "configure" ~version:"0.1" ~doc - -let output_file filename lines = - let oc = open_out filename in - let lines = List.map (fun line -> line ^ "\n") lines in - List.iter (output_string oc) lines; - close_out oc - -let configure bindir libexecdir etcdir = - - Printf.printf "Configuring with:\n\tbindir=%s\n\tlibexecdir=%s\n\tetcdir=%s\n" bindir libexecdir etcdir; - - let xcp = find_ocamlfind false "xcp" in - let xenstore_transport = find_ocamlfind false "xenstore_transport" in - let xenstore = find_ocamlfind false "xenstore" in - let tapctl = find_ocamlfind false "tapctl" in - let enable_xenserver = xcp && xenstore_transport && xenstore && tapctl in - let lines = - [ "# Warning - this file is autogenerated by the configure script"; - "# Do not edit"; - Printf.sprintf "BINDIR=%s" bindir; - Printf.sprintf "LIBEXECDIR=%s" libexecdir; - Printf.sprintf "ETCDIR=%s" etcdir; - Printf.sprintf "ENABLE_XENSERVER=--%sable-xenserver" (if enable_xenserver then "en" else "dis"); - ] in - output_file config_mk lines; - let lines = - [ "bin: ["; - " \"main.native\" { \"vhd-tool\" }"; - ] @ (if enable_xenserver - then [ " \"sparse_dd.native\" { \"sparse_dd\" }" ] - else []) @ [ - "]"; - "man: ["; - " \"vhd-tool.1\" { \"vhd-tool.1\" }"; - ] @ (if enable_xenserver - then [ " \"sparse_dd.1\" { \"sparse_dd.1\" }" ] - else []) @ [ - "]"; - ] in - output_file "vhd-tool.install" lines - -let configure_t = Term.(pure configure $ bindir $ libexecdir $ etcdir ) - -let () = - match - Term.eval (configure_t, info) - with - | `Error _ -> exit 1 - | _ -> exit 0 diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..78daeb4 --- /dev/null +++ b/dune-project @@ -0,0 +1,3 @@ +(lang dune 2.0) + +(formatting (enabled_for ocaml)) diff --git a/myocamlbuild.ml b/myocamlbuild.ml deleted file mode 100644 index a1d7e1c..0000000 --- a/myocamlbuild.ml +++ /dev/null @@ -1,3 +0,0 @@ -(* OASIS_START *) -(* OASIS_STOP *) -Ocamlbuild_plugin.dispatch dispatch_default;; diff --git a/opam b/opam deleted file mode 100644 index 575f133..0000000 --- a/opam +++ /dev/null @@ -1,28 +0,0 @@ -opam-version: "1" -maintainer: "dave.scott@eu.citrix.com" -tags: [ - "org:mirage" - "org:xapi-project" -] -build: [ - ["oasis" "setup"] - ["./configure" "--bindir=%{bin}%" "--etcdir=%{etc}%" "--libexecdir=%{bin}%"] - [make] -] -depends: [ - "ocamlfind" - "lwt" {>= "2.4.3"} - "cstruct" {>= "1.0.1"} - "vhd-format" {>= "0.7.0"} - "uuidm" - "cmdliner" - "nbd" {>= "1.0.1"} - "ounit" - "uri" - "tar-format" - "sha" - "cohttp" {>="0.12.0"} - "ssl" - "xapi-tapctl" - "oasis" {build} -] diff --git a/scripts/get_nbd_extents.py b/scripts/get_nbd_extents.py new file mode 100644 index 0000000..2968c4d --- /dev/null +++ b/scripts/get_nbd_extents.py @@ -0,0 +1,158 @@ +#!/usr/bin/python + +""" +Returns a list of extents with their block statuses for an NBD export. + +This program uses new NBD capabilities introduced in QEMU 2.12. + +It uses the BLOCK_STATUS NBD extension, which relies on the structured replies +functionality. These are documented in the NBD protocol docs: +https://github.com/NetworkBlockDevice/nbd/blob/master/doc/proto.md +""" + +import argparse +import json +import logging +import logging.handlers + +from python_nbd_client import PythonNbdClient, assert_protocol +import python_nbd_client + + +LOGGER = logging.getLogger("get_nbd_extents") +LOGGER.setLevel(logging.DEBUG) +# The log level of python_nbd_client is not set, therefore it will default to +# that of the root logger, which is WARNING by default. + +# Request length is a 32-bit int. +# It looks like for qemu 2.12, the length parameter of a NBD_CMD_BLOCK_STATUS +# request is not limited by the maximum block size supported by the server (as +# defined by NBD_INFO_BLOCK_SIZE), only by the size of a 32-bit int. +MAX_REQUEST_LEN = 2 ** 32 - 1 + +# Make the NBD_CMD_BLOCK_STATUS request aligned to 512 bytes, just in case. But +# it looks like this is not required for qemu 2.12. +MAX_REQUEST_LEN = MAX_REQUEST_LEN - (MAX_REQUEST_LEN % 512) + + +def _get_extents(path, exportname, offset, length): + with PythonNbdClient(address=path, + exportname=exportname, + unix=True, + use_tls=False, + connect=False) as client: + + client.negotiate_structured_reply() + + # Select our metadata context. This context is documented at + # https://github.com/NetworkBlockDevice/nbd/blob/master/doc/proto.md#baseallocation-metadata-context + context = 'base:allocation' + selected_contexts = client.set_meta_contexts(exportname, [context]) + assert_protocol(len(selected_contexts) == 1) + (meta_context_id, meta_context_name) = selected_contexts[0] + assert_protocol(meta_context_name == context) + + client.connect(exportname) + + size = client.get_size() + LOGGER.debug( + 'Connected to NBD export %s served at path %s of size %d bytes', + exportname, path, size) + + if (offset < 0) or (length <= 0) or ((offset + length) > size): + raise ValueError("Offset={} and length={} out of bounds: " + "export has size {}".format(offset, length, size)) + end = offset + length + while offset < end: + request_len = min(MAX_REQUEST_LEN, end - offset) + replies = client.query_block_status(offset, request_len) + + # Process the returned structured reply chunks + # "For a successful return, the server MUST use a structured reply, + # containing exactly one chunk of type NBD_REPLY_TYPE_BLOCK_STATUS + # per selected context id" + assert_protocol(len(replies) == 1) + reply = replies[0] + + # First make sure it's a block status reply + if python_nbd_client.is_error_chunk( + reply_type=reply['reply_type']): + raise Exception('Received error: {}'.format(reply)) + if reply['reply_type'] != \ + python_nbd_client.NBD_REPLY_TYPE_BLOCK_STATUS: + raise Exception('Unexpected reply: {}'.format(reply)) + + # Then process the returned block status info + assert_protocol(reply['context_id'] == meta_context_id) + # Note: There might be consecutive descriptors with the same status + # value. + descriptors = reply['descriptors'] + for i, descriptor in enumerate(descriptors, 1): + (extent_length, flags) = descriptor + if i == (len(descriptors)): + # The first N-1 extents must be smaller than the requested + # length, but the last extent can exceed the requested + # length + extent_length = min(extent_length, end - offset) + yield {'length': extent_length, 'flags': flags} + offset += extent_length + assert_protocol(offset <= end) + + +def _main(): + # Configure the root logger to log into syslog + # (Specifically, into /var/log/user.log) + syslog_handler = logging.handlers.SysLogHandler( + address='/dev/log', + facility=logging.handlers.SysLogHandler.LOG_USER) + # Ensure the program name is included in the log messages: + formatter = logging.Formatter('%(name)s: [%(levelname)s] %(message)s') + syslog_handler.setFormatter(formatter) + logging.getLogger().addHandler(syslog_handler) + + try: + parser = argparse.ArgumentParser( + description="Return a list of extents with their block statuses. " + "The returned extents are consecutive, non-" + "overlapping, in the correct order starting from the " + "specified offset, and exactly cover the requested " + "area. There might be consecutive extents with the " + "same status flags.") + parser.add_argument( + '--path', + required=True, + help="The path of the Unix domain socket of the NBD server") + parser.add_argument( + '--exportname', + required=True, + help="The export name of the device to connect to") + parser.add_argument( + '--offset', + required=True, + type=int, + help="The returned list of extents will be computed " + "starting from this offset") + parser.add_argument( + '--length', + required=True, + type=int, + help="The returned list of extents will be computed " + "for an area of this length starting at the given offset") + + args = parser.parse_args() + LOGGER.debug('Called with args %s', args) + + extents = list( + _get_extents( + path=args.path, + exportname=args.exportname, + offset=args.offset, + length=args.length)) + print json.dumps(extents) + except Exception as exc: + LOGGER.exception(exc) + raise + + +if __name__ == '__main__': + _main() diff --git a/scripts/python_nbd_client.py b/scripts/python_nbd_client.py new file mode 100644 index 0000000..c097ba3 --- /dev/null +++ b/scripts/python_nbd_client.py @@ -0,0 +1,694 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (C) 2013 Nodalink, SARL. +# +# Simple nbd client used to connect to qemu-nbd +# +# author: BenoƮt Canet +# +# This work is open source software, licensed under the terms of the +# BSD license as described in the LICENSE file in the top-level directory. +# + +# Original file from +# https://github.com/cloudius-systems/osv/blob/master/scripts/nbd_client.py , +# added support for (non-fixed) newstyle negotation, +# then @thomasmck added support for fixed-newstyle negotiation and TLS + +""" +A pure-Python NBD client. + +This client implement the NBD protocol, and supports both the oldstyle and +newstyle negotiations: +https://github.com/NetworkBlockDevice/nbd/blob/master/doc/proto.md +Additionally, it supports the BLOCK_STATUS extension: +for the extension docs, see the same file in the extension-blockstatus branch. +""" + +import socket +import struct +import ssl +import logging + +LOGGER = logging.getLogger('python_nbd_client') + +# Request types +NBD_CMD_READ = 0 +NBD_CMD_WRITE = 1 +# a disconnect request +NBD_CMD_DISC = 2 +NBD_CMD_FLUSH = 3 +NBD_CMD_BLOCK_STATUS = 7 + +# Transmission flags +NBD_FLAG_HAS_FLAGS = (1 << 0) +NBD_FLAG_SEND_FLUSH = (1 << 2) + +# Client flags +NBD_FLAG_C_FIXED_NEWSTYLE = (1 << 0) + +# Option types +NBD_OPT_EXPORT_NAME = 1 +NBD_OPT_ABORT = 2 +NBD_OPT_STARTTLS = 5 +NBD_OPT_INFO = 6 +NBD_OPT_STRUCTURED_REPLY = 8 +NBD_OPT_LIST_META_CONTEXT = 9 +NBD_OPT_SET_META_CONTEXT = 10 + +# Option reply types +NBD_REP_ERROR_BIT = (1 << 31) +NBD_REP_ACK = 1 +NBD_REP_INFO = 3 +NBD_REP_META_CONTEXT = 4 + +OPTION_REPLY_MAGIC = 0x3e889045565a9 + +NBD_REQUEST_MAGIC = 0x25609513 +NBD_SIMPLE_REPLY_MAGIC = 0x67446698 +NBD_STRUCTURED_REPLY_MAGIC = 0x668e33ef + +# Structured reply types +NBD_REPLY_TYPE_NONE = 0 +NBD_REPLY_TYPE_OFFSET_DATA = 1 +NBD_REPLY_TYPE_OFFSET_HOLE = 2 +NBD_REPLY_TYPE_BLOCK_STATUS = 5 +NBD_REPLY_TYPE_ERROR_BIT = (1 << 15) +NBD_REPLY_TYPE_ERROR = (1 << 15 + 1) +NBD_REPLY_TYPE_ERROR_OFFSET = (1 << 15 + 2) + +# Structured reply flags +NBD_REPLY_FLAG_DONE = (1 << 0) + +# NBD_INFO information types +NBD_INFO_EXPORT = 0 +NBD_INFO_NAME = 1 +NBD_INFO_DESCRIPTION = 2 +NBD_INFO_BLOCK_SIZE = 3 + + +class NBDEOFError(EOFError): + """ + An end of file error happened while reading from the socket, because it has + been closed. + """ + pass + + +class NBDTransmissionError(Exception): + """ + The NBD server returned a non-zero error value in its response to a + request. + + :attribute error_code: The error code returned by the server. + """ + def __init__(self, error_code): + super(NBDTransmissionError, self).__init__( + "Server returned error during transmission: {}".format(error_code)) + self.error_code = error_code + + +class NBDOptionError(Exception): + """ + The NBD server replied with an error to the option sent by the client. + + :attribute reply: The error reply sent by the server. + """ + def __init__(self, reply): + error_code = reply - NBD_REP_ERROR_BIT + super(NBDOptionError, self).__init__( + "Server returned error during option haggling: " + "reply type={}; error code={}".format(reply, error_code)) + self.reply = reply + + +class NBDUnexpectedOptionResponseError(Exception): + """ + The NBD server sent a response to an option different from the most recent + one that the client is expecting a response to. + + :attribute expected: The option that was last sent by the client, to which + it is expecting a response. + :attribute received: The server's response is a reply to this option. + """ + def __init__(self, expected, received): + super(NBDUnexpectedOptionResponseError, self).__init__( + "Received response to unexpected option {}; " + "was expecting a response to option {}" + .format(received, expected)) + self.expected = expected + self.received = received + + +class NBDUnexpectedStructuredReplyType(Exception): + """ + The NBD server sent a structured reply chunk with an unexpected type that + is not known by this client. + + :attribute type: The type of the structured chunk message. + """ + def __init__(self, reply_type): + super(NBDUnexpectedStructuredReplyType, self).__init__( + "Received a structured reply chunk message " + "with an unexpected type: {}".format(reply_type)) + self.reply_type = reply_type + + +class NBDUnexpectedReplyHandleError(Exception): + """ + The NBD server sent a reply to a request different from the most recent one + that the client is expecting a response to. + + :attribute expected: The handle of the most recent request that the client + is expecting a reply to. + :attribute received: The server's reply contained this handle. + """ + def __init__(self, expected, received): + super(NBDUnexpectedReplyHandleError, self).__init__( + "Received reply with unexpected handle {}; " + "was expecting a response to the request with " + "handle {}" + .format(received, expected)) + self.expected = expected + self.received = received + + +class NBDProtocolError(Exception): + """ + The NBD server sent an invalid response that is not allowed by the NBD + protocol. + """ + pass + + +def assert_protocol(assertion): + """Raise an NBDProtocolError if the given condition is false.""" + if assertion is False: + raise NBDProtocolError + + +def _check_alignment(name, value): + if not value % 512: + return + raise ValueError("%s=%i is not a multiple of 512" % (name, value)) + + +def _is_final_structured_reply_chunk(flags): + return flags & NBD_REPLY_FLAG_DONE == NBD_REPLY_FLAG_DONE + + +def is_error_chunk(reply_type): + """ + Returns true if the structured reply chunk with the given type is an error + chunk. + """ + return reply_type & NBD_REPLY_TYPE_ERROR_BIT != 0 + + +def _parse_block_status_descriptors(data): + while data: + (length, status_flags) = struct.unpack(">LL", data[:8]) + yield (length, status_flags) + data = data[8:] + + +class PythonNbdClient(object): + """ + A pure-Python NBD client. Supports both the fixed-newstyle and the oldstyle + negotiation, and also has support for upgrading the connection to TLS + during fixed-newstyle negotiation, structured replies, and the BLOCK_STATUS + extension. + """ + + def __init__(self, + address, + exportname="", + port=10809, + timeout=60, + subject=None, + cert=None, + use_tls=True, + new_style_handshake=True, + unix=False, + connect=True): + LOGGER.info("Creating connection to address '%s' and port '%s'", + address, port) + self._flushed = True + self._closed = True + self._handle = 0 + self._last_sent_option = None + self._structured_reply = False + self._transmission_phase = False + if unix: + self._s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + else: + self._s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + if not unix: + address = (address, int(port)) + self._s.settimeout(timeout) + self._s.connect(address) + self._closed = False + if new_style_handshake: + self._fixed_new_style_handshake( + cert=cert, + subject=subject, + use_tls=use_tls) + if connect: + self.connect(exportname=exportname) + else: + self._old_style_handshake() + + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def close(self): + """ + Sends a flush request to the server if necessary and the server + supports it, followed by a disconnect request. + """ + if self._transmission_phase and (not self._flushed): + self.flush() + if not self._closed: + self._disconnect() + self._closed = True + + def _recvall(self, length): + data = bytearray(length) + view = memoryview(data) + bytes_left = length + while bytes_left: + received = self._s.recv_into(view, bytes_left) + # If recv reads 0 bytes, that means the peer has properly + # shut down the TCP session (end-of-file error): + if not received: + raise NBDEOFError + view = view[received:] + bytes_left -= received + return data + + # Handshake phase + + # Newstyle handshake + + def _send_option(self, option, data=b''): + LOGGER.debug("NBD sending option header") + data_length = len(data) + LOGGER.debug("option='%d' data_length='%d'", option, data_length) + self._s.sendall(b'IHAVEOPT') + header = struct.pack(">LL", option, data_length) + self._s.sendall(header + data) + self._last_sent_option = option + + def _parse_option_reply(self): + LOGGER.debug("NBD parsing option reply") + reply = self._recvall(8 + 4 + 4 + 4) + (magic, option, reply_type, data_length) = struct.unpack( + ">QLLL", reply) + LOGGER.debug("NBD reply magic='0x%x' option='%d' reply_type='%d'", + magic, option, reply_type) + assert_protocol(magic == OPTION_REPLY_MAGIC) + if option != self._last_sent_option: + raise NBDUnexpectedOptionResponseError( + expected=self._last_sent_option, received=option) + if reply_type & NBD_REP_ERROR_BIT != 0: + raise NBDOptionError(reply=reply_type) + data = self._recvall(data_length) + return (reply_type, data) + + def _parse_option_reply_ack(self): + (reply_type, data) = self._parse_option_reply() + if reply_type != NBD_REP_ACK: + raise NBDProtocolError() + return data + + def _parse_meta_context_reply(self): + (reply_type, data) = self._parse_option_reply() + if reply_type == NBD_REP_ACK: + return None + assert_protocol(reply_type == NBD_REP_META_CONTEXT) + context_id = struct.unpack(">L", data[:4])[0] + name = (data[4:]).decode('utf-8') + return (context_id, name) + + def _upgrade_socket_to_tls(self, cert, subject): + # Forcing the client to use TLSv1_2 + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1_2) + context.options &= ~ssl.OP_NO_TLSv1 + context.options &= ~ssl.OP_NO_TLSv1_1 + context.options &= ~ssl.OP_NO_SSLv2 + context.options &= ~ssl.OP_NO_SSLv3 + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = (subject is not None) + context.load_verify_locations(cadata=cert) + cleartext_socket = self._s + self._s = context.wrap_socket( + cleartext_socket, + server_side=False, + do_handshake_on_connect=True, + server_hostname=subject) + + def _initiate_tls_upgrade(self): + # start TLS negotiation + self._send_option(NBD_OPT_STARTTLS) + # receive reply + data = self._parse_option_reply_ack() + assert_protocol(len(data) == 0) + + def request_info(self, export_name, info_requests): + """Query information from the server.""" + data = struct.pack('>L', len(export_name)) + data += export_name.encode('utf-8') + data += struct.pack('>H', len(info_requests)) + for request in info_requests: + data += struct.pack('>H', request) + self._send_option(NBD_OPT_INFO, data) + infos = [] + while True: + (reply_type, data) = self._parse_option_reply() + if reply_type == NBD_REP_INFO: + info_type = struct.unpack(">H", data[:2])[0] + info = {'information_type': info_type} + payload = data[2:] + if info_type == NBD_INFO_BLOCK_SIZE: + assert_protocol(len(data) == 14) + sizes = struct.unpack('>LLL', payload) + (info['minimum_block_size'], + info['preferred_block_size'], + info['maximum_block_size']) = sizes + infos += [info] + elif info_type == NBD_INFO_EXPORT: + assert_protocol(len(data) == 12) + export_info = struct.unpack('>QH', payload) + (info['size'], + info['transmission_flags']) = export_info + infos += [info] + else: + # The client MUST ignore information replies it does not + # understand. + LOGGER.warning('Unsupported info reply type: %d', + info_type) + elif reply_type == NBD_REP_ACK: + assert_protocol(not data) + break + else: + raise NBDProtocolError( + 'Unexpected reply type: {}'.format(reply_type)) + return infos + + def negotiate_structured_reply(self): + """ + Negotiate use of the structured reply extension, fail if unsupported. + Only valid during the handshake phase. + """ + self._send_option(NBD_OPT_STRUCTURED_REPLY) + self._parse_option_reply_ack() + self._structured_reply = True + + def _process_meta_context_option(self, option, export_name, queries): + data = struct.pack('>L', len(export_name)) + data += export_name.encode('utf-8') + data += struct.pack('>L', len(queries)) + for query in queries: + data += struct.pack('>L', len(query)) + data += query.encode('utf-8') + self._send_option(option, data) + while True: + reply = self._parse_meta_context_reply() + if reply is None: + break + yield reply + + def _send_meta_context_option(self, option, export_name, queries): + return list(self._process_meta_context_option( + option, export_name, queries)) + + def set_meta_contexts(self, export_name, queries): + """ + Change the set of active metadata contexts. Only valid during the + handshake phase. Returns the list of selected metadata contexts as + (metadata context ID, metadata context name) pairs. + Structured replies must be negotiated first using + negotiate_structured_reply. + """ + return self._send_meta_context_option( + option=NBD_OPT_SET_META_CONTEXT, + export_name=export_name, + queries=queries) + + def list_meta_contexts(self, export_name, queries): + """ + Return the metadata contexts available on the export matching one or + more of the queries as (metadata context ID, metadata context name) + pairs. + Structured replies be negotiated first using + negotiate_structured_reply. + """ + return self._send_meta_context_option( + option=NBD_OPT_LIST_META_CONTEXT, + export_name=export_name, + queries=queries) + + def _fixed_new_style_handshake(self, cert, subject, use_tls): + nbd_magic = self._recvall(len("NBDMAGIC")) + assert_protocol(nbd_magic == b'NBDMAGIC') + nbd_magic = self._recvall(len("IHAVEOPT")) + assert_protocol(nbd_magic == b'IHAVEOPT') + buf = self._recvall(2) + handshake_flags = struct.unpack(">H", buf)[0] + assert_protocol(handshake_flags & NBD_FLAG_HAS_FLAGS != 0) + client_flags = NBD_FLAG_C_FIXED_NEWSTYLE + client_flags = struct.pack('>L', client_flags) + self._s.sendall(client_flags) + + if use_tls: + # start TLS negotiation + self._initiate_tls_upgrade() + # upgrade socket to TLS + self._upgrade_socket_to_tls(cert, subject) + + def connect(self, exportname): + """ + Valid only during the handshake phase. Requests the given + export and enters the transmission phase. + """ + LOGGER.info("Connecting to export '%s' using newstyle negotiation", + exportname) + # request export + self._send_option(NBD_OPT_EXPORT_NAME, exportname.encode('utf-8')) + + # non-fixed newstyle negotiation: we get these if the server is willing + # to allow the export + buf = self._recvall(10) + (self._size, self._transmission_flags) = struct.unpack(">QH", buf) + LOGGER.debug("NBD got size=%d transmission flags=%d", + self._size, self._transmission_flags) + # ignore the zeroes + zeroes = self._recvall(124) + LOGGER.debug("NBD got zeroes: %s", zeroes) + self._transmission_phase = True + LOGGER.debug("Connected") + + # Oldstyle handshake + + def _old_style_handshake(self): + LOGGER.info("Connecting to server using oldstyle negotiation") + nbd_magic = self._recvall(len("NBDMAGIC")) + assert_protocol(nbd_magic == b'NBDMAGIC') + buf = self._recvall(8 + 8 + 4) + (magic, + self._size, + self._transmission_flags) = struct.unpack(">QQL", buf) + assert_protocol(magic == 0x00420281861253) + # ignore trailing zeroes + self._recvall(124) + self._transmission_phase = True + + # Transmission phase + + def _send_request_header(self, request_type, offset, length): + LOGGER.debug("NBD request offset=%d length=%d", offset, length) + command_flags = 0 + self._handle += 1 + header = struct.pack('>LHHQQL', NBD_REQUEST_MAGIC, command_flags, + request_type, self._handle, offset, length) + self._s.sendall(header) + + def _check_handle(self, handle): + if handle != self._handle: + raise NBDUnexpectedReplyHandleError( + expected=self._handle, received=handle) + + def _parse_simple_reply(self, data_length=0): + LOGGER.debug("NBD parsing simple reply, data_length=%d", data_length) + reply = self._recvall(4 + 4 + 8) + (magic, errno, handle) = struct.unpack(">LLQ", reply) + LOGGER.debug("NBD simple reply magic='0x%x' errno='%d' handle='%d'", + magic, errno, handle) + assert_protocol(magic == NBD_SIMPLE_REPLY_MAGIC) + self._check_handle(handle) + data = self._recvall(length=data_length) + LOGGER.debug("NBD response received data_length=%d bytes", data_length) + if errno != 0: + raise NBDTransmissionError(errno) + return data + + def _handle_block_status_reply(self, fields): + data_length = fields['data_length'] + assert_protocol((data_length >= 12) and (data_length % 8 == 4)) + data = self._recvall(data_length) + view = memoryview(data) + fields['context_id'] = struct.unpack(">L", view[:4])[0] + view = view[4:] + descriptors = list(_parse_block_status_descriptors(view)) + assert_protocol(descriptors) + fields['descriptors'] = descriptors + + def _handle_data_reply(self, fields): + data_length = fields['data_length'] + assert_protocol(data_length >= 9) + buf = self._recvall(8) + fields['offset'] = struct.unpack(">Q", buf)[0] + fields['data'] = self._recvall(data_length - 8) + assert_protocol(fields['data']) + + def _handle_hole_reply(self, fields): + assert_protocol(fields['data_length'] == 12) + buf = self._recvall(12) + (fields['offset'], fields['hole_size']) = struct.unpack(">QL", buf) + + def _handle_structured_reply_error(self, fields): + data_length = fields['data_length'] + assert_protocol(data_length >= 6) + buf = self._recvall(4 + 2) + (errno, message_length) = struct.unpack(">LH", buf) + fields['error'] = errno + remaining_length = data_length - 6 + # The client MAY continue transmission in case of an unexpected error + # type, unless message_length does not fit into the length: + if message_length > remaining_length: + raise NBDProtocolError( + 'message_length is too large to fit within data_length bytes') + data = self._recvall(remaining_length) + view = memoryview(data) + fields['message'] = view[0:message_length].tobytes().decode('utf-8') + view = view[message_length:] + if fields['reply_type'] == NBD_REPLY_TYPE_ERROR_OFFSET: + fields['offset'] = struct.unpack(">Q", view)[0] + + def _parse_structured_reply_chunk(self): + LOGGER.debug("NBD parsing structured reply chunk") + reply = self._recvall(4 + 2 + 2 + 8 + 4) + header = struct.unpack(">LHHQL", reply) + (magic, flags, reply_type, handle, data_length) = header + LOGGER.debug("NBD structured reply magic='%x' flags='%s' " + "reply_type='%d' handle='%d' data_length='%d'", + magic, flags, reply_type, handle, data_length) + assert_protocol(magic == NBD_STRUCTURED_REPLY_MAGIC) + self._check_handle(handle) + fields = {'flags': flags, + 'reply_type': reply_type, + 'data_length': data_length} + if reply_type == NBD_REPLY_TYPE_BLOCK_STATUS: + self._handle_block_status_reply(fields) + elif reply_type == NBD_REPLY_TYPE_NONE: + assert_protocol(data_length == 0) + assert_protocol(_is_final_structured_reply_chunk(flags=flags)) + elif reply_type == NBD_REPLY_TYPE_OFFSET_DATA: + self._handle_data_reply(fields) + elif reply_type == NBD_REPLY_TYPE_OFFSET_HOLE: + self._handle_hole_reply(fields) + elif is_error_chunk(reply_type=reply_type): + self._handle_structured_reply_error(fields) + else: + raise NBDUnexpectedStructuredReplyType(reply_type) + return fields + + def _parse_structured_reply_chunks(self): + while True: + reply = self._parse_structured_reply_chunk() + yield reply + if _is_final_structured_reply_chunk(flags=reply['flags']): + return + + def write(self, data, offset): + """ + Writes the given bytes to the export, starting at the given + offset. + """ + LOGGER.debug("NBD_CMD_WRITE") + _check_alignment("offset", offset) + _check_alignment("size", len(data)) + self._flushed = False + self._send_request_header(NBD_CMD_WRITE, offset, len(data)) + self._s.sendall(data) + # TODO: the server MAY respond with a structured reply (e.g. to report + # errors) + self._parse_simple_reply() + return len(data) + + def read(self, offset, length): + """ + Returns length number of bytes read from the export, starting at + the given offset. + If structured replies have been negotiated, it returns a generator + containing the reply chunks. The caller must consume this generator + before further NBD commands, since this client does not support + asynchronous request processing. + """ + LOGGER.debug("NBD_CMD_READ") + _check_alignment("offset", offset) + _check_alignment("length", length) + self._send_request_header(NBD_CMD_READ, offset, length) + if self._structured_reply: + return self._parse_structured_reply_chunks() + data = self._parse_simple_reply(length) + return data + + def _need_flush(self): + return self._transmission_flags & NBD_FLAG_SEND_FLUSH != 0 + + def flush(self): + """ + Sends a flush request to the server if the server supports it + and there are unflushed writes. This causes all completed writes + (the writes for which the server has already sent a reply to the + client) to be written to permanent storage. + """ + if self._need_flush() is False: + self._flushed = True + return False + LOGGER.debug("NBD_CMD_FLUSH") + self._send_request_header(NBD_CMD_FLUSH, 0, 0) + # TODO: the server MAY respond with a structured reply (e.g. to report + # errors) + self._parse_simple_reply() + self._flushed = True + return True + + def query_block_status(self, offset, length): + """ + Query block status in the range defined by length and offset. + Returns a list of structured reply chunks. + The required meta contexts must have been negotiated using + set_meta_contexts. + """ + LOGGER.debug("NBD_CMD_BLOCK_STATUS") + self._send_request_header(NBD_CMD_BLOCK_STATUS, offset, length) + return list(self._parse_structured_reply_chunks()) + + def _disconnect(self): + if self._transmission_phase: + LOGGER.debug("NBD_CMD_DISC") + self._send_request_header(NBD_CMD_DISC, 0, 0) + else: + self._send_option(NBD_OPT_ABORT) + + def get_size(self): + """ + Return the size of the device in bytes. + """ + return self._size diff --git a/src/channels.ml b/src/channels.ml index 39d111d..b2faecb 100644 --- a/src/channels.ml +++ b/src/channels.ml @@ -18,53 +18,100 @@ open Lwt open Lwt_preemptive -external _sendfile: Unix.file_descr -> Unix.file_descr -> int64 -> int64 = "stub_sendfile64" +type handle -let _sendfile from_fd to_fd len = - let from_fd = Lwt_unix.unix_file_descr from_fd in - let to_fd = Lwt_unix.unix_file_descr to_fd in - detach (_sendfile from_fd to_fd) len +external _init : Unix.file_descr -> Unix.file_descr -> handle = "stub_init" + +external _cleanup : handle -> unit = "stub_cleanup" + +external _direct_copy : handle -> int64 -> int64 = "stub_direct_copy" + +let with_handle from_fd to_fd f = + let unix_from_fd = Lwt_unix.unix_file_descr from_fd in + let unix_to_fd = Lwt_unix.unix_file_descr to_fd in + let handle = _init unix_from_fd unix_to_fd in + Lwt.finalize + (fun () -> f handle) + (fun () -> _cleanup handle ; Lwt.return_unit) + +let _direct_copy handle _from_fd _to_fd len = + (* Atlhough the FD is set to blocking mode (by [with_blocking_fd]), + * the fcntl flags are only updated lazily, either by + * first call to an Lwt_unix IO function or [wrap_syscall]. + * Perform `wrap_syscall` here to avoid EAGAIN *) + detach (_direct_copy handle) len + +let maybe_fdatasync stat to_fd = + match stat.Lwt_unix.LargeFile.st_kind with + | Unix.S_REG | Unix.S_BLK -> + Lwt_unix.fdatasync to_fd + | _ -> + Lwt.return_unit (* The OS implementation can return short (e.g. Linux will stop at a 2GiB boundary). This function keeps copying until all the bytes are copied. *) -let rec sendfile from_fd to_fd len = - (* sendfile requires sockets in non-blocking mode *) +let direct_copy from_fd to_fd len = + (* direct_copy requires sockets in non-blocking mode *) let with_blocking_fd fd f = - Lwt_unix.blocking fd - >>= function - | true -> f fd - | false -> - Lwt_unix.set_blocking fd true; - Lwt.catch - (fun () -> - f fd - >>= fun r -> - Lwt_unix.set_blocking fd false; - return r - ) (fun e -> - Lwt_unix.set_blocking fd false; - fail e) in - with_blocking_fd from_fd - (fun from_fd -> - with_blocking_fd to_fd - (fun to_fd -> - let rec loop remaining = - if remaining > 0L then begin - _sendfile from_fd to_fd remaining - >>= fun written -> - loop (Int64.sub remaining written) - end else return () in - loop len - ) - ) + Lwt_unix.blocking fd >>= function + | true -> + f fd + | false -> ( + Lwt_unix.set_blocking fd true ; + (* [set_blocking] sets the flags lazily, + * force them to be set by querying it *) + Lwt_unix.blocking fd >>= function + | false -> + Lwt.fail_with "Failed to set FD to blocking mode" + | true -> + Lwt.catch + (fun () -> + f fd >>= fun r -> + Lwt_unix.set_blocking fd false ; + return r) + (fun e -> + Lwt_unix.set_blocking fd false ; + fail e) + ) + in + + let sync_limit = Int64.(mul 4L (mul 1024L 1024L)) in + + let write handle from_fd to_fd to_write = + let rec loop remaining = + if remaining > 0L then + _direct_copy handle from_fd to_fd remaining >>= fun written -> + loop (Int64.sub remaining written) + else + return () + in + loop to_write >>= fun () -> return to_write + in + + let min x y = if Int64.compare x y = -1 then x else y in + + with_blocking_fd from_fd (fun from_fd -> + with_blocking_fd to_fd (fun to_fd -> + with_handle from_fd to_fd (fun handle -> + Lwt_unix.LargeFile.fstat to_fd >>= fun stat -> + let rec loop remaining = + if remaining > 0L then + let to_write = min sync_limit remaining in + write handle from_fd to_fd to_write >>= fun written -> + maybe_fdatasync stat to_fd >>= fun () -> + loop (Int64.sub remaining written) + else + return () + in + loop len))) type t = { - really_read: Cstruct.t -> unit Lwt.t; - really_write: Cstruct.t -> unit Lwt.t; - offset: int64 ref; - skip: int64 -> unit Lwt.t; - copy_from: Lwt_unix.file_descr -> int64 -> int64 Lwt.t; - close: unit -> unit Lwt.t + really_read: Cstruct.t -> unit Lwt.t + ; really_write: Cstruct.t -> unit Lwt.t + ; offset: int64 ref + ; skip: int64 -> unit Lwt.t + ; copy_from: Lwt_unix.file_descr -> int64 -> int64 Lwt.t + ; close: unit -> unit Lwt.t } exception Impossible_to_seek @@ -73,66 +120,67 @@ let of_raw_fd fd = let offset = ref 0L in let really_read buf = IO.complete "read" (Some !offset) Lwt_bytes.read fd buf >>= fun () -> - offset := Int64.(add !offset (of_int (Cstruct.len buf))); - return () in + (offset := Int64.(add !offset (of_int (Cstruct.len buf)))) ; + return () + in let really_write buf = IO.complete "write" (Some !offset) Lwt_bytes.write fd buf >>= fun () -> - offset := Int64.(add !offset (of_int (Cstruct.len buf))); - return () in + (offset := Int64.(add !offset (of_int (Cstruct.len buf)))) ; + return () + in let skip _ = fail Impossible_to_seek in let copy_from from_fd len = - sendfile from_fd fd len - >>= fun () -> - offset := Int64.(add !offset len); - return len in + direct_copy from_fd fd len >>= fun () -> + (offset := Int64.(add !offset len)) ; + return len + in let close () = Lwt_unix.close fd in - return { really_read; really_write; offset; skip; copy_from; close } + return {really_read; really_write; offset; skip; copy_from; close} let of_seekable_fd fd = of_raw_fd fd >>= fun c -> let skip n = Lwt_unix.LargeFile.lseek fd n Unix.SEEK_CUR >>= fun offset -> - c.offset := offset; - return () in - return { c with skip } + c.offset := offset ; + return () + in + return {c with skip} -let _ = - Ssl.init () - -let legacy_sslctx good_ciphersuites legacy_ciphersuites = - let ctx = Ssl.create_context Ssl.SSLv23 Ssl.Client_context in - Ssl.set_cipher_list ctx (good_ciphersuites ^ (match legacy_ciphersuites with "" -> "" | s -> (":" ^ s))); - Ssl.disable_protocols ctx [Ssl.SSLv3]; - ctx +let _ = Ssl.init () -let good_sslctx good_ciphersuites = +let sslctx good_ciphersuites = let ctx = Ssl.create_context Ssl.TLSv1_2 Ssl.Client_context in - Ssl.set_cipher_list ctx good_ciphersuites; + Ssl.set_cipher_list ctx good_ciphersuites ; ctx -let of_ssl_fd fd ssl_legacy good_ciphersuites legacy_ciphersuites = - let good_ciphersuites = match good_ciphersuites with None -> failwith "good_ciphersuites not specified" | Some x -> x in - let legacy_ciphersuites = match legacy_ciphersuites with None -> "" | Some x -> x in - let sslctx = if ssl_legacy then legacy_sslctx good_ciphersuites legacy_ciphersuites else good_sslctx good_ciphersuites in +let of_ssl_fd fd good_ciphersuites = + let good_ciphersuites = + match good_ciphersuites with + | None -> + failwith "good_ciphersuites not specified" + | Some x -> + x + in + let sslctx = sslctx good_ciphersuites in Lwt_ssl.ssl_connect fd sslctx >>= fun sock -> let offset = ref 0L in let really_read buf = IO.complete "read" (Some !offset) Lwt_ssl.read_bytes sock buf >>= fun () -> - offset := Int64.(add !offset (of_int (Cstruct.len buf))); - return () in + (offset := Int64.(add !offset (of_int (Cstruct.len buf)))) ; + return () + in let really_write buf = - IO.complete "write" (Some !offset) Lwt_ssl.write_bytes sock buf >>= fun () -> - offset := Int64.(add !offset (of_int (Cstruct.len buf))); - return () in + IO.complete "write" (Some !offset) Lwt_ssl.write_bytes sock buf + >>= fun () -> + (offset := Int64.(add !offset (of_int (Cstruct.len buf)))) ; + return () + in let skip _ = fail Impossible_to_seek in let copy_from from_fd len = - sendfile from_fd fd len - >>= fun () -> - offset := Int64.(add !offset len); - return len in - - let close () = - Lwt_ssl.close sock in - return { really_read; really_write; offset; skip; copy_from; close } - + direct_copy from_fd fd len >>= fun () -> + (offset := Int64.(add !offset len)) ; + return len + in + let close () = Lwt_ssl.close sock in + return {really_read; really_write; offset; skip; copy_from; close} diff --git a/src/chunked.ml b/src/chunked.ml index eee15fd..e8152f2 100644 --- a/src/chunked.ml +++ b/src/chunked.ml @@ -12,25 +12,22 @@ * GNU Lesser General Public License for more details. *) -cstruct t { - uint64_t offset; - uint32_t len - (* data *) -} as little_endian +type%cstruct t = {offset: uint64_t; len: uint32_t (* data *)} [@@little_endian] let sizeof = sizeof_t type t = { - offset: int64; (** offset on the physical disk *) - data: Cstruct.t; (** data to write *) + offset: int64 (** offset on the physical disk *) + ; data: Cstruct.t (** data to write *) } -let marshal (buf: Cstruct.t) t = - set_t_offset buf t.offset; +let marshal (buf : Cstruct.t) t = + set_t_offset buf t.offset ; set_t_len buf (Int32.of_int (Cstruct.len t.data)) -let is_last_chunk (buf: Cstruct.t) = - get_t_offset buf = 0L && (get_t_len buf = 0l) +let is_last_chunk (buf : Cstruct.t) = + get_t_offset buf = 0L && get_t_len buf = 0l let get_offset = get_t_offset + let get_len = get_t_len diff --git a/src/cohttp_unbuffered_io.ml b/src/cohttp_unbuffered_io.ml index de9609e..edad5cd 100644 --- a/src/cohttp_unbuffered_io.ml +++ b/src/cohttp_unbuffered_io.ml @@ -15,13 +15,15 @@ * *) -open Lwt - type 'a t = 'a Lwt.t + let iter fn x = Lwt_list.iter_s fn x + let return = Lwt.return -let (>>=) = Lwt.bind -let (>>) m n = m >>= fun _ -> n + +let ( >>= ) = Lwt.bind + +let ( >> ) m n = m >>= fun _ -> n (** Use as few really_{read,write} calls as we can (for efficiency) without explicitly buffering the stream beyond the HTTP headers. This will @@ -29,101 +31,103 @@ let (>>) m n = m >>= fun _ -> n safely to another process *) type ic = { - mutable header_buffer: string option; (** buffered headers *) - mutable header_buffer_idx: int; (** next char within the buffered headers *) - c: Channels.t; + mutable header_buffer: string option (** buffered headers *) + ; mutable header_buffer_idx: int (** next char within the buffered headers *) + ; c: Channels.t } let make_input c = let header_buffer = None in let header_buffer_idx = 0 in - { header_buffer; header_buffer_idx; c } + {header_buffer; header_buffer_idx; c} type oc = Channels.t + type conn = Channels.t let really_read_into c buf ofs len = let tmp = Cstruct.create len in c.Channels.really_read tmp >>= fun () -> - Cstruct.blit_to_string tmp 0 buf ofs len; + Cstruct.blit_to_bytes tmp 0 buf ofs len ; return () let read_http_headers c = let buf = Buffer.create 128 in (* We can safely read everything up to this marker: *) let end_of_headers = "\r\n\r\n" in - let tmp = String.make (String.length end_of_headers) '\000' in + let tmp = Bytes.make (String.length end_of_headers) '\000' in let module Scanner = struct - type t = { - marker: string; - mutable i: int; - } - let make x = { marker = x; i = 0 } - let input x c = - if c = x.marker.[x.i] then x.i <- x.i + 1 else x.i <- 0 + type t = {marker: string; mutable i: int} + + let make x = {marker= x; i= 0} + + let input x c = if c = x.marker.[x.i] then x.i <- x.i + 1 else x.i <- 0 + let remaining x = String.length x.marker - x.i + let matched x = x.i = String.length x.marker + let to_string x = Printf.sprintf "%d" x.i end in let marker = Scanner.make end_of_headers in let rec loop () = - if not(Scanner.matched marker) then begin + if not (Scanner.matched marker) then ( (* We may be part way through reading the end of header marker, so be pessimistic and only read enough bytes to read until the end of the marker. *) let safe_to_read = Scanner.remaining marker in really_read_into c tmp 0 safe_to_read >>= fun () -> - for j = 0 to safe_to_read - 1 do - Scanner.input marker tmp.[j]; - Buffer.add_char buf tmp.[j] - done; + Scanner.input marker (Bytes.get tmp j) ; + Buffer.add_char buf (Bytes.get tmp j) + done ; loop () - end else return () in - loop () >>= fun () -> - return (Buffer.contents buf) + ) else + return () + in + loop () >>= fun () -> return (Buffer.contents buf) -let crlf = Re_str.regexp_string "\r\n" +let crlf = Re.Str.regexp_string "\r\n" (* We assume read_line is only used to read the HTTP header *) -let rec read_line ic = match ic.header_buffer, ic.header_buffer_idx with -| None, _ -> - read_http_headers ic.c >>= fun str -> - ic.header_buffer <- Some str; - read_line ic -| Some buf, i when i < (String.length buf) -> - begin +let rec read_line ic = + match (ic.header_buffer, ic.header_buffer_idx) with + | None, _ -> + read_http_headers ic.c >>= fun str -> + ic.header_buffer <- Some str ; + read_line ic + | Some buf, i when i < String.length buf -> ( try - let eol = Re_str.search_forward crlf buf i in + let eol = Re.Str.search_forward crlf buf i in let line = String.sub buf i (eol - i) in - ic.header_buffer_idx <- eol + 2; + ic.header_buffer_idx <- eol + 2 ; return (Some line) with Not_found -> return (Some "") - end -| Some _, _ -> - return (Some "") + ) + | Some _, _ -> + return (Some "") let read_into_exactly ic buf ofs len = - really_read_into ic.c buf ofs len >>= fun () -> - return true + really_read_into ic.c buf ofs len >>= fun () -> return true let read_exactly ic len = - let buf = String.create len in + let buf = Bytes.create len in read_into_exactly ic buf 0 len >>= function - | true -> return (Some buf) - | false -> return None + | true -> + return (Some buf) + | false -> + return None let read ic n = - let buf = String.make n '\000' in + let buf = Bytes.make n '\000' in really_read_into ic.c buf 0 n >>= fun () -> - return buf + return (Bytes.unsafe_to_string buf) let write oc x = let buf = Cstruct.create (String.length x) in - Cstruct.blit_from_string x 0 buf 0 (String.length x); + Cstruct.blit_from_string x 0 buf 0 (String.length x) ; oc.Channels.really_write buf -let flush oc = - return () +let flush _oc = return () diff --git a/src/common.ml b/src/common.ml index 2b88e33..788a382 100644 --- a/src/common.ml +++ b/src/common.ml @@ -12,21 +12,16 @@ * GNU Lesser General Public License for more details. *) -type t = { - debug: bool; - verb: bool; - unbuffered: bool; - path: string list; -} +type t = {debug: bool; verb: bool; unbuffered: bool; path: string list} -let colon = Re_str.regexp_string ":" +let colon = Re.Str.regexp_string ":" let make debug verb unbuffered path = - let path = Re_str.split colon path in - { debug; verb; unbuffered; path } + let path = Re.Str.split colon path in + {debug; verb; unbuffered; path} (* Keep this in sync with OCaml's Unix.file_descr *) -let file_descr_of_int (x: int) : Unix.file_descr = Obj.magic x +let file_descr_of_int (x : int) : Unix.file_descr = Obj.magic x let ( |> ) a b = b a @@ -38,30 +33,37 @@ let parse_size x = let endswith suffix x = let suffix' = String.length suffix in let x' = String.length x in - x' >= suffix' && (String.sub x (x' - suffix') suffix' = suffix) in + x' >= suffix' && String.sub x (x' - suffix') suffix' = suffix + in let remove suffix x = let suffix' = String.length suffix in let x' = String.length x in - String.sub x 0 (x' - suffix') in + String.sub x 0 (x' - suffix') + in try - if endswith "KiB" x then Int64.(mul kib (of_string (remove "KiB" x))) - else if endswith "MiB" x then Int64.(mul mib (of_string (remove "MiB" x))) - else if endswith "GiB" x then Int64.(mul gib (of_string (remove "GiB" x))) - else if endswith "TiB" x then Int64.(mul tib (of_string (remove "TiB" x))) - else Int64.of_string x - with _ -> - failwith (Printf.sprintf "Cannot parse size: %s" x) + if endswith "KiB" x then + Int64.(mul kib (of_string (remove "KiB" x))) + else if endswith "MiB" x then + Int64.(mul mib (of_string (remove "MiB" x))) + else if endswith "GiB" x then + Int64.(mul gib (of_string (remove "GiB" x))) + else if endswith "TiB" x then + Int64.(mul tib (of_string (remove "TiB" x))) + else + Int64.of_string x + with _ -> failwith (Printf.sprintf "Cannot parse size: %s" x) module type Floatable = sig type t - val to_float: t -> float + + val to_float : t -> float end let hms secs = - let h = secs / 3600 in - let m = (secs mod 3600) / 60 in - let s = secs mod 60 in - Printf.sprintf "%02d:%02d:%02d" h m s + let h = secs / 3600 in + let m = secs mod 3600 / 60 in + let s = secs mod 60 in + Printf.sprintf "%02d:%02d:%02d" h m s let size bytes = let open Int64 in @@ -69,72 +71,91 @@ let size bytes = let mib = mul kib 1024L in let gib = mul mib 1024L in let tib = mul gib 1024L in - if div bytes tib > 0L - then Printf.sprintf "%Ld TiB" (div bytes tib) - else if div bytes gib > 0L - then Printf.sprintf "%Ld GiB" (div bytes gib) - else if div bytes mib > 0L - then Printf.sprintf "%Ld MiB" (div bytes mib) - else if div bytes kib > 0L - then Printf.sprintf "%Ld KiB" (div bytes kib) - else Printf.sprintf "%Ld bytes" bytes - -module Progress_bar(T: Floatable) = struct + if div bytes tib > 0L then + Printf.sprintf "%Ld TiB" (div bytes tib) + else if div bytes gib > 0L then + Printf.sprintf "%Ld GiB" (div bytes gib) + else if div bytes mib > 0L then + Printf.sprintf "%Ld MiB" (div bytes mib) + else if div bytes kib > 0L then + Printf.sprintf "%Ld KiB" (div bytes kib) + else + Printf.sprintf "%Ld bytes" bytes + +module Progress_bar (T : Floatable) = struct type t = { - max_value: T.t; - mutable current_value: T.t; - width: int; - line: string; - mutable spin_index: int; - start_time: float; - mutable summarised: bool; + max_value: T.t + ; mutable current_value: T.t + ; width: int + ; line: bytes + ; mutable spin_index: int + ; start_time: float + ; mutable summarised: bool } let prefix_s = "[*] " + let prefix = String.length prefix_s + let suffix_s = " ( % ETA : : )" + let suffix = String.length suffix_s - let spinner = [| '-'; '\\'; '|'; '/' |] + let spinner = [|'-'; '\\'; '|'; '/'|] let create width current_value max_value = - let line = String.make width ' ' in - String.blit prefix_s 0 line 0 prefix; - String.blit suffix_s 0 line (width - suffix - 1) suffix; + let line = Bytes.make width ' ' in + String.blit prefix_s 0 line 0 prefix ; + String.blit suffix_s 0 line (width - suffix - 1) suffix ; let spin_index = 0 in let start_time = Unix.gettimeofday () in - { max_value; current_value; width; line; spin_index; start_time; summarised = false } - - let percent t = int_of_float (T.(to_float t.current_value /. (to_float t.max_value) *. 100.)) + { + max_value + ; current_value + ; width + ; line + ; spin_index + ; start_time + ; summarised= false + } + + let percent t = + int_of_float T.(to_float t.current_value /. to_float t.max_value *. 100.) let bar_width t value = - int_of_float (T.(to_float value /. (to_float t.max_value) *. (float_of_int (t.width - prefix - suffix)))) + int_of_float + T.( + to_float value + /. to_float t.max_value + *. float_of_int (t.width - prefix - suffix)) let eta t = let time_so_far = Unix.gettimeofday () -. t.start_time in - let total_time = T.(to_float t.max_value /. (to_float t.current_value)) *. time_so_far in + let total_time = + T.(to_float t.max_value /. to_float t.current_value) *. time_so_far + in let remaining = int_of_float (total_time -. time_so_far) in hms remaining let print_bar t = let w = bar_width t t.current_value in - t.line.[1] <- spinner.(t.spin_index); - t.spin_index <- (t.spin_index + 1) mod (Array.length spinner); + Bytes.set t.line 1 @@ spinner.(t.spin_index) ; + t.spin_index <- (t.spin_index + 1) mod Array.length spinner ; for i = 0 to w - 1 do - t.line.[prefix + i] <- (if i = w - 1 then '>' else '#') - done; + Bytes.set t.line (prefix + i) @@ if i = w - 1 then '>' else '#' + done ; let percent = Printf.sprintf "%3d" (percent t) in - String.blit percent 0 t.line (t.width - 19) 3; + String.blit percent 0 t.line (t.width - 19) 3 ; let eta = eta t in - String.blit eta 0 t.line (t.width - 10) (String.length eta); - - Printf.printf "\r%s%!" t.line + String.blit eta 0 t.line (t.width - 10) (String.length eta) ; + + Printf.printf "\r%s%!" (Bytes.to_string t.line) let update t new_value = let new_value = min new_value t.max_value in let old_bar = bar_width t t.current_value in let new_bar = bar_width t new_value in - t.current_value <- new_value; + t.current_value <- new_value ; new_bar <> old_bar let average_rate t = @@ -142,31 +163,45 @@ module Progress_bar(T: Floatable) = struct T.to_float t.current_value /. time_so_far let summarise t = - if not t.summarised then begin - t.summarised <- true; - Printf.printf "Total work done: %s\n" (size (Int64.of_float (T.to_float t.current_value))); - Printf.printf "Total time: %s\n" (hms (int_of_float (Unix.gettimeofday () -. t.start_time))); - Printf.printf "Average rate: %s / sec\n" (size (Int64.of_float (average_rate t))) - end + if not t.summarised then ( + t.summarised <- true ; + Printf.printf "Total work done: %s\n" + (size (Int64.of_float (T.to_float t.current_value))) ; + Printf.printf "Total time: %s\n" + (hms (int_of_float (Unix.gettimeofday () -. t.start_time))) ; + Printf.printf "Average rate: %s / sec\n" + (size (Int64.of_float (average_rate t))) + ) end let padto blank n s = - let result = String.make n blank in - String.blit s 0 result 0 (min n (String.length s)); - result + let result = Bytes.make n blank in + String.blit s 0 result 0 (min n (String.length s)) ; + Bytes.unsafe_to_string result let print_table header rows = let nth xs i = try List.nth xs i with Not_found -> "" in let width_of_column i = - let values = nth header i :: (List.map (fun r -> nth r i) rows) in + let values = nth header i :: List.map (fun r -> nth r i) rows in let widths = List.map String.length values in - List.fold_left max 0 widths in - let widths = List.rev (snd(List.fold_left (fun (i, acc) _ -> (i + 1, (width_of_column i) :: acc)) (0, []) header)) in + List.fold_left max 0 widths + in + let widths = + List.rev + (snd + (List.fold_left + (fun (i, acc) _ -> (i + 1, width_of_column i :: acc)) + (0, []) header)) + in let print_row row = - List.iter (fun (n, s) -> Printf.printf "%s |" (padto ' ' n s)) (List.combine widths row); - Printf.printf "\n" in - print_row header; - List.iter (fun (n, _) -> Printf.printf "%s-|" (padto '-' n "")) (List.combine widths header); - Printf.printf "\n"; + List.iter + (fun (n, s) -> Printf.printf "%s |" (padto ' ' n s)) + (List.combine widths row) ; + Printf.printf "\n" + in + print_row header ; + List.iter + (fun (n, _) -> Printf.printf "%s-|" (padto '-' n "")) + (List.combine widths header) ; + Printf.printf "\n" ; List.iter print_row rows - diff --git a/src/direct_copy_stubs.c b/src/direct_copy_stubs.c new file mode 100644 index 0000000..7af5087 --- /dev/null +++ b/src/direct_copy_stubs.c @@ -0,0 +1,229 @@ +/* + * Copyright (C) 2012-2013 Citrix Inc + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU Lesser General Public License as published + * by the Free Software Foundation; version 2.1 only. with the special + * exception on linking described in file LICENSE. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Lesser General Public License for more details. + */ + +#define _GNU_SOURCE + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +/* ocaml/ocaml/unixsupport.c */ +extern void uerror(char *cmdname, value cmdarg); +#define Nothing ((value) 0) + +enum direct_copy_rc { + OK = 0, + TRIED_AND_FAILED = 1, + READ_FAILED = 2, + WRITE_FAILED = 3, + WRITE_UNEXPECTED_EOF = 4, + WRITE_POLL_FAILED = 5, + READ_POLL_FAILED = 6 +}; + +#define XFER_BUFSIZ (2*1024*1024) + +struct direct_copy_handle { + int in_fd; + int out_fd; + char *buffer; +}; + +CAMLprim value stub_init(value in_fd, value out_fd) +{ + CAMLparam2(in_fd, out_fd); + CAMLlocal1(result); + int c_in_fd = Int_val(in_fd); + int c_out_fd = Int_val(out_fd); + struct direct_copy_handle *cpinfo = NULL; + int flags; + + /* This is where we will keep the handle on return to OCaml. The + * Abstract tag teaches OCaml's garbage collector not to mess with + * it */ + result = alloc(1, Abstract_tag); + + /* initialise handle */ + cpinfo = malloc(sizeof(struct direct_copy_handle)); + if (!cpinfo) caml_raise_out_of_memory(); + cpinfo->buffer = NULL; + if (posix_memalign((void **)&cpinfo->buffer, sysconf(_SC_PAGESIZE), XFER_BUFSIZ)) { + free(cpinfo); + caml_raise_out_of_memory(); + } + cpinfo->in_fd = c_in_fd; + cpinfo->out_fd = c_out_fd; + +#ifdef __linux__ + /* Force the output to have O_DIRECT if possible. + Because it may not be possible, ignore any error + we might get on setting the flag. + */ + flags = fcntl(c_out_fd, F_GETFL, NULL); + if (flags > 0 && !(flags & O_DIRECT)) + fcntl(c_out_fd, F_SETFL, flags | O_DIRECT); +#endif + + Field(result, 0) = (uintptr_t)cpinfo; + CAMLreturn(result); + +} + +CAMLprim value stub_cleanup(value handle) +{ + CAMLparam1(handle); + struct direct_copy_handle *cpinfo = NULL; + + assert(Is_block(handle) && Tag_val(handle) == Abstract_tag); + cpinfo = (struct direct_copy_handle *)Field(handle, 0); + + free(cpinfo->buffer); + free(cpinfo); + Field(handle, 0) = (uintptr_t)NULL; + CAMLreturn(Val_unit); +} + +/* Wait for an fd. There will be a subsequent read() or write() + * to collect any fd error conditions that might occur */ +static inline int pollwait(int fd, short event) { + struct pollfd pfd; + + pfd.fd = fd; + pfd.events = event; + return poll(&pfd, 1, -1); +} + +CAMLprim value stub_direct_copy(value handle, value len){ + CAMLparam2(handle, len); + CAMLlocal1(result); + size_t c_len = Int64_val(len); + struct direct_copy_handle *cpinfo = NULL; + size_t bytes; + size_t remaining; + enum direct_copy_rc rc; + + assert(Is_block(handle) && Tag_val(handle) == Abstract_tag); + cpinfo = (struct direct_copy_handle *)Field(handle, 0); + if (!cpinfo) caml_failwith("direct_copy: NULL handle"); + + /* Calling enter_blocking_section() actually releases the OCaml + * runtime lock, so no OCaml exceptions may be thrown, and no OCaml + * values may be accessed, until it is reacquired. Also this + * means other OCaml threads may do things while this is going + * on so the caller must be careful. */ + enter_blocking_section(); + + rc = TRIED_AND_FAILED; + bytes = 0; + + remaining = c_len; + while (remaining > 0) { + ssize_t bread; + ssize_t bwritten = 0; + + bread = read(cpinfo->in_fd, cpinfo->buffer, (remaining < XFER_BUFSIZ)?remaining:XFER_BUFSIZ); + /* If we previously hit exactly the end of the input by accident, we're done. */ + if (bread == 0) break; + if (bread < 0) { + if (errno == EINTR) continue; + if (errno == EAGAIN) { + if (pollwait(cpinfo->in_fd, POLLIN) < 0) { + /* If poll() got interrupted, hitting read() (or, later, write() + * again one extra time to try again is insignificant, and avoids + * another loop */ + if (errno == EINTR) continue; + rc = READ_POLL_FAILED; + goto fail; + } + continue; + } + rc = READ_FAILED; + goto fail; + } + while (bwritten < bread) { + ssize_t ret; + + ret = write(cpinfo->out_fd, cpinfo->buffer + bwritten, bread - bwritten); + if (ret == 0) { + rc = WRITE_UNEXPECTED_EOF; + goto fail; + } + if (ret < 0) { + if (errno == EINTR) continue; + /* If someone passed us a non-blocking FD and we got + * EAGAIN, we need to keep trying, because the input FD + * could be something we cannot rewind. */ + if (errno == EAGAIN) { + if (pollwait(cpinfo->out_fd, POLLOUT) < 0) { + if (errno == EINTR) continue; + rc = WRITE_POLL_FAILED; + goto fail; + } + continue; + } + rc = WRITE_FAILED; + goto fail; + } + bytes += ret; + bwritten += ret; + remaining -= ret; + } + } + rc = OK; +fail: + + leave_blocking_section(); + /* Now that the OCaml runtime lock is reacquired, it is safe to + * raise OCaml exceptions */ + + switch (rc) { + case TRIED_AND_FAILED: + caml_failwith("direct_copy: General error"); + break; + case WRITE_FAILED: + uerror("write", Nothing); + break; + case READ_FAILED: + uerror("read", Nothing); + break; + case WRITE_UNEXPECTED_EOF: + caml_failwith("direct_copy: Unexpected EOF on write"); + break; + case WRITE_POLL_FAILED: + uerror("write poll", Nothing); + break; + case READ_POLL_FAILED: + uerror("read poll", Nothing); + break; + case OK: + break; + } + result = caml_copy_int64(bytes); + CAMLreturn(result); +} diff --git a/src/dune b/src/dune new file mode 100644 index 0000000..aa75570 --- /dev/null +++ b/src/dune @@ -0,0 +1,28 @@ +(library + (foreign_stubs + (language c) + (names direct_copy_stubs)) + (name local_lib) + (wrapped false) + (flags + (:standard -w -34-32-39)) + (libraries + cohttp-lwt + cstruct + io-page.unix + lwt + nbd-unix + re.str + rpclib + rpclib.json + sha + tar + vhd-format + vhd-format-lwt + xapi-tapctl + xenstore.client + xenstore.unix + xenstore_transport + xenstore_transport.unix) + (preprocess + (pps ppx_deriving_rpc ppx_cstruct))) diff --git a/src/get_vhd_vsize.ml b/src/get_vhd_vsize.ml deleted file mode 100644 index 07b0d1b..0000000 --- a/src/get_vhd_vsize.ml +++ /dev/null @@ -1,31 +0,0 @@ -open Lwt - -module Impl = Vhd.F.From_file(Vhd_lwt.IO) -open Impl -open Vhd.F -open Vhd_lwt.IO - -module In = From_input(Input) -open In - -let get_vhd_vsize filename = - Vhd_lwt.IO.openfile filename false >>= fun fd -> - let rec loop = function - | End -> return () - | Cons (hd, tl) -> - begin match hd with - | Fragment.Footer x -> - let size = x.Footer.current_size in - Printf.printf "%Ld\n" size; - exit 0 - | _ -> - () - end; - tl () >>= fun x -> - loop x in - openstream (Input.of_fd (Vhd_lwt.IO.to_file_descr fd)) >>= fun stream -> - loop stream >>= fun () -> Vhd_lwt.IO.close fd - -let _ = - let t = get_vhd_vsize Sys.argv.(1) in - Lwt_main.run t diff --git a/src/iO.ml b/src/iO.ml index 537ab18..e46fbd2 100644 --- a/src/iO.ml +++ b/src/iO.ml @@ -15,13 +15,13 @@ * *) -open Lwt - let debug_io = ref false let complete name offset op fd buffer = - if !debug_io - then Printf.fprintf stderr "%s offset=%s length=%d\n%!" name (match offset with Some x -> Int64.to_string x | None -> "None") (Cstruct.len buffer); + if !debug_io then + Printf.fprintf stderr "%s offset=%s length=%d\n%!" name + (match offset with Some x -> Int64.to_string x | None -> "None") + (Cstruct.len buffer) ; let open Lwt in let ofs = buffer.Cstruct.off in let len = buffer.Cstruct.len in @@ -30,19 +30,21 @@ let complete name offset op fd buffer = op fd buf ofs len >>= fun n -> let len' = len - n in let acc' = acc + n in - if len' = 0 || n = 0 - then return acc' - else loop acc' fd buf (ofs + n) len' in + if len' = 0 || n = 0 then + return acc' + else + loop acc' fd buf (ofs + n) len' + in loop 0 fd buf ofs len >>= fun n -> - if n = 0 && len <> 0 - then fail End_of_file - else return () + if n = 0 && len <> 0 then + fail End_of_file + else + return () let alloc bytes = - if bytes = 0 - then Cstruct.create 0 + if bytes = 0 then + Cstruct.create 0 else let n = (bytes + 4095) / 4096 in let pages = Io_page.(to_cstruct (get n)) in Cstruct.sub pages 0 bytes - diff --git a/src/image.ml b/src/image.ml index 2991beb..2a4b62a 100644 --- a/src/image.ml +++ b/src/image.ml @@ -1,33 +1,59 @@ -let (|>) a b = b a -module Opt = struct - let default d = function - | None -> d - | Some x -> x -end +let get_device_numbers path = + let rdev = (Unix.LargeFile.stat path).Unix.LargeFile.st_rdev in + let major = rdev / 256 and minor = rdev mod 256 in + (major, minor) -let startswith prefix x = - let prefix' = String.length prefix - and x' = String.length x in - prefix' <= x' && (String.sub x 0 prefix' = prefix) +let is_nbd_device path = + let nbd_device_num = 43 in + let major, _ = get_device_numbers path in + major = nbd_device_num -type t = [ - | `Vhd of string - | `Raw of string -] +module Opt = struct let default d = function None -> d | Some x -> x end + +type t = [`Vhd of string | `Raw of string | `Nbd of string * string] let to_string = function - | `Vhd x -> "vhd:" ^ x - | `Raw x -> "raw:" ^ x + | `Vhd x -> + "vhd:" ^ x + | `Raw x -> + "raw:" ^ x + | `Nbd (x, y) -> + Printf.sprintf "nbd:(%s,%s)" x y + +type nbd_connect_info = {path: string; exportname: string} [@@deriving rpc] + +let get_nbd_device path = + let nbd_device_prefix = "/dev/nbd" in + if + Astring.String.is_prefix ~affix:nbd_device_prefix path && is_nbd_device path + then + let nbd_number = + String.sub path + (String.length nbd_device_prefix) + (String.length path - String.length nbd_device_prefix) + in + let {path; exportname} = + let persistent_nbd_info_dir = "/var/run/nonpersistent/nbd" in + let filename = persistent_nbd_info_dir ^ "/" ^ nbd_number in + Xapi_stdext_unix.Unixext.string_of_file filename + |> Jsonrpc.of_string + |> nbd_connect_info_of_rpc + in + Some (`Nbd (path, exportname)) + else + None let of_device path = - try - match Tapctl.of_device (Tapctl.create ()) path with - | _, _, (Some ("vhd", vhd)) -> Some (`Vhd vhd) - | _, _, (Some ("aio", vhd)) -> Some (`Raw vhd) - | _, _, _ -> raise Not_found - with Tapctl.Not_blktap -> - None - | Tapctl.Not_a_device -> - None - | _ -> - None + match Tapctl.of_device (Tapctl.create ()) path with + | _, _, Some ("vhd", vhd) -> + Some (`Vhd vhd) + | _, _, Some ("aio", vhd) -> + Some (`Raw vhd) + | _, _, _ -> + None + | exception Tapctl.Not_blktap -> + get_nbd_device path + | exception Tapctl.Not_a_device -> + None + | exception _ -> + None diff --git a/src/image.mli b/src/image.mli index e917346..439042d 100644 --- a/src/image.mli +++ b/src/image.mli @@ -1,14 +1,10 @@ - -type t = [ - | `Vhd of string - | `Raw of string -] (** An image may either be backed by a vhd-format file or a raw-format file. *) +type t = [`Vhd of string | `Raw of string | `Nbd of string * string] -val to_string: t -> string +val to_string : t -> string (** Pretty-print the image *) -val of_device: string -> t option +val of_device : string -> t option (** Examine the provided path and return the backing image, or None if one doesn't exist. *) diff --git a/src/impl.ml b/src/impl.ml index 7597c6d..03954b3 100644 --- a/src/impl.ml +++ b/src/impl.ml @@ -13,285 +13,351 @@ *) open Common -open Cmdliner open Lwt +module F = Vhd_format.F.From_file (Vhd_format_lwt.IO) +module In = Vhd_format.F.From_input (Input) -external sendfile: Unix.file_descr -> Unix.file_descr -> int64 -> int64 = "stub_sendfile64" - -module F = Vhd.F.From_file(Vhd_lwt.IO) -module In = Vhd.F.From_input(Input) -module Channel_In = Vhd.F.From_input(struct +module Channel_In = Vhd_format.F.From_input (struct include Lwt + type fd = Channels.t + let read c buf = c.Channels.really_read buf + let scratch = IO.alloc (1024 * 1024) + let skip_to c offset = let rec drop remaining = - if remaining = 0L - then return () + if remaining = 0L then + return () else - let this = Int64.(to_int (min (of_int (Cstruct.len scratch)) remaining)) in + let this = + Int64.(to_int (min (of_int (Cstruct.len scratch)) remaining)) + in let frag = Cstruct.sub scratch 0 this in - read c frag >>= fun () -> - drop Int64.(sub remaining (of_int this)) in + read c frag >>= fun () -> drop Int64.(sub remaining (of_int this)) + in drop Int64.(sub offset !(c.Channels.offset)) end) + open F + (* open Vhd -open Vhd_lwt +open Vhd_format_lwt *) let vhd_search_path = "/dev/mapper" -let require name arg = match arg with - | None -> failwith (Printf.sprintf "Please supply a %s argument" name) - | Some x -> x +let require name arg = + match arg with + | None -> + failwith (Printf.sprintf "Please supply a %s argument" name) + | Some x -> + x -let get common filename key = +let get _common filename key = try let filename = require "filename" filename in let key = require "key" key in let t = Vhd_IO.openfile filename false >>= fun t -> - let result = Vhd.F.Vhd.Field.get t key in - Vhd_IO.close t >>= fun () -> - return result in + let result = Vhd_format.F.Vhd.Field.get t key in + Vhd_IO.close t >>= fun () -> return result + in match Lwt_main.run t with | Some v -> - Printf.printf "%s\n" v; - `Ok () - | None -> raise Not_found + Printf.printf "%s\n" v ; `Ok () + | None -> + raise Not_found with - | Failure x -> - `Error(true, x) - | Not_found -> - `Error(true, Printf.sprintf "Unknown key. Known keys are: %s" (String.concat ", " Vhd.F.Vhd.Field.list)) - -let info common filename = + | Failure x -> + `Error (true, x) + | Not_found -> + `Error + ( true + , Printf.sprintf "Unknown key. Known keys are: %s" + (String.concat ", " Vhd_format.F.Vhd.Field.list) ) + +let info _common filename = try let filename = require "filename" filename in let t = Vhd_IO.openfile filename false >>= fun t -> - let all = List.map (fun f -> - match Vhd.F.Vhd.Field.get t f with - | Some v -> [ f; v ] - | None -> [ f; "" ] - ) Vhd.F.Vhd.Field.list in - print_table ["field"; "value"] all; - return () in - Lwt_main.run t; - `Ok () - with Failure x -> - `Error(true, x) + let all = + List.map + (fun f -> + match Vhd_format.F.Vhd.Field.get t f with + | Some v -> + [f; v] + | None -> + [f; ""]) + Vhd_format.F.Vhd.Field.list + in + print_table ["field"; "value"] all ; + return () + in + Lwt_main.run t ; `Ok () + with Failure x -> `Error (true, x) -let contents common filename = +let contents _common filename = try let filename = require "filename" filename in let t = let open In in - Vhd_lwt.IO.openfile filename false >>= fun fd -> + Vhd_format_lwt.IO.openfile filename false >>= fun fd -> let rec loop = function - | End -> return () - | Cons (hd, tl) -> - let open Vhd.F in - begin match hd with - | Fragment.Header x -> - Printf.printf "Header\n" - | Fragment.Footer x -> - Printf.printf "Footer\n" - | Fragment.BAT x -> - Printf.printf "BAT\n" - | Fragment.Batmap x -> - Printf.printf "batmap\n" - | Fragment.Block (offset, buffer) -> - Printf.printf "Block %Ld (len %d)\n" offset (Cstruct.len buffer) - end; - tl () >>= fun x -> - loop x in - openstream (Input.of_fd (Vhd_lwt.IO.to_file_descr fd)) >>= fun stream -> - loop stream in - Lwt_main.run t; - `Ok () - with Failure x -> - `Error(true, x) + | End -> + return () + | Cons (hd, tl) -> + let open Vhd_format.F in + ( match hd with + | Fragment.Header _x -> + Printf.printf "Header\n" + | Fragment.Footer _x -> + Printf.printf "Footer\n" + | Fragment.BAT _x -> + Printf.printf "BAT\n" + | Fragment.Batmap _x -> + Printf.printf "batmap\n" + | Fragment.Block (offset, buffer) -> + Printf.printf "Block %Ld (len %d)\n" offset (Cstruct.len buffer) + ) ; + tl () >>= fun x -> loop x + in + Vhd_format_lwt.IO.get_file_size filename >>= fun size -> + openstream (Some size) (Input.of_fd (Vhd_format_lwt.IO.to_file_descr fd)) + >>= fun stream -> loop stream + in + Lwt_main.run t ; `Ok () + with Failure x -> `Error (true, x) let create common filename size parent = try - begin let filename = require "filename" filename in - match parent, size with - | None, None -> failwith "Please supply either a size or a parent" - | None, Some size -> - let size = parse_size size in - let t = - Vhd_IO.create_dynamic ~filename ~size () >>= fun vhd -> - Vhd_IO.close vhd in - Lwt_main.run t - | Some parent, None -> - let t = - Vhd_IO.openchain ~path:common.path parent false >>= fun parent -> - Vhd_IO.create_difference ~filename ~parent () >>= fun vhd -> - Vhd_IO.close parent >>= fun () -> - Vhd_IO.close vhd >>= fun () -> - return () in - Lwt_main.run t - | Some parent, Some size -> - failwith "Overriding the size in a child node not currently implemented" - end; - `Ok () - with Failure x -> - `Error(true, x) + (let filename = require "filename" filename in + match (parent, size) with + | None, None -> + failwith "Please supply either a size or a parent" + | None, Some size -> + let size = parse_size size in + let t = + Vhd_IO.create_dynamic ~filename ~size () >>= fun vhd -> + Vhd_IO.close vhd + in + Lwt_main.run t + | Some parent, None -> + let t = + Vhd_IO.openchain ~path:common.path parent false >>= fun parent -> + Vhd_IO.create_difference ~filename ~parent () >>= fun vhd -> + Vhd_IO.close parent >>= fun () -> + Vhd_IO.close vhd >>= fun () -> return () + in + Lwt_main.run t + | Some _parent, Some _size -> + failwith + "Overriding the size in a child node not currently implemented") ; + `Ok () + with Failure x -> `Error (true, x) let check common filename = try let filename = require "filename" filename in let t = Vhd_IO.openchain ~path:common.path filename false >>= fun vhd -> - Vhd.F.Vhd.check_overlapping_blocks vhd; - return () in - Lwt_main.run t; - `Ok () - with Failure x -> - `Error(true, x) + Vhd_format.F.Vhd.check_overlapping_blocks vhd ; + return () + in + Lwt_main.run t ; `Ok () + with Failure x -> `Error (true, x) -module P = Progress_bar(Int64) +module P = Progress_bar (Int64) let console_progress_bar total_work = let p = P.create 80 0L total_work in fun work_done -> let progress_updated = P.update p work_done in - if progress_updated then P.print_bar p; - if work_done = total_work then begin - Printf.printf "\n"; - P.summarise p; - Printf.printf "%!" - end + if progress_updated then P.print_bar p ; + if work_done = total_work then ( + Printf.printf "\n" ; P.summarise p ; Printf.printf "%!" + ) let machine_progress_bar total_work = let last_percent = ref (-1) in fun work_done -> let new_percent = Int64.(to_int (div (mul work_done 100L) total_work)) in - if new_percent <= 100 && !last_percent <> new_percent then begin - Printf.printf "%03d%!" new_percent; + if new_percent <= 100 && !last_percent <> new_percent then ( + Printf.printf "%03d%!" new_percent ; last_percent := new_percent - end + ) let no_progress_bar _ _ = () -let stream_human common _ s _ _ ?(progress = no_progress_bar) () = +let[@warning "-27"] stream_human _common _ s _ _ ?(progress = no_progress_bar) + () = let decimal_digits = - let open Vhd.F in + let open Vhd_format.F in (* How much space will we need for the sector numbers? *) let sectors = Int64.(shift_right (add s.size.total 511L) sector_shift) in let decimal_digits = int_of_float (ceil (log10 (Int64.to_float sectors))) in - Printf.printf "# stream summary:\n"; - Printf.printf "# size of the final artifact: %Ld\n" s.size.total; - Printf.printf "# size of metadata blocks: %Ld\n" s.size.metadata; - Printf.printf "# size of empty space: %Ld\n" s.size.empty; - Printf.printf "# size of referenced blocks: %Ld\n" s.size.copy; - Printf.printf "# offset : contents\n"; - decimal_digits in - fold_left (fun sector x -> - Printf.printf "%s: %s\n" - (padto ' ' decimal_digits (Int64.to_string sector)) - (Vhd.Element.to_string x); - return (Int64.add sector (Vhd.Element.len x)) - ) 0L s.elements >>= fun _ -> - Printf.printf "# end of stream\n"; + Printf.printf "# stream summary:\n" ; + Printf.printf "# size of the final artifact: %Ld\n" s.size.total ; + Printf.printf "# size of metadata blocks: %Ld\n" s.size.metadata ; + Printf.printf "# size of empty space: %Ld\n" s.size.empty ; + Printf.printf "# size of referenced blocks: %Ld\n" s.size.copy ; + Printf.printf "# offset : contents\n" ; + decimal_digits + in + fold_left + (fun sector x -> + Printf.printf "%s: %s\n" + (padto ' ' decimal_digits (Int64.to_string sector)) + (Vhd_format.Element.to_string x) ; + return (Int64.add sector (Vhd_format.Element.len x))) + 0L s.elements + >>= fun _ -> + Printf.printf "# end of stream\n" ; return None -let stream_nbd common c s prezeroed _ ?(progress = no_progress_bar) () = - let c = { Nbd_lwt_client.read = c.Channels.really_read; write = c.Channels.really_write } in - - Nbd_lwt_client.negotiate c >>= fun (server, size, flags) -> +let stream_nbd _common c s prezeroed _ ?(progress = no_progress_bar) () = + let open Nbd_unix in + let c = + { + Nbd.Channel.read= c.Channels.really_read + ; write= c.Channels.really_write + ; close= c.Channels.close + ; is_tls= false + } + in + + Client.negotiate c "" >>= fun (server, _size, _flags) -> (* Work to do is: non-zero data to write + empty sectors if the target is not prezeroed *) - let total_work = let open Vhd.F in Int64.(add (add s.size.metadata s.size.copy) (if prezeroed then 0L else s.size.empty)) in + let total_work = + let open Vhd_format.F in + Int64.( + add + (add s.size.metadata s.size.copy) + (if prezeroed then 0L else s.size.empty)) + in let p = progress total_work in - ( if not prezeroed then expand_empty s else return s ) >>= fun s -> + (if not prezeroed then expand_empty s else return s) >>= fun s -> expand_copy s >>= fun s -> - - fold_left (fun (sector, work_done) x -> - ( match x with - | `Sectors data -> - Nbd_lwt_client.write server data (Int64.mul sector 512L) >>= fun () -> - return Int64.(of_int (Cstruct.len data)) - | `Empty n -> (* must be prezeroed *) - assert prezeroed; - return 0L - | _ -> fail (Failure (Printf.sprintf "unexpected stream element: %s" (Vhd.Element.to_string x))) ) >>= fun work -> - let sector = Int64.add sector (Vhd.Element.len x) in - let work_done = Int64.add work_done work in - p work_done; - return (sector, work_done) - ) (0L, 0L) s.elements >>= fun _ -> - p total_work; + fold_left + (fun (sector, work_done) x -> + ( match x with + | `Sectors data -> ( + Client.write server (Int64.mul sector 512L) [data] >>= function + | Ok () -> + return Int64.(of_int (Cstruct.len data)) + | Error _e -> + fail (Failure "Got error from NBD library") + ) + | `Empty _n -> + (* must be prezeroed *) + assert prezeroed ; + return 0L + | _ -> + fail + (Failure + (Printf.sprintf "unexpected stream element: %s" + (Vhd_format.Element.to_string x))) + ) + >>= fun work -> + let sector = Int64.add sector (Vhd_format.Element.len x) in + let work_done = Int64.add work_done work in + p work_done ; + return (sector, work_done)) + (0L, 0L) s.elements + >>= fun _ -> + p total_work ; return (Some total_work) -let stream_chunked common c s prezeroed _ ?(progress = no_progress_bar) () = +let stream_chunked _common c s prezeroed _ ?(progress = no_progress_bar) () = (* Work to do is: non-zero data to write + empty sectors if the target is not prezeroed *) - let total_work = let open Vhd.F in Int64.(add (add s.size.metadata s.size.copy) (if prezeroed then 0L else s.size.empty)) in + let total_work = + let open Vhd_format.F in + Int64.( + add + (add s.size.metadata s.size.copy) + (if prezeroed then 0L else s.size.empty)) + in let p = progress total_work in - ( if not prezeroed then expand_empty s else return s ) >>= fun s -> + (if not prezeroed then expand_empty s else return s) >>= fun s -> expand_copy s >>= fun s -> - let header = Cstruct.create Chunked.sizeof in - fold_left (fun(sector, work_done) x -> - ( match x with + fold_left + (fun (sector, work_done) x -> + ( match x with | `Sectors data -> - let t = { Chunked.offset = Int64.(mul sector 512L); data } in - Chunked.marshal header t; - c.Channels.really_write header >>= fun () -> - c.Channels.really_write data >>= fun () -> - return Int64.(of_int (Cstruct.len data)) - | `Empty n -> (* must be prezeroed *) - assert prezeroed; - return 0L - | _ -> fail (Failure (Printf.sprintf "unexpected stream element: %s" (Vhd.Element.to_string x))) ) >>= fun work -> - let sector = Int64.add sector (Vhd.Element.len x) in - let work_done = Int64.add work_done work in - p work_done; - return (sector, work_done) - ) (0L, 0L) s.elements >>= fun _ -> - p total_work; + let t = {Chunked.offset= Int64.(mul sector 512L); data} in + Chunked.marshal header t ; + c.Channels.really_write header >>= fun () -> + c.Channels.really_write data >>= fun () -> + return Int64.(of_int (Cstruct.len data)) + | `Empty _n -> + (* must be prezeroed *) + assert prezeroed ; + return 0L + | _ -> + fail + (Failure + (Printf.sprintf "unexpected stream element: %s" + (Vhd_format.Element.to_string x))) + ) + >>= fun work -> + let sector = Int64.add sector (Vhd_format.Element.len x) in + let work_done = Int64.add work_done work in + p work_done ; + return (sector, work_done)) + (0L, 0L) s.elements + >>= fun _ -> + p total_work ; (* Send the end-of-stream marker *) - Chunked.marshal header { Chunked.offset = 0L; data = Cstruct.create 0 }; - c.Channels.really_write header >>= fun () -> - - return (Some total_work) + Chunked.marshal header {Chunked.offset= 0L; data= Cstruct.create 0} ; + c.Channels.really_write header >>= fun () -> return (Some total_work) -let stream_raw common c s prezeroed _ ?(progress = no_progress_bar) () = +let stream_raw _common c s prezeroed _ ?(progress = no_progress_bar) () = (* Work to do is: non-zero data to write + empty sectors if the target is not prezeroed *) - let total_work = let open Vhd.F in Int64.(add (add s.size.metadata s.size.copy) (if prezeroed then 0L else s.size.empty)) in + let total_work = + let open Vhd_format.F in + Int64.( + add + (add s.size.metadata s.size.copy) + (if prezeroed then 0L else s.size.empty)) + in let p = progress total_work in - ( if not prezeroed then expand_empty s else return s ) >>= fun s -> - - fold_left (fun work_done x -> - (match x with - | `Copy(fd, sector_start, sector_len) -> - let fd = Vhd_lwt.IO.to_file_descr fd in - Lwt_unix.LargeFile.lseek fd (Int64.mul 512L sector_start) Unix.SEEK_SET - >>= fun (_: int64) -> - c.Channels.copy_from fd (Int64.mul 512L sector_len) + (if not prezeroed then expand_empty s else return s) >>= fun s -> + fold_left + (fun work_done x -> + ( match x with + | `Copy (fd, sector_start, sector_len) -> + let fd = Vhd_format_lwt.IO.to_file_descr fd in + Lwt_unix.LargeFile.lseek fd + (Int64.mul 512L sector_start) + Unix.SEEK_SET + >>= fun (_ : int64) -> + c.Channels.copy_from fd (Int64.mul 512L sector_len) | `Sectors data -> - c.Channels.really_write data >>= fun () -> - return Int64.(of_int (Cstruct.len data)) - | `Empty n -> (* must be prezeroed *) - c.Channels.skip (Int64.(mul n 512L)) >>= fun () -> - assert prezeroed; - return 0L - ) >>= fun work -> - let work_done = Int64.add work_done work in - p work_done; - return work_done - ) 0L s.elements >>= fun _ -> - p total_work; + c.Channels.really_write data >>= fun () -> + return Int64.(of_int (Cstruct.len data)) + | `Empty n -> + (* must be prezeroed *) + c.Channels.skip Int64.(mul n 512L) >>= fun () -> + assert prezeroed ; + return 0L + ) + >>= fun work -> + let work_done = Int64.add work_done work in + p work_done ; return work_done) + 0L s.elements + >>= fun _ -> + p total_work ; return (Some total_work) @@ -299,30 +365,41 @@ let sha1_update_cstruct ctx buffer = let ofs = buffer.Cstruct.off in let len = buffer.Cstruct.len in let buf = buffer.Cstruct.buffer in - let buffer' : (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t = Bigarray.Array1.sub buf ofs len in + let buffer' : + (char, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t = + Bigarray.Array1.sub buf ofs len + in (* XXX: need a better way to do this *) - let buffer'': (int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t = Obj.magic buffer' in - Sha1.update_buffer ctx buffer'' + (* let buffer'': (int, Bigarray.int8_unsigned_elt, Bigarray.c_layout) Bigarray.Array1.t = Obj.magic buffer' in *) + Sha1.update_buffer ctx buffer' module TarStream = struct type t = { - work_done: int64; - total_size: int64; - ctx: Sha1.ctx; - nr_bytes_remaining: int; (* start at 0 *) - next_counter: int; - mutable header: Tar.Header.t option; + work_done: int64 + ; total_size: int64 + ; ctx: Sha1.ctx + ; nr_bytes_remaining: int + ; (* start at 0 *) + next_counter: int + ; mutable header: Tar.Header.t option } let to_string t = - Printf.sprintf "work_done = %Ld; nr_bytes_remaining = %d; next_counter = %d; filename = %s" + Printf.sprintf + "work_done = %Ld; nr_bytes_remaining = %d; next_counter = %d; filename = \ + %s" t.work_done t.nr_bytes_remaining t.next_counter (match t.header with None -> "None" | Some h -> h.Tar.Header.file_name) - let initial total_size = { - work_done = 0L; ctx = Sha1.init (); nr_bytes_remaining = 0; - next_counter = 0; header = None; total_size - } + let initial total_size = + { + work_done= 0L + ; ctx= Sha1.init () + ; nr_bytes_remaining= 0 + ; next_counter= 0 + ; header= None + ; total_size + } let make_tar_header prefix counter suffix file_size = Tar.Header.make @@ -332,7 +409,7 @@ module TarStream = struct (Int64.of_int file_size) end -let stream_tar common c s _ prefix ?(progress = no_progress_bar) () = +let stream_tar _common c s _ prefix ?(progress = no_progress_bar) () = let open TarStream in let prefix = match prefix with None -> "" | Some x -> x in let block_size = 1024 * 1024 in @@ -340,218 +417,302 @@ let stream_tar common c s _ prefix ?(progress = no_progress_bar) () = let zeroes = IO.alloc block_size in for i = 0 to Cstruct.len zeroes - 1 do Cstruct.set_uint8 zeroes i 0 - done; + done ; (* This undercounts by missing the tar headers and occasional empty sector *) - let total_work = let open Vhd.F in Int64.(add s.size.metadata s.size.copy) in + let total_work = + let open Vhd_format.F in + Int64.(add s.size.metadata s.size.copy) + in let p = progress total_work in expand_copy s >>= fun s -> - (* Write [data] to the tar-format stream currnetly in [state] *) let rec input state data = (* Write as much as we can into the current file *) let len = Cstruct.len data in let this_block_len = min len state.nr_bytes_remaining in let this_block = Cstruct.sub data 0 this_block_len in - sha1_update_cstruct state.ctx this_block; + sha1_update_cstruct state.ctx this_block ; c.Channels.really_write this_block >>= fun () -> let nr_bytes_remaining = state.nr_bytes_remaining - this_block_len in - let state = { state with nr_bytes_remaining } in + let state = {state with nr_bytes_remaining} in let rest = Cstruct.shift data this_block_len in (* If we've hit the end of a block then output the hash *) - ( if nr_bytes_remaining = 0 then match state.header with - | Some hdr -> - c.Channels.really_write (Tar.Header.zero_padding hdr) >>= fun () -> - let hash = Sha1.(to_hex (finalize state.ctx)) in - let ctx = Sha1.init () in - let hdr' = { hdr with - Tar.Header.file_name = hdr.Tar.Header.file_name ^ ".checksum"; - file_size = Int64.of_int (String.length hash) - } in - Tar.Header.marshal header hdr'; - c.Channels.really_write header >>= fun () -> - Cstruct.blit_from_string hash 0 header 0 (String.length hash); - c.Channels.really_write (Cstruct.sub header 0 (String.length hash)) >>= fun () -> - c.Channels.really_write (Tar.Header.zero_padding hdr') >>= fun () -> - return { state with ctx; header = None } - | None -> - return state - else return state ) >>= fun state -> - + ( if nr_bytes_remaining = 0 then + match state.header with + | Some hdr -> + c.Channels.really_write (Tar.Header.zero_padding hdr) >>= fun () -> + let hash = Sha1.(to_hex (finalize state.ctx)) in + let ctx = Sha1.init () in + let hdr' = + { + hdr with + Tar.Header.file_name= hdr.Tar.Header.file_name ^ ".checksum" + ; file_size= Int64.of_int (String.length hash) + } + in + Tar.Header.marshal header hdr' ; + c.Channels.really_write header >>= fun () -> + Cstruct.blit_from_string hash 0 header 0 (String.length hash) ; + c.Channels.really_write (Cstruct.sub header 0 (String.length hash)) + >>= fun () -> + c.Channels.really_write (Tar.Header.zero_padding hdr') >>= fun () -> + return {state with ctx; header= None} + | None -> + return state + else + return state + ) + >>= fun state -> (* If we have unwritten data then output the next header *) - ( if nr_bytes_remaining = 0 && Cstruct.len rest > 0 then begin + ( if nr_bytes_remaining = 0 && Cstruct.len rest > 0 then ( (* XXX the last block might be smaller than block_size *) let hdr = make_tar_header prefix state.next_counter "" block_size in - Tar.Header.marshal header hdr; + Tar.Header.marshal header hdr ; c.Channels.really_write header >>= fun () -> - return { state with nr_bytes_remaining = block_size; - next_counter = state.next_counter + 1; - header = Some hdr } - end else return { state with nr_bytes_remaining } ) >>= fun state -> - - if Cstruct.len rest > 0 - then input state rest - else return state in + return + { + state with + nr_bytes_remaining= block_size + ; next_counter= state.next_counter + 1 + ; header= Some hdr + } + ) else + return {state with nr_bytes_remaining} + ) + >>= fun state -> + if Cstruct.len rest > 0 then + input state rest + else + return state + in let rec empty state bytes = let write state bytes = let this = Int64.(to_int (min bytes (of_int (Cstruct.len zeroes)))) in input state (Cstruct.sub zeroes 0 this) >>= fun state -> - empty state Int64.(sub bytes (of_int this)) in - if bytes = 0L - then return state - (* If we're in the middle of a block, then complete it *) - else if 0 < state.nr_bytes_remaining && state.nr_bytes_remaining < block_size - then begin + empty state Int64.(sub bytes (of_int this)) + in + if bytes = 0L then + return state (* If we're in the middle of a block, then complete it *) + else if + 0 < state.nr_bytes_remaining && state.nr_bytes_remaining < block_size + then let this = min (Int64.of_int state.nr_bytes_remaining) bytes in - write state this >>= fun state -> - empty state (Int64.sub bytes this) + write state this >>= fun state -> empty state (Int64.sub bytes this) (* If we're the first or last block then always include *) - end else if state.work_done = 0L || Int64.(sub state.total_size state.work_done <= (of_int block_size)) - then write state bytes - else if bytes >= (Int64.of_int block_size) then begin + else if + state.work_done = 0L + || Int64.(sub state.total_size state.work_done <= of_int block_size) + then + write state bytes + else if bytes >= Int64.of_int block_size then (* If n > block_size (in sectors) then we can omit empty blocks *) - empty { state with next_counter = state.next_counter + 1 } Int64.(sub bytes (of_int block_size)) - end else write state bytes in - let module E = Vhd.Element in - fold_left (fun state x -> - (match x with + empty + {state with next_counter= state.next_counter + 1} + Int64.(sub bytes (of_int block_size)) + else + write state bytes + in + let module E = Vhd_format.Element in + fold_left + (fun state x -> + ( match x with | `Sectors data -> - input state data + input state data | `Empty n -> - empty state (Int64.(mul n 512L)) - | _ -> fail (Failure (Printf.sprintf "unexpected stream element: %s" (Vhd.Element.to_string x))) ) >>= fun state -> - let work = Int64.mul (E.len x) 512L in - let work_done = Int64.add state.work_done work in - p work_done; - return { state with work_done } - ) (initial s.size.Vhd.F.total) s.elements >>= fun _ -> - p total_work; + empty state Int64.(mul n 512L) + | _ -> + fail + (Failure + (Printf.sprintf "unexpected stream element: %s" + (Vhd_format.Element.to_string x))) + ) + >>= fun state -> + let work = Int64.mul (E.len x) 512L in + let work_done = Int64.add state.work_done work in + p work_done ; + return {state with work_done}) + (initial s.size.Vhd_format.F.total) + s.elements + >>= fun _ -> + p total_work ; return (Some total_work) module TarInput = struct type t = { - ctx: Sha1.ctx; - offset: int64; - detected_block_size: int64 option; - last_sequence_number: int; - } - let initial () = { - ctx = Sha1.init (); - offset = 0L; - detected_block_size = None; - last_sequence_number = -1; + ctx: Sha1.ctx + ; offset: int64 + ; detected_block_size: int64 option + ; last_sequence_number: int } + + let initial () = + { + ctx= Sha1.init () + ; offset= 0L + ; detected_block_size= None + ; last_sequence_number= -1 + } end let startswith prefix x = let prefix_len = String.length prefix in let x_len = String.length x in - x_len >= prefix_len && (String.sub x 0 prefix_len = prefix) + x_len >= prefix_len && String.sub x 0 prefix_len = prefix let endswith suffix x = let suffix_len = String.length suffix in let x_len = String.length x in - x_len >= suffix_len && (String.sub x (x_len - suffix_len) suffix_len = suffix) + x_len >= suffix_len && String.sub x (x_len - suffix_len) suffix_len = suffix let serve_vhd_to_raw total_size c dest prezeroed progress _ _ = - if not prezeroed then failwith "unimplemented: prezeroed"; + if not prezeroed then failwith "unimplemented: prezeroed" ; let p = ref None in let open Channel_In in - let open Vhd.F in + let open Vhd_format.F in let rec loop block_size_sectors_shift last_block blocks_seen = function - | End -> return () - | Cons (Fragment.Header h, tl) -> tl () >>= loop h.Header.block_size_sectors_shift last_block blocks_seen + | End -> + return () + | Cons (Fragment.Header h, tl) -> + tl () >>= loop h.Header.block_size_sectors_shift last_block blocks_seen | Cons (Fragment.BAT x, tl) -> - (* total_size = number of bits set in the BAT *) - let total_size = BAT.fold (fun _ _ acc -> Int64.succ acc) x 0L in - p := Some (progress total_size); - tl () >>= loop block_size_sectors_shift last_block blocks_seen + (* total_size = number of bits set in the BAT *) + let total_size = BAT.fold (fun _ _ acc -> Int64.succ acc) x 0L in + p := Some (progress total_size) ; + tl () >>= loop block_size_sectors_shift last_block blocks_seen | Cons (Fragment.Block (offset, data), tl) -> - Vhd_lwt.IO.really_write dest (Int64.shift_left offset sector_shift) data >>= fun () -> - let this_block = Int64.(shift_right offset block_size_sectors_shift) in - let blocks_seen = if last_block <> this_block then Int64.succ blocks_seen else blocks_seen in - (match !p with Some p -> p blocks_seen | None -> ()); - tl () >>= loop block_size_sectors_shift this_block blocks_seen - | Cons (_, tl) -> tl () >>= loop block_size_sectors_shift last_block blocks_seen in - openstream c >>= fun stream -> - loop 0 (-1L) 0L stream - -let serve_tar_to_raw total_size c dest prezeroed progress expected_prefix ignore_checksums = - let module M = Tar.Archive(Lwt) in + Vhd_format_lwt.IO.really_write dest + (Int64.shift_left offset sector_shift) + data + >>= fun () -> + let this_block = Int64.(shift_right offset block_size_sectors_shift) in + let blocks_seen = + if last_block <> this_block then + Int64.succ blocks_seen + else + blocks_seen + in + (match !p with Some p -> p blocks_seen | None -> ()) ; + tl () >>= loop block_size_sectors_shift this_block blocks_seen + | Cons (_, tl) -> + tl () >>= loop block_size_sectors_shift last_block blocks_seen + in + openstream (Some total_size) c >>= fun stream -> loop 0 (-1L) 0L stream + +let serve_tar_to_raw total_size c dest prezeroed progress expected_prefix + ignore_checksums = let twomib = 2 * 1024 * 1024 in let buffer = IO.alloc twomib in let header = IO.alloc 512 in - if not prezeroed then failwith "unimplemented: prezeroed"; + if not prezeroed then failwith "unimplemented: prezeroed" ; let p = progress total_size in let open TarInput in let rec loop t = - p t.offset; - if t.offset = total_size - then return () + p t.offset ; + if t.offset = total_size then + return () else c.Channels.really_read header >>= fun () -> match Tar.Header.unmarshal header with - | None -> fail (Failure "failed to unmarshal header") + | None -> + fail (Failure "failed to unmarshal header") | Some hdr -> - ( match expected_prefix with - | None -> return (Filename.basename hdr.Tar.Header.file_name) + ( match expected_prefix with + | None -> + return (Filename.basename hdr.Tar.Header.file_name) | Some p -> - if not(startswith p hdr.Tar.Header.file_name) - then fail (Failure (Printf.sprintf "expected filename prefix %s, got %s" p hdr.Tar.Header.file_name)) + if not (startswith p hdr.Tar.Header.file_name) then + fail + (Failure + (Printf.sprintf "expected filename prefix %s, got %s" p + hdr.Tar.Header.file_name)) + else + let p_len = String.length p in + let file_name_len = String.length hdr.Tar.Header.file_name in + let filename = + String.sub hdr.Tar.Header.file_name p_len + (file_name_len - p_len) + in + return (Filename.basename filename) + ) + >>= fun filename -> + let zero = + Cstruct.sub header 0 (Tar.Header.compute_zero_padding_length hdr) + in + (* either 'counter' or 'counter.checksum' *) + if endswith ".checksum" filename then + let checksum = + Cstruct.sub buffer 0 (Int64.to_int hdr.Tar.Header.file_size) + in + c.Channels.really_read checksum >>= fun () -> + c.Channels.really_read zero >>= fun () -> + if ignore_checksums then + loop t else - let p_len = String.length p in - let file_name_len = String.length hdr.Tar.Header.file_name in - let filename = String.sub hdr.Tar.Header.file_name p_len (file_name_len - p_len) in - return (Filename.basename filename)) >>= fun filename -> - let zero = Cstruct.sub header 0 (Tar.Header.compute_zero_padding_length hdr) in - (* either 'counter' or 'counter.checksum' *) - if endswith ".checksum" filename then begin - let checksum = Cstruct.sub buffer 0 (Int64.to_int hdr.Tar.Header.file_size) in - c.Channels.really_read checksum >>= fun () -> - c.Channels.really_read zero >>= fun () -> - if ignore_checksums - then loop t - else begin - let checksum' = Cstruct.to_string checksum in - let hash = Sha1.(to_hex (finalize t.ctx)) in - if checksum' <> hash then begin - Printf.fprintf stderr "Unexpected checksum in %s: expected %s, we computed %s\n" - hdr.Tar.Header.file_name checksum' hash; - fail (Failure (Printf.sprintf "Unexpected checksum in block %s" hdr.Tar.Header.file_name)) - end else loop { t with ctx = Sha1.init () } - end - end else begin - let block_size = match t.detected_block_size with - | None -> hdr.Tar.Header.file_size - | Some x -> x in - ( try return (int_of_string filename) - with _ -> fail (Failure (Printf.sprintf "Expected sequence number, got %s" filename)) ) >>= fun sequence_number -> - let skipped_blocks = sequence_number - t.last_sequence_number - 1 in - let to_skip = Int64.(mul (of_int skipped_blocks) block_size) in - let offset = Int64.(add t.offset to_skip) in - (* XXX: prezeroed? *) - let rec copy offset remaining = - let this = Int64.(to_int (min remaining (of_int (Cstruct.len buffer)))) in - let block = Cstruct.sub buffer 0 this in - c.Channels.really_read block >>= fun () -> - Vhd_lwt.IO.really_write dest offset block >>= fun () -> - if not ignore_checksums then sha1_update_cstruct t.ctx block; - let remaining = Int64.(sub remaining (of_int this)) in - let offset = Int64.(add offset (of_int this)) in - if remaining = 0L - then return offset - else copy offset remaining in - copy offset hdr.Tar.Header.file_size >>= fun offset -> - c.Channels.really_read zero >>= fun () -> - loop { t with offset; detected_block_size = Some block_size; last_sequence_number = sequence_number } - end in + let checksum' = Cstruct.to_string checksum in + let hash = Sha1.(to_hex (finalize t.ctx)) in + if checksum' <> hash then ( + Printf.fprintf stderr + "Unexpected checksum in %s: expected %s, we computed %s\n" + hdr.Tar.Header.file_name checksum' hash ; + fail + (Failure + (Printf.sprintf "Unexpected checksum in block %s" + hdr.Tar.Header.file_name)) + ) else + loop {t with ctx= Sha1.init ()} + else + let block_size = + match t.detected_block_size with + | None -> + hdr.Tar.Header.file_size + | Some x -> + x + in + ( try return (int_of_string filename) + with _ -> + fail + (Failure + (Printf.sprintf "Expected sequence number, got %s" + filename)) + ) + >>= fun sequence_number -> + let skipped_blocks = sequence_number - t.last_sequence_number - 1 in + let to_skip = Int64.(mul (of_int skipped_blocks) block_size) in + let offset = Int64.(add t.offset to_skip) in + (* XXX: prezeroed? *) + let rec copy offset remaining = + let this = + Int64.(to_int (min remaining (of_int (Cstruct.len buffer)))) + in + let block = Cstruct.sub buffer 0 this in + c.Channels.really_read block >>= fun () -> + Vhd_format_lwt.IO.really_write dest offset block >>= fun () -> + if not ignore_checksums then sha1_update_cstruct t.ctx block ; + let remaining = Int64.(sub remaining (of_int this)) in + let offset = Int64.(add offset (of_int this)) in + if remaining = 0L then + return offset + else + copy offset remaining + in + copy offset hdr.Tar.Header.file_size >>= fun offset -> + c.Channels.really_read zero >>= fun () -> + loop + { + t with + offset + ; detected_block_size= Some block_size + ; last_sequence_number= sequence_number + } + in loop (TarInput.initial ()) open StreamCommon @@ -566,246 +727,403 @@ type endpoint = | Https of Uri.t let endpoint_of_string = function - | "stdout:" -> return Stdout - | "null:" -> return Null - | uri -> - let uri' = Uri.of_string uri in - begin match Uri.scheme uri', Uri.host uri' with - | Some "fd", Some fd -> - return (File_descr (fd |> int_of_string |> file_descr_of_int |> Lwt_unix.of_unix_file_descr)) - | Some "tcp", _ -> - let host = match Uri.host uri' with None -> failwith "Please supply a host in the URI" | Some host -> host in - let port = match Uri.port uri' with None -> failwith "Please supply a port in the URI" | Some port -> port in - Lwt_unix.gethostbyname host >>= fun host_entry -> - return (Sockaddr(Lwt_unix.ADDR_INET(host_entry.Lwt_unix.h_addr_list.(0), port))) - | Some "unix", _ -> - return (Sockaddr(Lwt_unix.ADDR_UNIX(Uri.path uri'))) - | Some "file", _ -> - return (File(Uri.path uri')) - | Some "http", _ -> - return (Http uri') - | Some "https", _ -> - return (Https uri') - | Some x, _ -> - fail (Failure (Printf.sprintf "Unknown URI scheme: %s" x)) - | None, _ -> - fail (Failure (Printf.sprintf "Failed to parse URI: %s" uri)) - end + | "stdout:" -> + return Stdout + | "null:" -> + return Null + | uri -> ( + let uri' = Uri.of_string uri in + match (Uri.scheme uri', Uri.host uri') with + | Some "fd", Some fd -> + return + (File_descr + (fd + |> int_of_string + |> file_descr_of_int + |> Lwt_unix.of_unix_file_descr + )) + | Some "tcp", _ -> + let host = + match Uri.host uri' with + | None -> + failwith "Please supply a host in the URI" + | Some host -> + host + in + let host = Scanf.ksscanf host (fun _ _ -> host) "[%s@]" Fun.id in + let port = + match Uri.port uri' with + | None -> + failwith "Please supply a port in the URI" + | Some port -> + port + in + Lwt_unix.getaddrinfo host (string_of_int port) [] >>= fun he -> + if he = [] then raise Not_found ; + return (Sockaddr (List.hd he).Unix.ai_addr) + | Some "unix", _ -> + return (Sockaddr (Lwt_unix.ADDR_UNIX (Uri.path uri'))) + | Some "file", _ -> + return (File (Uri.path uri')) + | Some "http", _ -> + return (Http uri') + | Some "https", _ -> + return (Https uri') + | Some x, _ -> + fail (Failure (Printf.sprintf "Unknown URI scheme: %s" x)) + | None, _ -> + fail (Failure (Printf.sprintf "Failed to parse URI: %s" uri)) + ) let socket sockaddr = - let family = match sockaddr with - | Lwt_unix.ADDR_INET(_, _) -> Unix.PF_INET - | Lwt_unix.ADDR_UNIX _ -> Unix.PF_UNIX in + let family = + match sockaddr with + | Lwt_unix.ADDR_INET (addr, port) -> + Unix.domain_of_sockaddr (Lwt_unix.ADDR_INET (addr, port)) + | Lwt_unix.ADDR_UNIX _ -> + Unix.PF_UNIX + in Lwt_unix.socket family Unix.SOCK_STREAM 0 -let colon = Re_str.regexp_string ":" +let colon = Re.Str.regexp_string ":" let retry common retries f = let rec aux n = - if n <= 0 then f () + if n <= 0 then + f () else - try_lwt f () - with exn -> - if common.Common.debug then - Printf.fprintf stderr "warning: caught %s; will retry %d more time%s...\n%!" - (Printexc.to_string exn) n (if n=1 then "" else "s"); - Lwt_unix.sleep 1. >>= fun () -> - aux (n - 1) in + Lwt.catch f (fun exn -> + if common.Common.debug then + Printf.fprintf stderr + "warning: caught %s; will retry %d more time%s...\n%!" + (Printexc.to_string exn) n + (if n = 1 then "" else "s") ; + Lwt_unix.sleep 1. >>= fun () -> aux (n - 1)) + in aux retries +(** [make_stream common source relative_to source_format destination_format] + returns a lazy stream of extents to copy. [source_format] determines the + way in which the [source] and [relative_to] strings sould be interpreted + and how their data and metadata can be accessed. If [relative_to] is + specified, then the changes from it will will be returned. *) let make_stream common source relative_to source_format destination_format = - match source_format, destination_format with - | "hybrid", "raw" -> + match (source_format, destination_format) with + | "nbdhybrid", "raw" -> ( + match Re.Str.bounded_split colon source 4 with + | [raw; nbd_server; export_name; size] -> + let size = Int64.of_string size in + Vhd_format_lwt.IO.openfile raw false >>= fun raw -> + Nbd_input.raw raw nbd_server export_name size + | _ -> + fail + (Failure + (Printf.sprintf + "Failed to parse nbdhybrid source: %s (expecting \ + :::" + source)) + ) + | "hybrid", "raw" -> ( (* expect source to be block_device:vhd *) - begin match Re_str.bounded_split colon source 2 with - | [ raw; vhd ] -> - let path = common.path @ [ Filename.dirname vhd ] in - retry common 3 (fun () -> Vhd_IO.openchain ~path vhd false) >>= fun t -> - Vhd_lwt.IO.openfile raw false >>= fun raw -> - ( match relative_to with None -> return None | Some f -> Vhd_IO.openchain ~path f false >>= fun t -> return (Some t) ) >>= fun from -> - Hybrid_input.raw ?from raw t + match Re.Str.bounded_split colon source 2 with + | [raw; vhd] -> + let path = common.path @ [Filename.dirname vhd] in + retry common 3 (fun () -> Vhd_IO.openchain ~path vhd false) >>= fun t -> + Vhd_IO.close t >>= fun () -> + Vhd_format_lwt.IO.openfile raw false >>= fun raw -> + ( match relative_to with + | None -> + return None + | Some f -> + Vhd_IO.openchain ~path f false >>= fun t -> + Vhd_IO.close t >>= fun () -> return (Some t) + ) + >>= fun from -> Hybrid_input.raw ?from raw t | _ -> - fail (Failure (Printf.sprintf "Failed to parse hybrid source: %s (expected raw_disk|vhd_disk)" source)) - end - | "hybrid", "vhd" -> + fail + (Failure + (Printf.sprintf + "Failed to parse hybrid source: %s (expected raw_disk|vhd_disk)" + source)) + ) + | "hybrid", "vhd" -> ( (* expect source to be block_device:vhd *) - begin match Re_str.bounded_split colon source 2 with - | [ raw; vhd ] -> - let path = common.path @ [ Filename.dirname vhd ] in - retry common 3 (fun () -> Vhd_IO.openchain ~path vhd false) >>= fun t -> - Vhd_lwt.IO.openfile raw false >>= fun raw -> - ( match relative_to with None -> return None | Some f -> Vhd_IO.openchain ~path f false >>= fun t -> return (Some t) ) >>= fun from -> - Hybrid_input.vhd ?from raw t + match Re.Str.bounded_split colon source 2 with + | [raw; vhd] -> + let path = common.path @ [Filename.dirname vhd] in + retry common 3 (fun () -> Vhd_IO.openchain ~path vhd false) >>= fun t -> + Vhd_IO.close t >>= fun () -> + Vhd_format_lwt.IO.openfile raw false >>= fun raw -> + ( match relative_to with + | None -> + return None + | Some f -> + Vhd_IO.openchain ~path f false >>= fun t -> + Vhd_IO.close t >>= fun () -> return (Some t) + ) + >>= fun from -> Hybrid_input.vhd ?from raw t | _ -> - fail (Failure (Printf.sprintf "Failed to parse hybrid source: %s (expected raw_disk|vhd_disk)" source)) - end + fail + (Failure + (Printf.sprintf + "Failed to parse hybrid source: %s (expected raw_disk|vhd_disk)" + source)) + ) | "vhd", "vhd" -> - let path = common.path @ [ Filename.dirname source ] in - retry common 3 (fun () -> Vhd_IO.openchain ~path source false) >>= fun t -> - ( match relative_to with None -> return None | Some f -> Vhd_IO.openchain ~path f false >>= fun t -> return (Some t) ) >>= fun from -> - Vhd_input.vhd ?from t + let path = common.path @ [Filename.dirname source] in + retry common 3 (fun () -> Vhd_IO.openchain ~path source false) + >>= fun t -> + ( match relative_to with + | None -> + return None + | Some f -> + Vhd_IO.openchain ~path f false >>= fun t -> return (Some t) + ) + >>= fun from -> Vhd_input.vhd ?from t | "vhd", "raw" -> - let path = common.path @ [ Filename.dirname source ] in - retry common 3 (fun () -> Vhd_IO.openchain ~path source false) >>= fun t -> - ( match relative_to with None -> return None | Some f -> Vhd_IO.openchain ~path f false >>= fun t -> return (Some t) ) >>= fun from -> - Vhd_input.raw ?from t + let path = common.path @ [Filename.dirname source] in + retry common 3 (fun () -> Vhd_IO.openchain ~path source false) + >>= fun t -> + ( match relative_to with + | None -> + return None + | Some f -> + Vhd_IO.openchain ~path f false >>= fun t -> return (Some t) + ) + >>= fun from -> Vhd_input.raw ?from t | "raw", "vhd" -> - let source = match Image.of_device source with - | Some (`Raw x) -> x (* bypass any tapdisk and use the raw file *) - | _ -> source in - Raw_IO.openfile source false >>= fun t -> - Raw_input.vhd t + let source = + match Image.of_device source with + | Some (`Raw x) -> + x (* bypass any tapdisk and use the raw file *) + | _ -> + source + in + Raw_IO.openfile source false >>= fun t -> Raw_input.vhd t | "raw", "raw" -> - let source = match Image.of_device source with - | Some (`Raw x) -> x (* bypass any tapdisk and use the raw file *) - | _ -> source in - Raw_IO.openfile source false >>= fun t -> - Raw_input.raw t - | _, _ -> assert false - -let write_stream common s destination source_protocol destination_protocol prezeroed progress tar_filename_prefix ssl_legacy good_ciphersuites legacy_ciphersuites = + let source = + match Image.of_device source with + | Some (`Raw x) -> + x (* bypass any tapdisk and use the raw file *) + | _ -> + source + in + Raw_IO.openfile source false >>= fun t -> Raw_input.raw t + | _, _ -> + assert false + +let write_stream common s destination _source_protocol destination_protocol + prezeroed progress tar_filename_prefix good_ciphersuites = endpoint_of_string destination >>= fun endpoint -> let use_ssl = match endpoint with Https _ -> true | _ -> false in ( match endpoint with - | File path -> - Lwt_unix.openfile path [ Unix.O_RDWR; Unix.O_CREAT ] 0o0644 >>= fun fd -> + | File path -> + Lwt_unix.openfile path [Unix.O_RDWR; Unix.O_CREAT] 0o0644 >>= fun fd -> Channels.of_seekable_fd fd >>= fun c -> - return (c, [ NoProtocol; Human; Tar ]) - | Null -> - Lwt_unix.openfile "/dev/null" [ Unix.O_RDWR ] 0o0 >>= fun fd -> - Channels.of_raw_fd fd >>= fun c -> - return (c, [ NoProtocol; Human; Tar ]) - | Stdout -> + return (c, [NoProtocol; Human; Tar]) + | Null -> + Lwt_unix.openfile "/dev/null" [Unix.O_RDWR] 0o0 >>= fun fd -> + Channels.of_raw_fd fd >>= fun c -> return (c, [NoProtocol; Human; Tar]) + | Stdout -> Channels.of_raw_fd Lwt_unix.stdout >>= fun c -> - return (c, [ NoProtocol; Human; Tar ]) - | File_descr fd -> + return (c, [NoProtocol; Human; Tar]) + | File_descr fd -> Channels.of_raw_fd fd >>= fun c -> - return (c, [ Nbd; NoProtocol; Chunked; Human; Tar ]) - | Sockaddr sockaddr -> + return (c, [Nbd; NoProtocol; Chunked; Human; Tar]) + | Sockaddr sockaddr -> let sock = socket sockaddr in - Lwt_unix.connect sock sockaddr >>= fun () -> + Lwt.catch + (fun () -> Lwt_unix.connect sock sockaddr) + (fun e -> Lwt_unix.close sock >>= fun () -> Lwt.fail e) + >>= fun () -> Channels.of_raw_fd sock >>= fun c -> - return (c, [ Nbd; NoProtocol; Chunked; Human; Tar ]) - | Https uri' - | Http uri' -> + return (c, [Nbd; NoProtocol; Chunked; Human; Tar]) + | Https uri' | Http uri' -> ( (* TODO: https is not currently implemented *) - let port = match Uri.port uri' with None -> (if use_ssl then 443 else 80) | Some port -> port in - let host = match Uri.host uri' with None -> failwith "Please supply a host in the URI" | Some host -> host in - Lwt_unix.gethostbyname host >>= fun host_entry -> - let sockaddr = Lwt_unix.ADDR_INET(host_entry.Lwt_unix.h_addr_list.(0), port) in + let port = + match Uri.port uri' with + | None -> + if use_ssl then 443 else 80 + | Some port -> + port + in + let host = + match Uri.host uri' with + | None -> + failwith "Please supply a host in the URI" + | Some host -> + host + in + let host = Scanf.ksscanf host (fun _ _ -> host) "[%s@]" Fun.id in + Lwt_unix.getaddrinfo host (string_of_int port) [] >>= fun he -> + if he = [] then raise Not_found ; + + let sockaddr = (List.hd he).Unix.ai_addr in let sock = socket sockaddr in - Lwt_unix.connect sock sockaddr >>= fun () -> - + Lwt.catch + (fun () -> Lwt_unix.connect sock sockaddr) + (fun e -> Lwt_unix.close sock >>= fun () -> Lwt.fail e) + >>= fun () -> let open Cohttp in - ( if use_ssl then Channels.of_ssl_fd sock ssl_legacy good_ciphersuites legacy_ciphersuites else Channels.of_raw_fd sock ) >>= fun c -> - - let module Request = Request.Make(Cohttp_unbuffered_io) in - let module Response = Response.Make(Cohttp_unbuffered_io) in + ( if use_ssl then + Channels.of_ssl_fd sock good_ciphersuites + else + Channels.of_raw_fd sock + ) + >>= fun c -> + let module Request = Request.Make (Cohttp_unbuffered_io) in + let module Response = Response.Make (Cohttp_unbuffered_io) in let headers = Header.init () in - let k, v = Cookie.Cookie_hdr.serialize [ "chunked", "true" ] in + let k, v = Cookie.Cookie_hdr.serialize [("chunked", "true")] in let headers = Header.add headers k v in - let headers = match Uri.userinfo uri' with - | None -> headers - | Some x -> - begin match Re_str.bounded_split_delim (Re_str.regexp_string ":") x 2 with - | [ user; pass ] -> - let b = Cohttp.Auth.string_of_credential (`Basic (user, pass)) in - Header.add headers "authorization" b + let headers = + match Uri.userinfo uri' with + | None -> + headers + | Some x -> ( + match Re.Str.bounded_split_delim (Re.Str.regexp_string ":") x 2 with + | [user; pass] -> + let b = Cohttp.Auth.string_of_credential (`Basic (user, pass)) in + Header.add headers "authorization" b | _ -> - Printf.fprintf stderr "I don't know how to handle authentication for this URI.\n Try scheme://user:password@host/path\n"; - exit 1 - end in - let request = Cohttp.Request.make ~meth:`PUT ~version:`HTTP_1_1 ~headers uri' in + Printf.fprintf stderr + "I don't know how to handle authentication for this URI.\n\ + \ Try scheme://user:password@host/path\n" ; + exit 1 + ) + in + let request = + Cohttp.Request.make ~meth:`PUT ~version:`HTTP_1_1 ~headers uri' + in Request.write (fun _ -> return ()) request c >>= fun () -> Response.read (Cohttp_unbuffered_io.make_input c) >>= fun r -> - begin match r with - | `Invalid x -> fail (Failure (Printf.sprintf "Invalid HTTP response: %s" - x)) - | `Eof -> fail (Failure "EOF while parsing HTTP response") + match r with + | `Invalid x -> + fail (Failure (Printf.sprintf "Invalid HTTP response: %s" x)) + | `Eof -> + fail (Failure "EOF while parsing HTTP response") | `Ok x -> - let code = Code.code_of_status (Cohttp.Response.status x) in - if Code.is_success code then begin - let advertises_nbd = - let headers = Header.to_list (Cohttp.Response.headers x) in - let headers = List.map (fun (x, y) -> String.lowercase x, String.lowercase y) headers in - let te = "transfer-encoding" in - List.mem_assoc te headers && (List.assoc te headers = "nbd") in - if advertises_nbd - then return(c, [ Nbd ]) - else return(c, [ Chunked; NoProtocol ]) - end else fail (Failure (Code.reason_phrase_of_code code)) - end - ) >>= fun (c, possible_protocols) -> - let destination_protocol = match destination_protocol with - | Some x -> x - | None -> + let code = Code.code_of_status (Cohttp.Response.status x) in + if Code.is_success code then + let advertises_nbd = + let headers = Header.to_list (Cohttp.Response.headers x) in + let headers = + List.map + (fun (x, y) -> + (String.lowercase_ascii x, String.lowercase_ascii y)) + headers + in + let te = "transfer-encoding" in + List.mem_assoc te headers && List.assoc te headers = "nbd" + in + if advertises_nbd then + return (c, [Nbd]) + else + return (c, [Chunked; NoProtocol]) + else + fail (Failure (Code.reason_phrase_of_code code)) + ) + ) + >>= fun (c, possible_protocols) -> + let destination_protocol = + match destination_protocol with + | Some x -> + x + | None -> let t = List.hd possible_protocols in - Printf.fprintf stderr "Using protocol: %s\n%!" (string_of_protocol t); - t in - if not(List.mem destination_protocol possible_protocols) - then fail(Failure(Printf.sprintf "this destination only supports protocols: [ %s ]" (String.concat "; " (List.map string_of_protocol possible_protocols)))) - else - let start = Unix.gettimeofday () in - (match destination_protocol with - | Nbd -> stream_nbd - | Human -> stream_human - | Chunked -> stream_chunked - | Tar -> stream_tar - | NoProtocol -> stream_raw) common c s prezeroed tar_filename_prefix ~progress () >>= fun p -> - c.Channels.close () >>= fun () -> - match p with - | Some p -> + Printf.fprintf stderr "Using protocol: %s\n%!" (string_of_protocol t) ; + t + in + if not (List.mem destination_protocol possible_protocols) then + fail + (Failure + (Printf.sprintf "this destination only supports protocols: [ %s ]" + (String.concat "; " + (List.map string_of_protocol possible_protocols)))) + else + let start = Unix.gettimeofday () in + ( match destination_protocol with + | Nbd -> + stream_nbd + | Human -> + stream_human + | Chunked -> + stream_chunked + | Tar -> + stream_tar + | NoProtocol -> + stream_raw + ) + common c s prezeroed tar_filename_prefix ~progress () + >>= fun p -> + c.Channels.close () >>= fun () -> + match p with + | Some p -> let time = Unix.gettimeofday () -. start in let physical_rate = Int64.(to_float p /. time) in - if common.Common.verb then begin + if common.Common.verb then ( let add_unit x = let kib = 1024. in let mib = kib *. 1024. in let gib = mib *. 1024. in let tib = gib *. 1024. in - if x /. tib > 1. then Printf.sprintf "%.1f TiB" (x /. tib) - else if x /. gib > 1. then Printf.sprintf "%.1f GiB" (x /. gib) - else if x /. mib > 1. then Printf.sprintf "%.1f MiB" (x /. mib) - else if x /. kib > 1. then Printf.sprintf "%.1f KiB" (x /. kib) - else Printf.sprintf "%.1f B" x in - - Printf.printf "Time taken: %s\n" (hms (int_of_float time)); - Printf.printf "Physical data rate: %s/sec\n" (add_unit physical_rate); - let open Vhd.F in - let speedup = Int64.(to_float s.size.total /. (to_float p)) in - Printf.printf "Speedup: %.1f\n" speedup; - Printf.printf "Virtual data rate: %s/sec\n" (add_unit (physical_rate *. speedup)); - end; + if x /. tib > 1. then + Printf.sprintf "%.1f TiB" (x /. tib) + else if x /. gib > 1. then + Printf.sprintf "%.1f GiB" (x /. gib) + else if x /. mib > 1. then + Printf.sprintf "%.1f MiB" (x /. mib) + else if x /. kib > 1. then + Printf.sprintf "%.1f KiB" (x /. kib) + else + Printf.sprintf "%.1f B" x + in + + Printf.printf "Time taken: %s\n" (hms (int_of_float time)) ; + Printf.printf "Physical data rate: %s/sec\n" (add_unit physical_rate) ; + let open Vhd_format.F in + let speedup = Int64.(to_float s.size.total /. to_float p) in + Printf.printf "Speedup: %.1f\n" speedup ; + Printf.printf "Virtual data rate: %s/sec\n" + (add_unit (physical_rate *. speedup)) + ) ; + return () + | None -> return () - | None -> return () - let stream_t common args ?(progress = no_progress_bar) () = - make_stream common args.StreamCommon.source args.StreamCommon.relative_to args.StreamCommon.source_format args.StreamCommon.destination_format >>= fun s -> - write_stream common s args.StreamCommon.destination args.StreamCommon.source_protocol args.StreamCommon.destination_protocol args.StreamCommon.prezeroed progress args.StreamCommon.tar_filename_prefix args.StreamCommon.ssl_legacy args.StreamCommon.good_ciphersuites args.StreamCommon.legacy_ciphersuites + make_stream common args.StreamCommon.source args.StreamCommon.relative_to + args.StreamCommon.source_format args.StreamCommon.destination_format + >>= fun s -> + write_stream common s args.StreamCommon.destination + args.StreamCommon.source_protocol args.StreamCommon.destination_protocol + args.StreamCommon.prezeroed progress args.StreamCommon.tar_filename_prefix + args.StreamCommon.good_ciphersuites let stream common args = try - Vhd_lwt.File.use_unbuffered := common.Common.unbuffered; - - let progress_bar = match args with - | { StreamCommon.progress = true; machine = true } -> machine_progress_bar - | { StreamCommon.progress = true; machine = false } -> console_progress_bar - | _ -> no_progress_bar in + Vhd_format_lwt.File.use_unbuffered := common.Common.unbuffered ; + + let progress_bar = + match args with + | {StreamCommon.progress= true; machine= true; _} -> + machine_progress_bar + | {StreamCommon.progress= true; machine= false; _} -> + console_progress_bar + | _ -> + no_progress_bar + in let thread = stream_t common args ~progress:progress_bar () in - Lwt_main.run thread; - `Ok () - with Failure x -> - `Error(true, x) + Lwt_main.run thread ; `Ok () + with Failure x -> `Error (true, x) let serve_nbd_to_raw common size c dest _ _ _ _ = let flags = [] in - let open Nbd in - let buf = Cstruct.create Negotiate.sizeof in - Negotiate.marshal buf { Negotiate.size; flags }; + let open Nbd.Protocol in + let buf = Cstruct.create (Negotiate.sizeof `V1) in + Negotiate.marshal buf (Negotiate.V1 {Negotiate.size; flags}) ; c.Channels.really_write buf >>= fun () -> - let twomib = 2 * 1024 * 1024 in let block = IO.alloc twomib in let inblocks fn request = @@ -815,38 +1133,48 @@ let serve_nbd_to_raw common size c dest _ _ _ _ = fn offset subblock >>= fun () -> let remaining = remaining - n in let offset = Int64.(add offset (of_int n)) in - if remaining > 0 then loop offset remaining else return () in - loop request.Request.from (Int32.to_int request.Request.len) in + if remaining > 0 then loop offset remaining else return () + in + loop request.Request.from (Int32.to_int request.Request.len) + in let req = Cstruct.create Request.sizeof in let rep = Cstruct.create Reply.sizeof in let rec serve_requests () = c.Channels.really_read req >>= fun () -> match Request.unmarshal req with - | `Error e -> fail e - | `Ok request -> - if common.Common.debug - then Printf.fprintf stderr "%s\n%!" (Request.to_string request); - begin match request.Request.ty with - | Command.Write -> - inblocks (fun offset subblock -> - c.Channels.really_read subblock >>= fun () -> - Vhd_lwt.IO.really_write dest offset subblock - ) request >>= fun () -> - Reply.marshal rep { Reply.error = 0l; handle = request.Request.handle }; - c.Channels.really_write rep - | Command.Read -> - Reply.marshal rep { Reply.error = 0l; handle = request.Request.handle }; - c.Channels.really_write rep >>= fun () -> - inblocks (fun offset subblock -> - Vhd_lwt.IO.really_read dest offset subblock >>= fun () -> - c.Channels.really_write subblock - ) request - | _ -> - Reply.marshal rep { Reply.error = 1l; handle = request.Request.handle }; - c.Channels.really_write rep - end >>= fun () -> - serve_requests () in + | Error e -> + fail e + | Ok request -> + if common.Common.debug then + Printf.fprintf stderr "%s\n%!" (Request.to_string request) ; + ( match request.Request.ty with + | Command.Write -> + inblocks + (fun offset subblock -> + c.Channels.really_read subblock >>= fun () -> + Vhd_format_lwt.IO.really_write dest offset subblock) + request + >>= fun () -> + Reply.marshal rep + {Reply.error= Ok (); handle= request.Request.handle} ; + c.Channels.really_write rep + | Command.Read -> + Reply.marshal rep + {Reply.error= Ok (); handle= request.Request.handle} ; + c.Channels.really_write rep >>= fun () -> + inblocks + (fun offset subblock -> + Vhd_format_lwt.IO.really_read dest offset subblock >>= fun () -> + c.Channels.really_write subblock) + request + | _ -> + Reply.marshal rep + {Reply.error= Error `EPERM; handle= request.Request.handle} ; + c.Channels.really_write rep + ) + >>= fun () -> serve_requests () + in serve_requests () let serve_chunked_to_raw _ c dest _ _ _ _ = @@ -855,112 +1183,196 @@ let serve_chunked_to_raw _ c dest _ _ _ _ = let buffer = IO.alloc twomib in let rec loop () = c.Channels.really_read header >>= fun () -> - if Chunked.is_last_chunk header then begin - Printf.fprintf stderr "Received last chunk.\n%!"; + if Chunked.is_last_chunk header then ( + Printf.fprintf stderr "Received last chunk.\n%!" ; return () - end else begin + ) else let rec block offset remaining = let this = Int32.(to_int (min (of_int twomib) remaining)) in let buf = if this < twomib then Cstruct.sub buffer 0 this else buffer in c.Channels.really_read buf >>= fun () -> - Vhd_lwt.IO.really_write dest offset buf >>= fun () -> + Vhd_format_lwt.IO.really_write dest offset buf >>= fun () -> let offset = Int64.(add offset (of_int this)) in let remaining = Int32.(sub remaining (of_int this)) in - if remaining > 0l - then block offset remaining - else return () in + if remaining > 0l then + block offset remaining + else + return () + in block (Chunked.get_offset header) (Chunked.get_len header) >>= fun () -> loop () - end in + in loop () +(* If we're using unbuffered IO, we write in whole sectors. We therefore might + need to extend the cstruct to the next sector boundary *) +let round_up_to_sector unbuffered len = + if unbuffered then + let sector_size = 512 in + (((len - 1) / sector_size) + 1) * sector_size + else + len + let serve_raw_to_raw common size c dest _ progress _ _ = let twomib = 2 * 1024 * 1024 in let buffer = IO.alloc twomib in let p = progress size in let rec loop offset remaining = - let this = Int64.(to_int (min remaining (of_int (Cstruct.len buffer)))) in - let block = Cstruct.sub buffer 0 this in - c.Channels.really_read block >>= fun () -> - Vhd_lwt.IO.really_write dest offset block >>= fun () -> - let offset = Int64.(add offset (of_int this)) in - let remaining = Int64.(sub remaining (of_int this)) in begin - p Int64.(sub size remaining); - if remaining > 0L - then loop offset remaining - else return () - end in + let n = Int64.(to_int (min remaining (of_int (Cstruct.len buffer)))) in + let rounded_n = round_up_to_sector common.unbuffered n in + (* Create a buffer of the rounded-up size *) + let block = Cstruct.sub buffer 0 rounded_n in + ( if n <> rounded_n then + Vhd_format_lwt.IO.really_read dest offset block + else + Lwt.return () + ) + >>= fun () -> + (* Create a cstruct that's an alias to the above block, + but only as long as the amount of data we're expecting *) + let block2 = Cstruct.sub block 0 n in + c.Channels.really_read block2 >>= fun () -> + Vhd_format_lwt.IO.really_write dest offset block >>= fun () -> + let offset = Int64.(add offset (of_int n)) in + let remaining = Int64.(sub remaining (of_int n)) in + p Int64.(sub size remaining) ; + if remaining > 0L then + loop offset remaining + else + return () + in loop 0L size -let serve common_options source source_fd source_format source_protocol destination destination_fd destination_format destination_size prezeroed progress machine expected_prefix ignore_checksums = +let serve common_options source source_fd source_format source_protocol + destination destination_fd destination_format destination_size prezeroed + progress machine expected_prefix ignore_checksums = try - Vhd_lwt.File.use_unbuffered := common_options.Common.unbuffered; - - let source_protocol = protocol_of_string (require "source-protocol" source_protocol) in - - let supported_formats = [ "raw"; "vhd" ] in - if not (List.mem source_format supported_formats) - then failwith (Printf.sprintf "%s is not a supported format" source_format); - let supported_formats = [ "raw" ] in - if not (List.mem destination_format supported_formats) - then failwith (Printf.sprintf "%s is not a supported format" destination_format); - let supported_protocols = [ NoProtocol; Chunked; Nbd; Tar ] in - if not (List.mem source_protocol supported_protocols) - then failwith (Printf.sprintf "%s is not a supported source protocol" (string_of_protocol source_protocol)); - - let destination = match destination_fd with - | None -> destination - | Some fd -> "fd://" ^ (string_of_int fd) in - - let progress_bar = match progress, machine with - | true, true -> machine_progress_bar - | true, false -> console_progress_bar - | _, _ -> no_progress_bar in + Vhd_format_lwt.File.use_unbuffered := common_options.Common.unbuffered ; + + let source_protocol = + protocol_of_string (require "source-protocol" source_protocol) + in + + let supported_formats = ["raw"; "vhd"] in + if not (List.mem source_format supported_formats) then + failwith (Printf.sprintf "%s is not a supported format" source_format) ; + let supported_formats = ["raw"] in + if not (List.mem destination_format supported_formats) then + failwith + (Printf.sprintf "%s is not a supported format" destination_format) ; + let supported_protocols = [NoProtocol; Chunked; Nbd; Tar] in + if not (List.mem source_protocol supported_protocols) then + failwith + (Printf.sprintf "%s is not a supported source protocol" + (string_of_protocol source_protocol)) ; + + let destination = + match destination_fd with + | None -> + destination + | Some fd -> + "fd://" ^ string_of_int fd + in + + let progress_bar = + match (progress, machine) with + | true, true -> + machine_progress_bar + | true, false -> + console_progress_bar + | _, _ -> + no_progress_bar + in let thread = endpoint_of_string destination >>= fun destination_endpoint -> ( match source_fd with - | None -> endpoint_of_string source - | Some fd -> return (File_descr (Lwt_unix.of_unix_file_descr (file_descr_of_int fd))) ) >>= fun source_endpoint -> + | None -> + endpoint_of_string source + | Some fd -> + return + (File_descr (Lwt_unix.of_unix_file_descr (file_descr_of_int fd))) + ) + >>= fun source_endpoint -> ( match source_endpoint with - | File_descr fd -> - Channels.of_raw_fd fd >>= fun c -> - return c - | Sockaddr s -> + | File_descr fd -> + Channels.of_raw_fd fd >>= fun c -> return c + | Sockaddr s -> let sock = socket s in - Lwt_unix.bind sock s; - Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true; - Lwt_unix.listen sock 1; + Lwt_unix.bind sock s >>= fun () -> + Lwt_unix.setsockopt sock Unix.SO_REUSEADDR true ; + Lwt_unix.listen sock 1 ; Lwt_unix.accept sock >>= fun (fd, _) -> - Channels.of_raw_fd fd >>= fun c -> - return c - | File path -> - let fd = Vhd_lwt.File.openfile path false 0 in + Channels.of_raw_fd fd >>= fun c -> return c + | File path -> + let fd = Vhd_format_lwt.File.openfile path false 0 in Channels.of_raw_fd (Lwt_unix.of_unix_file_descr fd) - | _ -> failwith (Printf.sprintf "Not implemented: serving from source %s" source) ) >>= fun source_sock -> + | _ -> + failwith + (Printf.sprintf "Not implemented: serving from source %s" source) + ) + >>= fun source_sock -> ( match destination_endpoint with - | File path -> - ( if not(Sys.file_exists path) then begin - Lwt_unix.openfile path [ Unix.O_CREAT; Unix.O_RDONLY ] 0o0644 >>= fun fd -> - Lwt_unix.close fd - end else return () ) >>= fun () -> - Vhd_lwt.IO.openfile path true >>= fun fd -> - let size = match destination_size with - | None -> Vhd_lwt.File.get_file_size path - | Some x -> x in - return (fd, size) - | _ -> failwith (Printf.sprintf "Not implemented: writing to destination %s" destination) ) >>= fun (destination_fd, size) -> - let fn = match source_format, source_protocol with - | "raw", NoProtocol -> serve_raw_to_raw common_options size - | "raw", Nbd -> serve_nbd_to_raw common_options size - | "raw", Chunked -> serve_chunked_to_raw common_options - | "raw", Tar -> serve_tar_to_raw size - | "vhd", NoProtocol -> serve_vhd_to_raw size - | _, _ -> failwith (Printf.sprintf "Not implemented: receiving format %s via protocol %s" source_format (StreamCommon.string_of_protocol source_protocol)) in - fn source_sock destination_fd prezeroed progress_bar expected_prefix ignore_checksums >>= fun () -> - let fd = Lwt_unix.unix_file_descr (Vhd_lwt.IO.to_file_descr destination_fd) in - (try Vhd_lwt.File.fsync fd; return () with _ -> fail (Failure "fsync failed")) in - Lwt_main.run thread; - `Ok () - with Failure x -> - `Error(false, x) + | File path -> + ( if not (Sys.file_exists path) then + Lwt_unix.openfile path [Unix.O_CREAT; Unix.O_RDONLY] 0o0644 + >>= fun fd -> Lwt_unix.close fd + else + return () + ) + >>= fun () -> + Vhd_format_lwt.IO.openfile path true >>= fun fd -> + let size = + match destination_size with + | None -> + Vhd_format_lwt.File.get_file_size path + | Some x -> + x + in + if size = 0L then + fail + (Failure + "Non-zero size required (either a pre-existing destination \ + file or specified via --destination-size on the command \ + line)") + else + return (fd, size) + | _ -> + failwith + (Printf.sprintf "Not implemented: writing to destination %s" + destination) + ) + >>= fun (destination_fd, size) -> + let fn = + match (source_format, source_protocol) with + | "raw", NoProtocol -> + serve_raw_to_raw common_options size + | "raw", Nbd -> + serve_nbd_to_raw common_options size + | "raw", Chunked -> + serve_chunked_to_raw common_options + | "raw", Tar -> + serve_tar_to_raw size + | "vhd", NoProtocol -> + serve_vhd_to_raw size + | _, _ -> + failwith + (Printf.sprintf + "Not implemented: receiving format %s via protocol %s" + source_format + (StreamCommon.string_of_protocol source_protocol)) + in + fn source_sock destination_fd prezeroed progress_bar expected_prefix + ignore_checksums + >>= fun () -> + let fd = + Lwt_unix.unix_file_descr + (Vhd_format_lwt.IO.to_file_descr destination_fd) + in + try + Vhd_format_lwt.File.fsync fd ; + return () + with _ -> fail (Failure "fsync failed") + in + Lwt_main.run thread ; `Ok () + with Failure x -> `Error (false, x) diff --git a/src/input.ml b/src/input.ml index ac3e1eb..d25d782 100644 --- a/src/input.ml +++ b/src/input.ml @@ -11,38 +11,39 @@ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. *) - + type 'a t = 'a Lwt.t -let (>>=) = Lwt.(>>=) +let ( >>= ) = Lwt.( >>= ) + let return = Lwt.return + let fail = Lwt.fail open Lwt -type fd = { - fd: Lwt_unix.file_descr; - mutable offset: int64; -} +type fd = {fd: Lwt_unix.file_descr; mutable offset: int64} let of_fd fd = let offset = 0L in - { fd; offset } + {fd; offset} let read fd buf = - lwt () = IO.complete "read" (Some fd.offset) Lwt_bytes.read fd.fd buf in - fd.offset <- Int64.(add fd.offset (of_int (Cstruct.len buf))); + IO.complete "read" (Some fd.offset) Lwt_bytes.read fd.fd buf >>= fun () -> + fd.offset <- Int64.(add fd.offset (of_int (Cstruct.len buf))) ; return () let skip_to fd n = let buf = Io_page.(to_cstruct (get 1)) in let rec loop remaining = - if remaining = 0L - then return () + if remaining = 0L then + return () else let this = Int64.(to_int (min remaining (of_int (Cstruct.len buf)))) in let frag = Cstruct.sub buf 0 this in - lwt () = IO.complete "read" (Some fd.offset) Lwt_bytes.read fd.fd frag in - fd.offset <- Int64.(add fd.offset (of_int this)); - loop Int64.(sub remaining (of_int this)) in + IO.complete "read" (Some fd.offset) Lwt_bytes.read fd.fd frag + >>= fun () -> + fd.offset <- Int64.(add fd.offset (of_int this)) ; + loop Int64.(sub remaining (of_int this)) + in loop Int64.(sub n fd.offset) diff --git a/src/main.ml b/src/main.ml deleted file mode 100644 index 6110dd9..0000000 --- a/src/main.ml +++ /dev/null @@ -1,262 +0,0 @@ -(* - * Copyright (C) 2011-2013 Citrix Inc - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU Lesser General Public License as published - * by the Free Software Foundation; version 2.1 only. with the special - * exception on linking described in file LICENSE. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Lesser General Public License for more details. - *) - -let project_url = "http://github.com/djs55/ocaml-vhd" - -open Common -open Cmdliner - -(* Help sections common to all commands *) - -let _common_options = "COMMON OPTIONS" -let help = [ - `S _common_options; - `P "These options are common to all commands."; - `S "MORE HELP"; - `P "Use `$(mname) $(i,COMMAND) --help' for help on a single command."; `Noblank; - `S "BUGS"; `P (Printf.sprintf "Check bug reports at %s" project_url); -] - -(* Options common to all commands *) -let common_options_t = - let docs = _common_options in - let debug = - let doc = "Give only debug output." in - Arg.(value & flag & info ["debug"] ~docs ~doc) in - let verb = - let doc = "Give verbose output." in - let verbose = true, Arg.info ["v"; "verbose"] ~docs ~doc in - Arg.(last & vflag_all [false] [verbose]) in - let unbuffered = - let doc = "Use unbuffered I/O." in - Arg.(value & flag & info ["unbuffered"; "direct"] ~docs ~doc) in - let search_path = - let doc = "Search path for vhds." in - Arg.(value & opt string "." & info [ "path" ] ~docs ~doc) in - Term.(pure Common.make $ debug $ verb $ unbuffered $ search_path) - -let get_cmd = - let doc = "query vhd metadata" in - let man = [ - `S "DESCRIPTION"; - `P "Look up a particular metadata property by name and print the value." - ] @ help in - let filename = - let doc = Printf.sprintf "Path to the vhd file." in - Arg.(value & pos 0 (some file) None & info [] ~doc) in - let key = - let doc = "Key to query" in - Arg.(value & pos 1 (some string) None & info [] ~doc) in - Term.(ret(pure Impl.get $ common_options_t $ filename $ key)), - Term.info "get" ~sdocs:_common_options ~doc ~man - -let filename = - let doc = Printf.sprintf "Path to the vhd file." in - Arg.(value & pos 0 (some file) None & info [] ~doc) - -let info_cmd = - let doc = "display general information about a vhd" in - let man = [ - `S "DESCRIPTION"; - `P "Display general information about a vhd, including header and footer fields. This won't directly display block allocation tables or sector bitmaps."; - ] @ help in - Term.(ret(pure Impl.info $ common_options_t $ filename)), - Term.info "info" ~sdocs:_common_options ~doc ~man - -let contents_cmd = - let doc = "display the contents of the vhd" in - let man = [ - `S "DESCRIPTION"; - `P "Display the contents of the vhd: headers, metadata and data blocks. Everything is displayed in the order it appears in the vhd file, not the order it appears in the virtual disk image itself."; - ] @ help in - Term.(ret(pure Impl.contents $ common_options_t $ filename)), - Term.info "contents" ~sdocs:_common_options ~doc ~man - - -let create_cmd = - let doc = "create a dynamic vhd" in - let man = [ - `S "DESCRIPTION"; - `P "Create a dynamic vhd (i.e. one which may be sparse). A dynamic vhd may be self-contained or it may have a backing-file or 'parent'."; - ] @ help in - let filename = - let doc = Printf.sprintf "Path to the vhd file to be created." in - Arg.(value & pos 0 (some string) None & info [] ~doc) in - let size = - let doc = Printf.sprintf "Virtual size of the disk." in - Arg.(value & opt (some string) None & info [ "size" ] ~doc) in - let parent = - let doc = Printf.sprintf "Parent image" in - Arg.(value & opt (some file) None & info [ "parent" ] ~doc) in - Term.(ret(pure Impl.create $ common_options_t $ filename $ size $ parent)), - Term.info "create" ~sdocs:_common_options ~doc ~man - -let check_cmd = - let doc = "check the structure of a vhd file" in - let man = [ - `S "DESCRIPTION"; - `P "Check the structure of a vhd file is valid, print any errors on the console."; - ] @ help in - let filename = - let doc = Printf.sprintf "Path to the vhd to be checked." in - Arg.(value & pos 0 (some file) None & info [] ~doc) in - Term.(ret(pure Impl.check $ common_options_t $ filename)), - Term.info "check" ~sdocs:_common_options ~doc ~man - -let source = - let doc = Printf.sprintf "The source disk" in - Arg.(value & opt string "stdin:" & info [ "source" ] ~doc) - -let source_fd = - let doc = Printf.sprintf "An open-file descriptor pointing to the source disk" in - Arg.(value & opt (some int) None & info [ "source-fd" ] ~doc) - -let source_format = - let doc = "Source format" in - Arg.(value & opt string "raw" & info [ "source-format" ] ~doc) - -let source_protocol = - let doc = "Transport protocol for the source data." in - Arg.(value & opt (some string) None & info [ "source-protocol" ] ~doc) - -let destination = - let doc = "Destination for streamed data." in - Arg.(value & opt string "stdout:" & info [ "destination" ] ~doc) - -let destination_fd = - let doc = "Write data to a file descriptor." in - Arg.(value & opt (some int) None & info [ "destination-fd" ] ~doc) - -let destination_format = - let doc = "Destination format" in - Arg.(value & opt string "raw" & info [ "destination-format" ] ~doc) - -let destination_size = - let doc = "Size of the destination disk" in - Arg.(value & opt (some int64) None & info [ "destination-size" ] ~doc) - -let prezeroed = - let doc = "Assume the destination is completely empty." in - Arg.(value & flag & info [ "prezeroed" ] ~doc) - -let progress = - let doc = "Display a progress bar." in - Arg.(value & flag & info ["progress"] ~doc) - -let machine = - let doc = "Machine readable output." in - Arg.(value & flag & info ["machine"] ~doc) - -let tar_filename_prefix = - let doc = "Filename prefix for tar/sha disk blocks" in - Arg.(value & opt (some string) None & info ["tar-filename-prefix"] ~doc) - -let ssl_legacy = - let doc = "For TLS, allow all protocol versions instead of just TLSv1.2" in - Arg.(value & flag & info ["ssl-legacy"] ~doc) - -let good_ciphersuites = - let doc = "The list of ciphersuites to allow for TLS" in - Arg.(value & opt (some string) None & info ["good-ciphersuites"] ~doc) - -let legacy_ciphersuites = - let doc = "Additional TLS ciphersuites allowed only if ssl-legacy is set" in - Arg.(value & opt (some string) None & info ["legacy-ciphersuites"] ~doc) - -let serve_cmd = - let doc = "serve the contents of a disk" in - let man = [ - `S "DESCRIPTION"; - `P "Allow the contents of a disk to be read or written over a network protocol"; - `P "EXAMPLES"; - `P " vhd-tool serve --source fd:5 --source-protocol=chunked --destination file:///foo.raw --destination-format raw"; - `P " vhd-tool serve --source fd:5 --source-protocol=nbd --destination file:///foo.raw --destination-format raw"; - `P " vhd-tool serve --source fd:5 --source-format=vhd --source-protocol=none --destination file:///foo.raw --destination-format raw"; - ] in - let ignore_checksums = - let doc = "Do not verify checksums" in - Arg.(value & flag & info ["ignore-checksums"] ~doc) in - Term.(ret(pure Impl.serve $ common_options_t $ source $ source_fd $ source_format $ source_protocol $ destination $ destination_fd $ destination_format $ destination_size $ prezeroed $ progress $ machine $ tar_filename_prefix $ ignore_checksums)), - Term.info "serve" ~sdocs:_common_options ~doc ~man - -let stream_cmd = - let doc = "stream the contents of a vhd disk" in - let man = [ - `S "DESCRIPTION"; - `P "Read the contents of a virtual disk from a source using (format, protocol) and write it out to a destination using another (format, protocol). This command allows disks to be uploaded, downloaded and format-converted in a space-efficient manner."; - `S "FORMATS"; - `P "The input format and the output format are specified separately: this allows easy format conversion during the streaming process. The following formats are defined:"; - `P " raw: a single flat image"; - `P " vhd: the Virtual Hard Disk format used in XenServer"; - `P "Note: the vhd format supports both self-contained single file images and also \"differencing disks\" containing only the differences between two disks. To input only the differences between two disks, specify the reference disk with the \"--relative-to\" argument."; - `S "PROTOCOLS"; - `P "Protocols are the means by which a disk image in a particular format is written to a particular destination. The following protocols are supported:"; - `P " nbd: the Network Block Device protocol"; - `P " chunked: the XenServer chunked disk upload protocol"; - `P " none: unencoded write"; - `P " tar: the XenServer import/export encoding using tar"; - `P " human: human-readable description of the contents"; - `P "The default behaviour is to auto-detect based on the destination."; - `S "SOURCES and DESTINATIONS"; - `P "The source describes where the disk data comes from. The destination describes where the disk data is written to. The following are defined:"; - `P " stdin:"; - `P " read from standard input (input only)"; - `P " stdout:"; - `P " write to standard output (destination only)"; - `P " fd:5"; - `P " read and write from file descriptor 5"; - `P " "; - `P " read from or write to the file "; - `P " unix://"; - `P " connect to the Unix domain socket"; - `P " tcp://server:port/path"; - `P " to issue an HTTP PUT to server:port/path"; - `P " tcp://host:port/"; - `P " to connect to TCP port 'port' on host 'host'"; - `S "OTHER OPTIONS"; - `P "When transferring a raw format image onto a medium which is completely empty (i.e. full of zeroes) it is possible to optimise the transfer by avoiding writing empty blocks. The default behaviour is to write zeroes, which is always safe. If you know your media is empty then supply the '--prezeroed' argument."; - `P "When running interactively, the --progress argument will cause a progress bar and summary statistics to be printed."; - `P "When generating a tar/sha stream, the --tar-filename-prefix will be prefixed onto each disk data block. This is typically used to place the disk blocks of separate disks in different directories."; - `S "NOTES"; - `P "Not all protocols can be used with all destinations. For example the NBD protocol needs the ability to read (responses) and write (requests); it therefore will not work with the stdout: destination"; - `S "EXAMPLES"; - `P " $(tname) stream --source=foo.vhd --source-format=vhd --destination-format=raw --destination=http://user:password@xenserver/import_raw_vdi?vdi="; - ] @ help in - let source = - let doc = Printf.sprintf "The disk to be streamed" in - Arg.(value & opt string "stdin:" & info [ "source" ] ~doc) in - let relative_to = - let doc = "Output only differences from the given reference disk" in - Arg.(value & opt (some file) None & info [ "relative-to" ] ~doc) in - let destination_protocol = - let doc = "Transport protocol for the destination data." in - Arg.(value & opt (some string) None & info [ "destination-protocol" ] ~doc) in - let stream_args_t = - Term.(pure StreamCommon.make $ source $ relative_to $ source_format $ destination_format $ destination $ destination_fd $ source_protocol $ destination_protocol $ prezeroed $ progress $ machine $ tar_filename_prefix $ ssl_legacy $ good_ciphersuites $ legacy_ciphersuites) in - Term.(ret(pure Impl.stream $ common_options_t $ stream_args_t)), - Term.info "stream" ~sdocs:_common_options ~doc ~man - - -let default_cmd = - let doc = "manipulate virtual disks stored in vhd files" in - let man = help in - Term.(ret (pure (fun _ -> `Help (`Pager, None)) $ common_options_t)), - Term.info "vhd-tool" ~version:"1.0.0" ~sdocs:_common_options ~doc ~man - -let cmds = [info_cmd; contents_cmd; get_cmd; create_cmd; check_cmd; serve_cmd; stream_cmd] - -let _ = - match Term.eval_choice default_cmd cmds with - | `Error _ -> exit 1 - | _ -> exit 0 diff --git a/src/nbd_input.ml b/src/nbd_input.ml new file mode 100644 index 0000000..47e146f --- /dev/null +++ b/src/nbd_input.ml @@ -0,0 +1,102 @@ +module F = Vhd_format.F.From_file (Vhd_format_lwt.IO) +open Lwt.Infix + +type extent = {flags: int32; length: int64} [@@deriving rpc] + +(* The flags returned for the base:allocation NBD metadata context are defined here: + https://github.com/NetworkBlockDevice/nbd/blob/extension-blockstatus/doc/proto.md#baseallocation-metadata-context *) +let flag_hole = 1l + +let flag_zero = 2l + +type extent_list = extent list [@@deriving rpc] + +(** We query the block status for an area of up to 1GiB at a time, to avoid + excessive memory usage when marshalling/unmarshalling the JSON containing + the extent list. *) +let max_query_length = Int64.(mul 1024L (mul 1024L 1024L)) + +let min a b = if Int64.compare a b < 0 then a else b + +(** The extents returned by this Python script must be consecutive, + non-overlapping, in the correct order starting from the specified offset, + and must exactly cover the requested area. *) +let get_extents_json ~extent_reader ~server ~export_name ~offset ~length = + Lwt_process.pread + ( "" + , [| + extent_reader + ; "--path" + ; server + ; "--exportname" + ; export_name + ; "--offset" + ; Int64.to_string offset + ; "--length" + ; Int64.to_string length + |] ) + +let raw ?(extent_reader = "/opt/xensource/libexec/get_nbd_extents.py") raw + server export_name size = + let to_sectors b = Int64.div b 512L in + let is_empty e = + let has_flag flag = Int32.logand e.flags flag = flag in + (* We assume the destination is prezeroed, so we do not have to copy zeroed extents *) + has_flag flag_hole || has_flag flag_zero + in + let assert_integer_sectors b = + if Int64.rem b 512L <> 0L then failwith "Expecting sector aligned extents" + in + let rec operations extents offset acc = + match extents with + | e :: es -> + assert_integer_sectors e.length ; + let op = + if is_empty e then + `Empty (to_sectors e.length) + else + `Copy (raw, to_sectors offset, to_sectors e.length) + in + operations es (Int64.add offset e.length) (op :: acc) + | [] -> + (List.rev acc, offset) + in + let operations ~offset ~length = + get_extents_json ~extent_reader ~server ~export_name ~offset ~length + >>= fun extents_json -> + let extents = extent_list_of_rpc (Jsonrpc.of_string extents_json) in + let ops, final_offset = operations extents offset [] in + ( if final_offset <> Int64.add offset length then + Lwt.fail_with + (Printf.sprintf + "Nbd_input.raw: extents returned for offset=%Ld & length=%Ld \ + finished at incorrect offset %Ld," + offset length final_offset) + else + Lwt.return_unit + ) + >|= fun () -> ops + in + + let rec block ops offset = + match (ops, offset) with + | [], offset when offset >= size -> + ( if offset <> size then + Lwt.fail_with + (Printf.sprintf + "Nbd_input.raw finished with offset=%Ld <> size=%Ld" offset + size) + else + Lwt.return_unit + ) + >>= fun () -> Lwt.return F.End + | [], _ -> + let length = min (Int64.sub size offset) max_query_length in + operations ~offset ~length >>= fun ops -> + block ops (Int64.add offset length) + | op :: ops, _ -> + Lwt.return (F.Cons (op, fun () -> block ops offset)) + in + block [] 0L >>= fun elements -> + let size = Vhd_format.F.{total= size; metadata= 0L; empty= 0L; copy= 0L} in + Lwt.return F.{elements; size} diff --git a/src/sendfile64_stubs.c b/src/sendfile64_stubs.c deleted file mode 100644 index eaa6f7b..0000000 --- a/src/sendfile64_stubs.c +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (C) 2012-2013 Citrix Inc - * - * This program is free software; you can redistribute it and/or modify - * it under the terms of the GNU Lesser General Public License as published - * by the Free Software Foundation; version 2.1 only. with the special - * exception on linking described in file LICENSE. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Lesser General Public License for more details. - */ - -#define _GNU_SOURCE - -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#ifdef __linux__ -# include -# include -#endif - -/* ocaml/ocaml/unixsupport.c */ -extern void uerror(char *cmdname, value cmdarg); -#define Nothing ((value) 0) - -#define NOT_IMPLEMENTED (-1) -#define TRIED_AND_FAILED (1) -#define OK 0 - -CAMLprim value stub_sendfile64(value in_fd, value out_fd, value len){ - CAMLparam3(in_fd, out_fd, len); - CAMLlocal1(result); - size_t c_len = Int64_val(len); - size_t bytes; - int c_in_fd = Int_val(in_fd); - int c_out_fd = Int_val(out_fd); - - int rc = NOT_IMPLEMENTED; - - enter_blocking_section(); - -#ifdef __linux__ - rc = TRIED_AND_FAILED; - bytes = sendfile(c_out_fd, c_in_fd, NULL, c_len); - if (bytes != -1) rc = OK; -#endif - - leave_blocking_section(); - - switch (rc) { - case NOT_IMPLEMENTED: - caml_failwith("This platform does not support sendfile()"); - break; - case TRIED_AND_FAILED: - uerror("sendfile", Nothing); - break; - default: break; - } - result = caml_copy_int64(bytes); - CAMLreturn(result); -} diff --git a/src/sparse_dd.ml b/src/sparse_dd.ml deleted file mode 100644 index 88339c4..0000000 --- a/src/sparse_dd.ml +++ /dev/null @@ -1,384 +0,0 @@ -(* Utility program which copies between two block devices, using vhd BATs and efficient zero-scanning - for performance. *) - -module D = Debug.Make(struct let name = "sparse_dd" end) -open D - -let config_file = "/etc/sparse_dd.conf" - -let vhd_search_path = "/dev/mapper" - -let ionice_cmd = "/usr/bin/ionice" -let renice_cmd = "/usr/bin/renice" - -type encryption_mode = - | Always - | Never - | User -let string_of_encryption_mode = function - | Always -> "always" - | Never -> "never" - | User -> "user" -let encryption_mode_of_string = function - | "always" -> Always - | "never" -> Never - | "user" -> User - | x -> failwith (Printf.sprintf "Unknown encryption mode %s. Use always, never or user." x) -let encryption_mode = ref User - -(* Niceness: strings that may or may not be valid ints. *) -let nice = ref None -let ionice_class = ref None -let ionice_class_data = ref None - -let base = ref None -let src = ref None -let dest = ref None -let size = ref (-1L) -let prezeroed = ref false -let set_machine_logging = ref false -let experimental_reads_bypass_tapdisk = ref false -let experimental_writes_bypass_tapdisk = ref false - -let ssl_legacy = ref false -let good_ciphersuites = ref None -let legacy_ciphersuites = ref None - -let string_opt = function - | None -> "None" - | Some x -> x - -let machine_readable_progress = ref false - -let options = - let str_option name var_ref description = - name, Arg.String (fun x -> var_ref := Some x), (fun () -> string_opt !var_ref), description - in - [ - "unbuffered", Arg.Bool (fun b -> Vhd_lwt.File.use_unbuffered := b), (fun () -> string_of_bool !Vhd_lwt.File.use_unbuffered), "use unbuffered I/O via O_DIRECT"; - "encryption-mode", Arg.String (fun x -> encryption_mode := encryption_mode_of_string x), (fun () -> string_of_encryption_mode !encryption_mode), "how to use encryption"; - (* Want to ignore bad values for "nice" etc. so not using Arg.Int *) - str_option "nice" nice "If supplied, the scheduling priority will be set using this value as argument to the 'nice' command."; - str_option "ionice_class" ionice_class "If supplied, the io scheduling class will be set using this value as -c argument to the 'ionice' command."; - str_option "ionice_class_data" ionice_class_data "If supplied, the io scheduling class data will be set using this value as -n argument to the 'ionice' command."; - "experimental-reads-bypass-tapdisk", Arg.Set experimental_reads_bypass_tapdisk, (fun () -> string_of_bool !experimental_reads_bypass_tapdisk), "bypass tapdisk and read directly from the underlying vhd file"; - "experimental-writes-bypass-tapdisk", Arg.Set experimental_writes_bypass_tapdisk, (fun () -> string_of_bool !experimental_writes_bypass_tapdisk), "bypass tapdisk and write directly to the underlying vhd file"; - "base", Arg.String (fun x -> base := Some x), (fun () -> string_opt !base), "base disk to search for differences from"; - "src", Arg.String (fun x -> src := Some x), (fun () -> string_opt !src), "source disk"; - "dest", Arg.String (fun x -> dest := Some x), (fun () -> string_opt !dest), "destination disk"; - "size", Arg.String (fun x -> size := Int64.of_string x), (fun () -> Int64.to_string !size), "number of bytes to copy"; - "prezeroed", Arg.Set prezeroed, (fun () -> string_of_bool !prezeroed), "assume the destination disk has been prezeroed"; - "machine", Arg.Set machine_readable_progress, (fun () -> string_of_bool !machine_readable_progress), "emit machine-readable output"; - "ssl-legacy", Arg.Set ssl_legacy, (fun () -> string_of_bool !ssl_legacy), " for TLS, allow all protocol versions instead of just TLSv1.2"; - "good-ciphersuites", Arg.String (fun x -> good_ciphersuites := Some x), (fun () -> string_opt !good_ciphersuites), " the list of ciphersuites to allow for TLS"; - "legacy-ciphersuites", Arg.String (fun x -> legacy_ciphersuites := Some x), (fun () -> string_opt !legacy_ciphersuites), " additional TLS ciphersuites allowed only if ssl-legacy is set"; -] - -let ( +* ) = Int64.add -let ( -* ) = Int64.sub -let ( ** ) = Int64.mul -let kib = 1024L -let mib = kib ** kib - -let startswith prefix x = - let prefix' = String.length prefix - and x' = String.length x in - prefix' <= x' && (String.sub x 0 prefix' = prefix) - -let (|>) a b = b a -module Opt = struct - let default d = function - | None -> d - | Some x -> x -end -module Mutex = struct - include Mutex - let execute m f = - Mutex.lock m; - try - let result = f () in - Mutex.unlock m; - result - with e -> - Mutex.unlock m; - raise e -end - -module Progress = struct - let header = Cstruct.create Chunked.sizeof - - (** Report progress complete to another program reading stdout *) - let report fraction = - if !machine_readable_progress then begin - let s = Printf.sprintf "Progress: %.0f" (fraction *. 100.) in - let data = Cstruct.create (String.length s) in - Cstruct.blit_from_string s 0 data 0 (String.length s); - Chunked.marshal header { Chunked.offset = 0L; data }; - Printf.printf "%s%s%!" (Cstruct.to_string header) s - end - - (** Emit the end-of-stream message *) - let close () = - if !machine_readable_progress then begin - let header = Cstruct.create Chunked.sizeof in - Chunked.marshal header { Chunked.offset = 0L; data = Cstruct.create 0 }; - Printf.printf "%s%!" (Cstruct.to_string header) - end -end - -let after f g = - try - let r = f () in - g (); - r - with e -> - g (); - raise e - -(** [find_backend_device path] returns [Some path'] where [path'] is the backend path in - the driver domain corresponding to the frontend device [path] in this domain. *) -let find_backend_device path = - try - let open Xenstore in - (* If we're looking at a xen frontend device, see if the backend - is in the same domain. If so check if it looks like a .vhd *) - let rdev = (Unix.LargeFile.stat path).Unix.LargeFile.st_rdev in - let major = rdev / 256 and minor = rdev mod 256 in - let link = Unix.readlink (Printf.sprintf "/sys/dev/block/%d:%d/device" major minor) in - match List.rev (Re_str.split (Re_str.regexp_string "/") link) with - | id :: "xen" :: "devices" :: _ when startswith "vbd-" id -> - let id = int_of_string (String.sub id 4 (String.length id - 4)) in - with_xs (fun xs -> - let self = xs.Xs.read "domid" in - let backend = xs.Xs.read (Printf.sprintf "device/vbd/%d/backend" id) in - let params = xs.Xs.read (Printf.sprintf "%s/params" backend) in - match Re_str.split (Re_str.regexp_string "/") backend with - | "local" :: "domain" :: bedomid :: _ -> - assert (self = bedomid); - Some params - | _ -> raise Not_found - ) - | _ -> raise Not_found - with _ -> None - -let with_paused_tapdisk path f = - let path = find_backend_device path |> Opt.default path in - - let context = Tapctl.create () in - match Tapctl.of_device context path with - | tapdev, _, (Some (driver, path)) -> - debug "pausing tapdisk for %s" path; - Tapctl.pause context tapdev; - after f (fun () -> - debug "unpausing tapdisk for %s" path; - Tapctl.unpause context tapdev path Tapctl.Vhd - ) - | _, _, _ -> failwith (Printf.sprintf "Failed to pause tapdisk for %s" path) - -let deref_symlinks path = - let rec inner seen_already path = - if List.mem path seen_already - then failwith "Circular symlink"; - let stats = Unix.LargeFile.lstat path in - if stats.Unix.LargeFile.st_kind = Unix.S_LNK - then inner (path :: seen_already) (Unix.readlink path) - else path in - inner [] path - - -(* Record when the binary started for performance measuring *) -let start = Unix.gettimeofday () - -(* Helper function to print nice progress info *) -let progress_cb = - let last_percent = ref (-1) in - - function fraction -> - let new_percent = int_of_float (fraction *. 100.) in - if !last_percent <> new_percent then Progress.report fraction; - if !last_percent / 10 <> new_percent / 10 then debug "progress %d%%" new_percent; - last_percent := new_percent - -let _ = - Vhd_lwt.File.use_unbuffered := true; - Xcp_service.configure ~options (); - - let src = match !src with - | None -> - debug "Must have -src argument\n"; - exit 1 - | Some x -> x in - let dest = match !dest with - | None -> - debug "Must have -dest argument\n"; - exit 1 - | Some x -> x in - if !size = (-1L) then begin - debug "Must have -size argument\n"; - exit 1 - end; - let size = !size in - let base = !base in - - (* Helper function to bring an int into valid range *) - let clip v min max descr = - if v < min then ( - warn "Value %d is too low for %s. Using %d instead." v descr min; - min - ) else if v > max then ( - warn "Value %d is too high for %s. Using %d instead." v descr max; - max - ) else v - in - - ( - let parse_as_int str_opt int_opt_ref opt_name = - match str_opt with - | None -> () - | Some str -> - try - int_opt_ref := Some (int_of_string str) - with _ -> - error "Ignoring invalid value for %s: %s" opt_name str - in - - (* renice this process if specified *) - let n_ref = ref None in - parse_as_int !nice n_ref "nice"; - (match !n_ref with - | None -> () - | Some n -> ( - (* Run command like: renice -n priority -p pid *) - let n = clip n (-20) 19 "nice" in - let pid = string_of_int (Unix.getpid ()) in - let (stdout, stderr) = Forkhelpers.execute_command_get_output renice_cmd ["-n"; string_of_int n; "-p"; pid] - in () - ) - ); - - (* Possibly run command like: ionice -c class -n classdata -p pid *) - let c_ref = ref None in - let cd_ref = ref None in - parse_as_int !ionice_class c_ref "ionice_class"; - parse_as_int !ionice_class_data cd_ref "ionice_class_data"; - - match !c_ref with - | None -> () - | Some c -> - let pid = string_of_int (Unix.getpid ()) in - let ionice args = - let (stdout, stderr) = Forkhelpers.execute_command_get_output ionice_cmd args - in () - in - let class_only c = - ionice ["-c"; string_of_int c; "-p"; pid] - in - let class_and_data c n = - ionice ["-c"; string_of_int c; "-n"; string_of_int n; "-p"; pid] - in - match c with - | 0 | 3 -> - class_only c - | 1 | 2 -> ( - match !cd_ref with - | None -> class_only c - | Some n -> - let n = clip n 0 7 "ionice classdata" in - class_and_data c n) - | _ -> error "Cannot use ionice due to invalid class value: %d" c - ); - - debug "src = %s; dest = %s; base = %s; size = %Ld" src dest (Opt.default "None" base) size; - let src_image = Image.of_device src in - let dest_image = Image.of_device dest in - let base_image = match base with - | None -> None - | Some x -> Image.of_device x in - let to_string = function None -> "None" | Some x -> Image.to_string x in - debug "src_image = %s; dest_image = %s; base_image = %s" (to_string src_image) (to_string dest_image) (to_string base_image); - - (* Add the directory of the vhd to the search path *) - let vhd_search_path = match src_image with - | Some (`Vhd x) -> vhd_search_path ^ ":" ^ (Filename.dirname x) - | _ -> vhd_search_path in - - let common = Common.make true false true vhd_search_path in - - if !experimental_reads_bypass_tapdisk - then warn "experimental_reads_bypass_tapdisk set: this may cause data corruption"; - if !experimental_writes_bypass_tapdisk - then warn "experimental_writes_bypass_tapdisk set: this may cause data corruption"; - - let relative_to = match base_image with - | Some (`Vhd x) -> Some x - | Some (`Raw _) -> None - | None -> None in - - let rewrite_url device_or_url = - (* Ensure device_or_url is a valid URL, and apply our encryption policy *) - let uri = Uri.of_string device_or_url in - let rewrite_scheme scheme = - let uri = Uri.make ~scheme - ?userinfo:(Uri.userinfo uri) - ?host:(Uri.host uri) - ?port:(Uri.port uri) - ~path:(Uri.path uri) - ~query:(Uri.query uri) - ?fragment:(Uri.fragment uri) - () in - Uri.to_string uri in - begin match Uri.scheme uri with - | Some "https" when !encryption_mode = Never -> - warn "turning off encryption for this transfer as requested by config file"; - rewrite_scheme "http" - | Some "http" when !encryption_mode = Always -> - warn "turning on encryption for this transfer as requested by config file"; - rewrite_scheme "https" - | Some ("http" | "https") -> device_or_url - | _ -> "file://" ^ device_or_url - end in - - let open Lwt in - let stream_t, destination, destination_format = match !experimental_reads_bypass_tapdisk, src, src_image, !experimental_writes_bypass_tapdisk, dest, dest_image with - | true, _, Some (`Vhd vhd), true, _, Some (`Vhd vhd') -> - prezeroed := false; (* the physical disk will have vhd metadata and other stuff on it *) - info "streaming from vhd %s (relative to %s) to vhd %s" vhd (string_opt relative_to) vhd'; - let t = Impl.make_stream common vhd relative_to "vhd" "vhd" in - t, "file://" ^ vhd', "vhd" - | false, _, _, true, _, _ -> - error "Not implemented: writes bypass tapdisk while reads go through tapdisk"; - failwith "Not implemented: writing bypassing tapdisk while reading through tapdisk" - | false, _, Some (`Vhd vhd), false, _, _ -> - let dest = rewrite_url dest in - info "streaming from raw %s using BAT from %s (relative to %s) to raw %s" src vhd (string_opt relative_to) dest; - let t = Impl.make_stream common (src ^ ":" ^ vhd) relative_to "hybrid" "raw" in - t, dest, "raw" - | true, _, Some (`Vhd vhd), _, _, _ -> - let dest = rewrite_url dest in - info "streaming from vhd %s (relative to %s) to raw %s" vhd (string_opt relative_to) dest; - let t = Impl.make_stream common vhd relative_to "vhd" "raw" in - t, dest, "raw" - | _, _, Some (`Raw raw), _, _, _ -> - let dest = rewrite_url dest in - info "streaming from raw %s (relative to %s) to raw %s" raw (string_opt relative_to) dest; - let t = Impl.make_stream common raw relative_to "raw" "raw" in - t, dest, "raw" - | _, device, None, _, _, _ -> - let dest = rewrite_url dest in - info "streaming from raw %s (relative to %s) to raw %s" src (string_opt relative_to) dest; - let t = Impl.make_stream common device relative_to "raw" "raw" in - t, dest, "raw" in - - progress_cb 0.; - let progress total_work work_done = - let fraction = Int64.(to_float work_done /. (to_float total_work)) in - progress_cb fraction in - let t = - stream_t >>= fun s -> - Impl.write_stream common s destination (Some "none") None !prezeroed progress None !ssl_legacy !good_ciphersuites !legacy_ciphersuites in - if destination_format = "vhd" - then with_paused_tapdisk dest (fun () -> Lwt_main.run t) - else Lwt_main.run t; - let time = Unix.gettimeofday () -. start in - debug "Time: %.2f seconds" time; - Progress.close () diff --git a/src/streamCommon.ml b/src/streamCommon.ml index 6ccef01..368a77c 100644 --- a/src/streamCommon.ml +++ b/src/streamCommon.ml @@ -15,49 +15,91 @@ type protocol = Nbd | Chunked | Human | Tar | NoProtocol let protocol_of_string = function - | "nbd" -> Nbd | "chunked" -> Chunked | "human" -> Human - | "tar" -> Tar | "none" -> NoProtocol - | x -> failwith (Printf.sprintf "Unsupported protocol: %s" x) + | "nbd" -> + Nbd + | "chunked" -> + Chunked + | "human" -> + Human + | "tar" -> + Tar + | "none" -> + NoProtocol + | x -> + failwith (Printf.sprintf "Unsupported protocol: %s" x) let string_of_protocol = function - | Nbd -> "nbd" | Chunked -> "chunked" | Human -> "human" - | Tar -> "tar" | NoProtocol -> "none" + | Nbd -> + "nbd" + | Chunked -> + "chunked" + | Human -> + "human" + | Tar -> + "tar" + | NoProtocol -> + "none" -let supported_formats = [ "raw"; "vhd"; "hybrid" ] +let supported_formats = ["raw"; "vhd"; "hybrid"] -let require name arg = match arg with - | None -> failwith (Printf.sprintf "Please supply a %s argument" name) - | Some x -> x +let require name arg = + match arg with + | None -> + failwith (Printf.sprintf "Please supply a %s argument" name) + | Some x -> + x type t = { - source: string; - relative_to: string option; - source_format: string; - destination_format: string; - destination: string; - source_protocol: protocol; - destination_protocol: protocol option; - prezeroed: bool; - progress: bool; - machine: bool; - tar_filename_prefix: string option; - ssl_legacy: bool; - good_ciphersuites: string option; - legacy_ciphersuites: string option; + source: string + ; relative_to: string option + ; source_format: string + ; destination_format: string + ; destination: string + ; source_protocol: protocol + ; destination_protocol: protocol option + ; prezeroed: bool + ; progress: bool + ; machine: bool + ; tar_filename_prefix: string option + ; good_ciphersuites: string option } -let make source relative_to source_format destination_format destination destination_fd source_protocol destination_protocol prezeroed progress machine tar_filename_prefix ssl_legacy good_ciphersuites legacy_ciphersuites = - let source_protocol = protocol_of_string (require "source-protocol" source_protocol) in - let destination_protocol = match destination_protocol with - | None -> None - | Some x -> Some (protocol_of_string x) in - if not (List.mem source_format supported_formats) - then failwith (Printf.sprintf "%s is not a supported format" source_format); - if not (List.mem destination_format supported_formats) - then failwith (Printf.sprintf "%s is not a supported format" destination_format); - let destination = match destination_fd with - | None -> destination - | Some fd -> "fd://" ^ (string_of_int fd) in - - { source; relative_to; source_format; destination_format; destination; source_protocol; destination_protocol; prezeroed; progress; machine; tar_filename_prefix; ssl_legacy; good_ciphersuites; legacy_ciphersuites } +let make source relative_to source_format destination_format destination + destination_fd source_protocol destination_protocol prezeroed progress + machine tar_filename_prefix good_ciphersuites = + let source_protocol = + protocol_of_string (require "source-protocol" source_protocol) + in + let destination_protocol = + match destination_protocol with + | None -> + None + | Some x -> + Some (protocol_of_string x) + in + if not (List.mem source_format supported_formats) then + failwith (Printf.sprintf "%s is not a supported format" source_format) ; + if not (List.mem destination_format supported_formats) then + failwith (Printf.sprintf "%s is not a supported format" destination_format) ; + let destination = + match destination_fd with + | None -> + destination + | Some fd -> + "fd://" ^ string_of_int fd + in + { + source + ; relative_to + ; source_format + ; destination_format + ; destination + ; source_protocol + ; destination_protocol + ; prezeroed + ; progress + ; machine + ; tar_filename_prefix + ; good_ciphersuites + } diff --git a/src/tar.md b/src/tar.md index a99a22a..ad21302 100644 --- a/src/tar.md +++ b/src/tar.md @@ -1,49 +1,54 @@ -The XenServer VM import/format -============================== +# The XenServer VM import/format This document describes the disk encoding used in version XXX of the XenServer VM import/export format. Each disk is encoded as a directory full of files, within a stream in 'tar' format. The directory name must match the name of the VDI within the VM metadata XML file. Each disk is subdivided into blocks each of which is represented by 2 files: - ---------- 1 djs djs 1048576 Jan 1 1970 00000000 - ---------- 1 djs djs 40 Jan 1 1970 00000000.checksum + ---------- 1 djs djs 1048576 Jan 1 1970 00000000 + ---------- 1 djs djs 40 Jan 1 1970 00000000.checksum The file stem is treated as a counter, *not as a disk offset*. The counter increases monotonically through the stream. The file with the suffix .checksum contains the sha1sum of the corresponding block e.g. $ sha1sum 00000000 3b71f43ff30f4b15b5cd85dd9e95ebc7e84eb5a3 00000000 - $ cat 00000000.checksum + $ cat 00000000.checksum 3b71f43ff30f4b15b5cd85dd9e95ebc7e84eb5a3 The first and last blocks must be present. Readers expect the block size to be encoded in the size of the first block. A typical block size is 1MiB. -Omitting empty blocks -===================== +## Omitting empty blocks -If it is known that a block is entirely empty (ie full of zeroes) then it may be ommitted from the stream *provided it is neither the first or last block*. The ommission is signaled by incrementing the counter value by 1 for every ommitted block. +If it is known that a block is entirely empty (i.e. full of zeroes) then it may +be ommitted from the stream *provided it is neither the first nor the last block*. +The ommission is signaled by incrementing the counter value by 1 for every +ommitted block. Example sequence 1: +```sh 00000000 00000000.checksum 00000001 00000001.checksum 00000002 00000002.checksum +``` --- no blocks have been ommitted, since the counter value increases by 1 each block. +no blocks have been ommitted, since the counter value increases by 1 each block. Example sequence 2: +```sh 00000000 00000000.checksum 00000002 00000002.checksum +``` --- one block has been ommitted, since the counter value increased by 2. +one block has been ommitted, since the counter value increased by 2. -Inserting dummy blocks -====================== - -Sometimes it is convenient to insert extra information into the stream. This can be done by adding extra zero-length blocks, incrementing the counter in the normal way. +## Inserting dummy blocks +Sometimes it is convenient to insert extra information into the stream. This can +be done by adding extra zero-length blocks, incrementing the counter in the +normal way. diff --git a/src/xenstore.ml b/src/xenstore.ml index 7e713a2..47f0cd2 100644 --- a/src/xenstore.ml +++ b/src/xenstore.ml @@ -14,89 +14,97 @@ let error fmt = Printf.ksprintf (output_string stderr) fmt -module Client = Xs_client_unix.Client(Xs_transport_unix_client) +module Client = Xs_client_unix.Client (Xs_transport_unix_client) + let make_client () = - try - Client.make () - with e -> - error "Failed to connect to xenstore. The raw error was: %s" (Printexc.to_string e); - begin match e with - | Unix.Unix_error(Unix.EACCES, _, _) -> - error "Access to xenstore was denied."; - let euid = Unix.geteuid () in - if euid <> 0 then begin - error "My effective uid is %d." euid; - error "Typically xenstore can only be accessed by root (uid 0)."; - error "Please switch to root (uid 0) and retry." - end - | Unix.Unix_error(Unix.ECONNREFUSED, _, _) -> - error "Access to xenstore was refused."; - error "This normally indicates that the service is not running."; - error "Please start the xenstore service and retry." - | _ -> () - end; - raise e + try Client.make () + with e -> + error "Failed to connect to xenstore. The raw error was: %s" + (Printexc.to_string e) ; + ( match e with + | Unix.Unix_error (Unix.EACCES, _, _) -> + error "Access to xenstore was denied." ; + let euid = Unix.geteuid () in + if euid <> 0 then ( + error "My effective uid is %d." euid ; + error "Typically xenstore can only be accessed by root (uid 0)." ; + error "Please switch to root (uid 0) and retry." + ) + | Unix.Unix_error (Unix.ECONNREFUSED, _, _) -> + error "Access to xenstore was refused." ; + error "This normally indicates that the service is not running." ; + error "Please start the xenstore service and retry." + | _ -> + () + ) ; + raise e let get_client = - let client = ref None in - fun () -> match !client with - | None -> - let c = make_client () in - client := Some c; - c - | Some c -> c + let client = ref None in + fun () -> + match !client with + | None -> + let c = make_client () in + client := Some c ; + c + | Some c -> + c type domid = int module Xs = struct - type domid = int + type domid = int - type xsh = { -(* + type xsh = { + (* debug: string list -> string; *) - directory : string -> string list; - read : string -> string; -(* + directory: string -> string list + ; read: string -> string + ; (* readv : string -> string list -> string list; *) - write : string -> string -> unit; - writev : string -> (string * string) list -> unit; - mkdir : string -> unit; - rm : string -> unit; -(* + write: string -> string -> unit + ; writev: string -> (string * string) list -> unit + ; mkdir: string -> unit + ; rm: string -> unit + ; (* getperms : string -> perms; setpermsv : string -> string list -> perms -> unit; release : domid -> unit; resume : domid -> unit; *) - setperms : string -> Xs_protocol.ACL.t -> unit; + setperms: string -> Xs_protocol.ACL.t -> unit + ; getdomainpath: domid -> string + ; watch: string -> string -> unit + ; unwatch: string -> string -> unit + ; introduce: domid -> nativeint -> int -> unit + ; set_target: domid -> domid -> unit + } - getdomainpath : domid -> string; - watch : string -> string -> unit; - unwatch : string -> string -> unit; - introduce : domid -> nativeint -> int -> unit; - set_target : domid -> domid -> unit; -} - - let ops h = { - read = Client.read h; - directory = Client.directory h; - write = Client.write h; - writev = (fun base_path -> List.iter (fun (k, v) -> Client.write h (base_path ^ "/" ^ k) v)); - mkdir = Client.mkdir h; - rm = (fun path -> try Client.rm h path with Xs_protocol.Enoent _ -> ()); - setperms = Client.setperms h; - getdomainpath = Client.getdomainpath h; - watch = Client.watch h; - unwatch = Client.unwatch h; - introduce = Client.introduce h; - set_target = Client.set_target h; + let ops h = + { + read= Client.read h + ; directory= Client.directory h + ; write= Client.write h + ; writev= + (fun base_path -> + List.iter (fun (k, v) -> Client.write h (base_path ^ "/" ^ k) v)) + ; mkdir= Client.mkdir h + ; rm= (fun path -> try Client.rm h path with Xs_protocol.Enoent _ -> ()) + ; setperms= Client.setperms h + ; getdomainpath= Client.getdomainpath h + ; watch= Client.watch h + ; unwatch= Client.unwatch h + ; introduce= Client.introduce h + ; set_target= Client.set_target h } - let with_xs f = Client.immediate (get_client ()) (fun h -> f (ops h)) - let wait f = Client.wait (get_client ()) (fun h -> f (ops h)) - let transaction _ f = Client.transaction (get_client ()) (fun h -> f (ops h)) + let with_xs f = Client.immediate (get_client ()) (fun h -> f (ops h)) + + let wait f = Client.wait (get_client ()) (fun h -> f (ops h)) + + let transaction _ f = Client.transaction (get_client ()) (fun h -> f (ops h)) end module Xst = Xs diff --git a/test/dummy_extent_reader.py b/test/dummy_extent_reader.py new file mode 100755 index 0000000..54cd22e --- /dev/null +++ b/test/dummy_extent_reader.py @@ -0,0 +1,31 @@ +#!/usr/bin/python + +""" +Dummy extent reader that returns a huge extent list +""" + +import json +import sys + +# We simulate a 4 TiB disk +DUMMY_DISK_SIZE = 4 * 1024 * 1024 * 1024 * 1024 + +# Every second 64 KiB block in the disk will be allocated to get a large extent +# list. This is the granularity at which QEMU 2.12 reports allocated blocks for +# qcow images. +BLOCK_SIZE = 64 * 1024 + +def _main(): + offset = int(sys.argv[6]) + length = int(sys.argv[8]) + if not (length + offset <= DUMMY_DISK_SIZE and length > 0 and offset >= 0): + raise ValueError("Invalid length={}, offset={} for disk size {}".format( + length, offset, DUMMY_DISK_SIZE)) + extents = [ + {'flags': 0, 'length': BLOCK_SIZE} + for _offset in xrange(0, length, BLOCK_SIZE) + ] + print json.dumps(extents) + +if __name__ == '__main__': + _main() diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..259525a --- /dev/null +++ b/test/dune @@ -0,0 +1,12 @@ +(executable + (modes byte exe) + (name stress) + (libraries alcotest alcotest-lwt local_lib vhd-format vhd-format-lwt)) + +(rule + (alias stresstest) + (deps + (:x stress.exe) + (source_tree .)) + (action + (run %{x}))) diff --git a/test/stress.ml b/test/stress.ml new file mode 100644 index 0000000..ac7ffac --- /dev/null +++ b/test/stress.ml @@ -0,0 +1,35 @@ +module F = Vhd_format.F.From_file (Vhd_format_lwt.IO) +open Lwt.Infix + +(** We simulate a 4TiB disk *) +let size = Int64.(mul 4L (mul 1024L (mul 1024L (mul 1024L 1024L)))) + +let rec process_stream total = function + | F.Cons (data, s) -> + let sectors = + match data with + | `Empty sectors | `Copy (_, _, sectors) -> + sectors + | _ -> + failwith "unexpected element" + in + s () >>= fun s -> process_stream (Int64.add total sectors) s + | F.End -> + Lwt.return total + +let test_huge_input switch () = + let raw = `anything in + let server = "" in + let export_name = "" in + Nbd_input.raw ~extent_reader:"./dummy_extent_reader.py" raw server export_name + size + >>= fun s -> + process_stream 0L s.F.elements >|= fun sectors -> + Alcotest.(check int64) + "total size of elements in stream" s.F.size.total (Int64.mul 512L sectors) + +let test_set = + let t = Alcotest_lwt.test_case in + [t "VDI with a large allocated extent list" `Quick test_huge_input] + +let () = Alcotest.run "stress test" [("Nbd_input", test_set)] diff --git a/vhd-tool.opam b/vhd-tool.opam new file mode 100644 index 0000000..99877ed --- /dev/null +++ b/vhd-tool.opam @@ -0,0 +1,38 @@ +opam-version: "2.0" +maintainer: "xen-api@lists.xen.org" +authors: [ "xen-api@lists.xen.org" ] +homepage: "https://github.com/xapi-project/vhd-tool" +bug-reports: "https://github.com/xapi-project/vhd-tool/issues" +dev-repo: "git+https://github.com/xapi-project/vhd-tool.git" +tags: [ + "org:mirage" + "org:xapi-project" +] +build: [[ "dune" "build" "-p" name "-j" jobs ] +] +depends: [ + "ocaml" + "dune" {build} + "cohttp-lwt" + "conf-libssl" + "cstruct" {>= "3.0.0"} + "io-page" + "lwt" + "nbd-unix" + "ocaml-migrate-parsetree" + "ppx_cstruct" + "ppx_deriving_rpc" + "re" + "rpclib" + "sha" + "tar" + "vhd-format" + "vhd-format-lwt" + "xapi-tapctl" + "xenstore" + "xenstore_transport" +] +synopsis: ".vhd file manipulation" +url { + src: "https://github.com/xapi-project/vhd-tool/archive/master.tar.gz" +}