|
| 1 | +(* |
| 2 | + * Copyright (C) 2018 Citrix Inc |
| 3 | + * |
| 4 | + * This program is free software; you can redistribute it and/or modify |
| 5 | + * it under the terms of the GNU Lesser General Public License as published |
| 6 | + * by the Free Software Foundation; version 2.1 only. with the special |
| 7 | + * exception on linking described in file LICENSE. |
| 8 | + * |
| 9 | + * This program is distributed in the hope that it will be useful, |
| 10 | + * but WITHOUT ANY WARRANTY; without even the implied warranty of |
| 11 | + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
| 12 | + * GNU Lesser General Public License for more details. |
| 13 | + *) |
| 14 | + |
| 15 | +let write_str filename str = |
| 16 | + let oc = open_out filename in |
| 17 | + Printf.fprintf oc "%s" str; |
| 18 | + close_out oc |
| 19 | + |
| 20 | +let read_str filename = |
| 21 | + let ic = open_in filename in |
| 22 | + let n = in_channel_length ic in |
| 23 | + let s = Bytes.create n in |
| 24 | + really_input ic s 0 n; |
| 25 | + close_in ic; |
| 26 | + s |
| 27 | + |
| 28 | +open Idl |
| 29 | + |
| 30 | +module type PATHS = sig |
| 31 | + val test_data_path : string |
| 32 | + (** Path under which we look for or generate requests and responses. For example, |
| 33 | + if test_data_path = 'foo', this module will search for or generate requests |
| 34 | + matching 'foo/requests/<RPC name>.request.<n>' and responses matching |
| 35 | + 'foo/responses/<RPC name>.response.<n>' *) |
| 36 | +end |
| 37 | + |
| 38 | +(** The following module implements test cases that write test |
| 39 | + RPC requests and responses in JSON that can be used to |
| 40 | + verify that subsequent versions of an API can still parse |
| 41 | + them. |
| 42 | +
|
| 43 | + The test cases are obtained by obtaining the implementation |
| 44 | + of the module generated when applying the API functor to |
| 45 | + this module. |
| 46 | +
|
| 47 | + The test data will be written to the path specified in the |
| 48 | + PATH module passed in *) |
| 49 | +module GenTestData (P:PATHS) = struct |
| 50 | + type implementation = unit Alcotest.test_case list ref |
| 51 | + |
| 52 | + let tests : unit Alcotest.test_case list ref = ref [] |
| 53 | + let description = ref None |
| 54 | + |
| 55 | + let implement x = description := Some x; tests |
| 56 | + |
| 57 | + type ('a,'b) comp = 'a |
| 58 | + type 'a res = unit |
| 59 | + type _ fn = |
| 60 | + | Function : 'a Idl.Param.t * 'b fn -> ('a -> 'b) fn |
| 61 | + | Returning : ('a Idl.Param.t * 'b Idl.Error.t) -> ('a, _) comp fn |
| 62 | + |
| 63 | + let returning a err = Returning (a, err) |
| 64 | + let (@->) = fun t f -> Function (t, f) |
| 65 | + |
| 66 | + let declare name _ ty = |
| 67 | + let rec inner : type b. (((string * Rpc.t) list * Rpc.t list) list) -> b fn -> unit = fun params -> |
| 68 | + function |
| 69 | + | Function (t, f) -> begin |
| 70 | + let vs = Rpc_genfake.genall 2 (match t.Param.name with Some n -> n | None -> t.Param.typedef.Rpc.Types.name) t.Param.typedef.Rpc.Types.ty in |
| 71 | + let marshalled = List.map (fun v -> Rpcmarshal.marshal t.Param.typedef.Rpc.Types.ty v) vs in |
| 72 | + match t.Param.name with |
| 73 | + | Some n -> inner (List.flatten (List.map (fun marshalled -> List.map (fun (named,unnamed) -> (((n,marshalled)::named),unnamed)) params) marshalled)) f |
| 74 | + | None -> inner (List.flatten (List.map (fun marshalled -> List.map (fun (named,unnamed) -> (named,(marshalled::unnamed))) params) marshalled)) f |
| 75 | + end |
| 76 | + | Returning (t, e) -> |
| 77 | + let wire_name = Idl.get_wire_name !description name in |
| 78 | + let calls = List.map |
| 79 | + (fun (named,unnamed) -> |
| 80 | + let args = |
| 81 | + match named with |
| 82 | + | [] -> List.rev unnamed |
| 83 | + | _ -> (Rpc.Dict named) :: List.rev unnamed |
| 84 | + in |
| 85 | + let call = Rpc.call wire_name args in |
| 86 | + call) params in |
| 87 | + List.iteri (fun i call -> |
| 88 | + let request_str = Jsonrpc.string_of_call call in |
| 89 | + write_str |
| 90 | + (Printf.sprintf "%s/requests/%s.request.%d" P.test_data_path wire_name i) |
| 91 | + request_str) calls; |
| 92 | + let vs = Rpc_genfake.genall 2 (match t.Param.name with Some n -> n | None -> t.Param.typedef.Rpc.Types.name) t.Param.typedef.Rpc.Types.ty in |
| 93 | + let marshalled_vs = List.map (fun v -> Rpc.success (Rpcmarshal.marshal t.Param.typedef.Rpc.Types.ty v)) vs in |
| 94 | + let errs = Rpc_genfake.genall 2 "error" e.Error.def.Rpc.Types.ty in |
| 95 | + let marshalled_errs = List.map (fun err -> Rpc.failure (Rpcmarshal.marshal e.Error.def.Rpc.Types.ty err)) errs in |
| 96 | + List.iteri (fun i response -> |
| 97 | + let response_str = Jsonrpc.string_of_response response in |
| 98 | + write_str |
| 99 | + (Printf.sprintf "%s/responses/%s.response.%d" P.test_data_path wire_name i) |
| 100 | + response_str) (marshalled_vs @ marshalled_errs) |
| 101 | + in |
| 102 | + let test_fn () = |
| 103 | + let mkdir_safe p = begin try Unix.mkdir p 0o755 with Unix.Unix_error (EEXIST, _, _) -> () end in |
| 104 | + mkdir_safe P.test_data_path; |
| 105 | + mkdir_safe (Printf.sprintf "%s/requests" P.test_data_path); |
| 106 | + mkdir_safe (Printf.sprintf "%s/responses" P.test_data_path); |
| 107 | + inner [[],[]] ty in |
| 108 | + tests := (Printf.sprintf "Generate test data for '%s'" (Idl.get_wire_name !description name), `Quick, test_fn) :: !tests |
| 109 | +end |
| 110 | + |
| 111 | +let get_arg call has_named name = |
| 112 | + match has_named, name, call.Rpc.params with |
| 113 | + | true, Some n, (Rpc.Dict named)::unnamed -> begin |
| 114 | + match List.partition (fun (x,_) -> x = n) named with |
| 115 | + | (_,arg)::dups,others -> Result.Ok (arg, {call with Rpc.params = (Rpc.Dict (dups @ others))::unnamed }) |
| 116 | + | _,_ -> Result.Error (`Msg (Printf.sprintf "Expecting named argument '%s'" n)) |
| 117 | + end |
| 118 | + | true, None, (Rpc.Dict named)::unnamed -> begin |
| 119 | + match unnamed with |
| 120 | + | head::tail -> Result.Ok (head, {call with Rpc.params = (Rpc.Dict named)::tail}) |
| 121 | + | _ -> Result.Error (`Msg "Incorrect number of arguments") |
| 122 | + end |
| 123 | + | true, _, _ -> begin |
| 124 | + Result.Error (`Msg "Marshalling error: Expecting dict as first argument when named parameters exist") |
| 125 | + end |
| 126 | + | false, None, head::tail -> begin |
| 127 | + Result.Ok (head, {call with Rpc.params = tail}) |
| 128 | + end |
| 129 | + | false, None, [] -> |
| 130 | + Result.Error (`Msg "Incorrect number of arguments") |
| 131 | + | false, Some _, _ -> |
| 132 | + failwith "Can't happen by construction" |
| 133 | + |
| 134 | +exception NoDescription |
| 135 | +exception MarshalError of string |
| 136 | + |
| 137 | + |
| 138 | +(** The following module will generate alcotest test cases to verify |
| 139 | + that a set of requests and responses can be successfully parsed. |
| 140 | +
|
| 141 | + The PATHS module specifies the location for the test data as |
| 142 | + `test_data_path`. Requests and responses will be looked up in |
| 143 | + this location in the subdirectories `requests` and `responses`. |
| 144 | + The actual data must be in files following the naming convention |
| 145 | + <wire_name>.request.<n> and <wire_name>.response.<n>. |
| 146 | +
|
| 147 | + The code here closely follows that of the GenServer module to |
| 148 | + ensure it accurately represents how the server would parse the |
| 149 | + json. |
| 150 | + *) |
| 151 | +module TestOldRpcs (P : PATHS) = struct |
| 152 | + open Rpc |
| 153 | + type implementation = unit Alcotest.test_case list ref |
| 154 | + |
| 155 | + let tests : implementation = ref [] |
| 156 | + let description = ref None |
| 157 | + |
| 158 | + let implement x = description := Some x; tests |
| 159 | + |
| 160 | + type ('a,'b) comp = unit |
| 161 | + type 'a res = unit |
| 162 | + |
| 163 | + type _ fn = |
| 164 | + | Function : 'a Param.t * 'b fn -> ('a -> 'b) fn |
| 165 | + | Returning : ('a Param.t * 'b Error.t) -> (_, _) comp fn |
| 166 | + |
| 167 | + let returning a b = Returning (a,b) |
| 168 | + let (@->) = fun t f -> Function (t, f) |
| 169 | + |
| 170 | + let rec has_named_args : type a. a fn -> bool = |
| 171 | + function |
| 172 | + | Function (t, f) -> begin |
| 173 | + match t.Param.name with |
| 174 | + | Some _ -> true |
| 175 | + | None -> has_named_args f |
| 176 | + end |
| 177 | + | Returning (_, _) -> |
| 178 | + false |
| 179 | + |
| 180 | + let declare : string -> string list -> 'a fn -> _ res = fun name _ ty -> |
| 181 | + begin |
| 182 | + (* Sanity check: ensure the description has been set before we declare |
| 183 | + any RPCs *) |
| 184 | + match !description with |
| 185 | + | Some _ -> () |
| 186 | + | None -> raise NoDescription |
| 187 | + end; |
| 188 | + |
| 189 | + let wire_name = Idl.get_wire_name !description name in |
| 190 | + |
| 191 | + let rec read_all path extension i = |
| 192 | + try |
| 193 | + let call = |
| 194 | + read_str (Printf.sprintf "%s/%s/%s.%s.%d" P.test_data_path path wire_name extension i) in |
| 195 | + call :: read_all path extension (i+1) |
| 196 | + with _ -> [] |
| 197 | + in |
| 198 | + |
| 199 | + let calls = read_all "requests" "request" 0 |> List.map Jsonrpc.call_of_string in |
| 200 | + let responses = read_all "responses" "response" 0 |> List.map Jsonrpc.response_of_string in |
| 201 | + |
| 202 | + let verify : type a. a Rpc.Types.typ -> Rpc.t -> a = fun typ rpc -> |
| 203 | + match Rpcmarshal.unmarshal typ rpc with |
| 204 | + | Ok x -> |
| 205 | + let check = Rpcmarshal.marshal typ x in |
| 206 | + if check <> rpc then begin |
| 207 | + let err = Printf.sprintf "Round-trip failed. Before: '%s' After: '%s'" |
| 208 | + (Jsonrpc.to_string rpc) |
| 209 | + (Jsonrpc.to_string check) in |
| 210 | + raise (MarshalError err) |
| 211 | + end; |
| 212 | + x |
| 213 | + | Error (`Msg m) -> |
| 214 | + raise (MarshalError m) |
| 215 | + in |
| 216 | + |
| 217 | + let testfn call response = |
| 218 | + let has_named = has_named_args ty in |
| 219 | + let rec inner : type a. a fn -> Rpc.call -> unit = fun f call -> |
| 220 | + match f with |
| 221 | + | Function (t, f) -> begin |
| 222 | + let (arg_rpc, call') = |
| 223 | + match get_arg call has_named t.Param.name with |
| 224 | + | Result.Ok (x,y) -> (x,y) |
| 225 | + | Result.Error (`Msg m) -> raise (MarshalError m) |
| 226 | + in |
| 227 | + verify t.Param.typedef.Rpc.Types.ty arg_rpc |> ignore; |
| 228 | + inner f call' |
| 229 | + end |
| 230 | + | Returning (t,e) -> begin |
| 231 | + match response.success with |
| 232 | + | true -> |
| 233 | + verify t.Param.typedef.Rpc.Types.ty response.contents |> ignore |
| 234 | + | false -> |
| 235 | + verify e.Error.def.Rpc.Types.ty response.contents |> ignore |
| 236 | + end |
| 237 | + in inner ty call |
| 238 | + in |
| 239 | + (* Check all calls *) |
| 240 | + let request_tests = |
| 241 | + List.mapi (fun i call -> |
| 242 | + let response = List.hd responses in |
| 243 | + let name = Printf.sprintf "Check old request for '%s': %d" wire_name i in |
| 244 | + (name, `Quick, fun () -> testfn call response)) calls in |
| 245 | + (* Now check all responses *) |
| 246 | + let response_tests = |
| 247 | + List.mapi (fun i response -> |
| 248 | + let call = List.hd calls in |
| 249 | + let name = Printf.sprintf "Check old response for '%s': %d" wire_name i in |
| 250 | + (name, `Quick, fun () -> testfn call response)) responses in |
| 251 | + |
| 252 | + tests := !tests @ request_tests @ response_tests |
| 253 | + |
| 254 | +end |
0 commit comments