Skip to content

Commit ed7ced2

Browse files
committed
1 parent d597c55 commit ed7ced2

File tree

3 files changed

+342
-17
lines changed

3 files changed

+342
-17
lines changed

click_extra/testing.py

Lines changed: 333 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,29 @@
1818
from __future__ import annotations
1919

2020
import os
21+
import sys
22+
import shlex
2123
import subprocess
2224
from pathlib import Path
2325
from textwrap import indent
24-
from typing import Iterable, Mapping, Optional, Union, cast, IO, Any, Mapping, Optional,ContextManager, Literal
26+
from typing import Iterable, Mapping, Optional, Union, cast, IO, Any, Type, Mapping, BinaryIO, Optional,ContextManager, Literal, Sequence, Iterator, Tuple
2527
import os
2628
import inspect
2729
from functools import partial
2830
from pathlib import Path
2931
from contextlib import nullcontext
3032
from unittest.mock import patch
33+
import io
34+
import contextlib
35+
import io
36+
import os
37+
import shlex
38+
import sys
39+
from types import TracebackType
40+
41+
from click import formatting
42+
from click import termui
43+
from click import utils
3144

3245
import click
3346
import click.testing
@@ -146,19 +159,128 @@ def run_cmd(*args, extra_env: EnvVars | None = None, print_output: bool = True):
146159
We need to collect them to help us identify which extra parameters passed to
147160
``invoke()`` collides with its original signature.
148161
149-
.. note::
162+
.. warning::
150163
This has been `reported upstream to Click project
151-
<https://github.com/pallets/click/issues/2110>`_ but was not considered an issue.
164+
<https://github.com/pallets/click/issues/2110>`_ but has been rejected and not
165+
considered an issue worth fixing.
152166
"""
153167

154168

169+
class BytesIOCopy(io.BytesIO):
170+
"""Patch ``io.BytesIO`` to let the written stream be copied to another.
171+
172+
.. caution::
173+
This has been `proposed upstream to Click project
174+
<https://github.com/pallets/click/pull/2523>`_ but has not been merged yet.
175+
"""
176+
177+
def __init__(self, copy_to: io.BytesIO) -> None:
178+
super().__init__()
179+
self.copy_to = copy_to
180+
181+
def flush(self) -> None:
182+
super().flush()
183+
self.copy_to.flush()
184+
185+
def write(self, b) -> int:
186+
self.copy_to.write(b)
187+
return super().write(b)
188+
189+
190+
class StreamMixer:
191+
"""Mixes ``<stdout>`` and ``<stderr>`` streams if ``mix_stderr=True``.
192+
193+
The result is available in the ``output`` attribute.
194+
195+
If ``mix_stderr=False``, the ``<stdout>`` and ``<stderr>`` streams are kept
196+
independent and the ``output`` is the same as the ``<stdout>`` stream.
197+
198+
.. caution::
199+
This has been `proposed upstream to Click project
200+
<https://github.com/pallets/click/pull/2523>`_ but has not been merged yet.
201+
"""
202+
203+
def __init__(self, mix_stderr: bool) -> None:
204+
if not mix_stderr:
205+
self.stdout = io.BytesIO()
206+
self.stderr = io.BytesIO()
207+
self.output = self.stdout
208+
209+
else:
210+
self.output = io.BytesIO()
211+
self.stdout = BytesIOCopy(copy_to=self.output)
212+
self.stderr = BytesIOCopy(copy_to=self.output)
213+
214+
215+
class ExtraResult(click.testing.Result):
216+
"""Like ``click.testing.Result``, with finer ``<stdout>`` and ``<stderr>`` streams.
217+
218+
.. caution::
219+
This has been `proposed upstream to Click project
220+
<https://github.com/pallets/click/pull/2523>`_ but has not been merged yet.
221+
"""
222+
223+
def __init__(
224+
self,
225+
runner: click.testing.CliRunner,
226+
stdout_bytes: bytes,
227+
stderr_bytes: bytes,
228+
output_bytes: bytes,
229+
return_value: Any,
230+
exit_code: int,
231+
exception: Optional[BaseException],
232+
exc_info: Optional[
233+
Tuple[Type[BaseException], BaseException, TracebackType]
234+
] = None,
235+
):
236+
"""Same as original but adds ``output_bytes`` parameter.
237+
238+
Also makes ``stderr_bytes`` mandatory.
239+
"""
240+
self.output_bytes = output_bytes
241+
super().__init__(
242+
runner=runner,
243+
stdout_bytes=stdout_bytes,
244+
stderr_bytes=stderr_bytes,
245+
return_value=return_value,
246+
exit_code=exit_code,
247+
exception=exception,
248+
exc_info=exc_info,
249+
)
250+
251+
@property
252+
def output(self) -> str:
253+
"""The terminal output as unicode string, as the user would see it.
254+
255+
.. caution::
256+
Contrary to original ``click.testing.Result.output``, it is not a proxy for
257+
``self.stdout``. It now possess its own stream to mix ``<stdout>`` and
258+
``<stderr>`` depending on the ``mix_stderr`` value.
259+
"""
260+
return self.output_bytes.decode(self.runner.charset, "replace").replace(
261+
"\r\n", "\n"
262+
)
263+
264+
@property
265+
def stderr(self) -> str:
266+
"""The standard error as unicode string.
267+
268+
.. caution::
269+
Contrary to original ``click.testing.Result.stderr``, it no longer raise an
270+
exception, and always returns the ``<stderr>`` string.
271+
"""
272+
return self.stderr_bytes.decode(self.runner.charset, "replace").replace(
273+
"\r\n", "\n"
274+
)
275+
276+
155277
class ExtraCliRunner(click.testing.CliRunner):
156278
"""Augment ``click.testing.CliRunner`` with extra features and bug fixes."""
157279

158280
force_color: bool = False
159281
"""Global class attribute to override the ``color`` parameter in ``invoke``.
160282
161-
.. note::
283+
.. info::
162284
This was initially developed to `force the initialization of the runner during
163285
the setup of Sphinx new directives <sphinx#click_extra.sphinx.setup>`_. This
164286
was the only way we found, as to patch some code we had to operate at the class
@@ -183,6 +305,210 @@ def __init__(
183305
mix_stderr=mix_stderr
184306
)
185307

308+
@contextlib.contextmanager
309+
def isolation(
310+
self,
311+
input: Optional[Union[str, bytes, IO[Any]]] = None,
312+
env: Optional[Mapping[str, Optional[str]]] = None,
313+
color: bool = False,
314+
) -> Iterator[Tuple[io.BytesIO, io.BytesIO, io.BytesIO]]:
315+
"""Copy of ``click.testing.CliRunner.isolation()`` with extra features.
316+
317+
- An additional output stream is returned, which is a mix of ``<stdout>`` and
318+
``<stderr>`` streams if ``mix_stderr=True``.
319+
320+
- Always returns the ``<stderr>`` stream.
321+
322+
.. caution::
323+
This is a hard-copy of the modified ``isolation()`` method `from click#2523 PR
324+
<https://github.com/pallets/click/pull/2523/files#diff-b07fd6fad9f9ea8be5cbcbeaf34c956703b929b2de95c56229e77c328a7c6010>`_
325+
which has not been merged upstream yet.
326+
"""
327+
bytes_input = click.testing.make_input_stream(input, self.charset)
328+
echo_input = None
329+
330+
old_stdin = sys.stdin
331+
old_stdout = sys.stdout
332+
old_stderr = sys.stderr
333+
old_forced_width = formatting.FORCED_WIDTH
334+
formatting.FORCED_WIDTH = 80
335+
336+
env = self.make_env(env)
337+
338+
stream_mixer = StreamMixer(mix_stderr=self.mix_stderr)
339+
340+
if self.echo_stdin:
341+
bytes_input = echo_input = cast(
342+
BinaryIO, click.testing.EchoingStdin(bytes_input, stream_mixer.stdout)
343+
)
344+
345+
sys.stdin = text_input = click.testing._NamedTextIOWrapper(
346+
bytes_input, encoding=self.charset, name="<stdin>", mode="r"
347+
)
348+
349+
if self.echo_stdin:
350+
# Force unbuffered reads, otherwise TextIOWrapper reads a
351+
# large chunk which is echoed early.
352+
text_input._CHUNK_SIZE = 1 # type: ignore
353+
354+
sys.stdout = click.testing._NamedTextIOWrapper(
355+
stream_mixer.stdout, encoding=self.charset, name="<stdout>", mode="w"
356+
)
357+
358+
sys.stderr = click.testing._NamedTextIOWrapper(
359+
stream_mixer.stderr,
360+
encoding=self.charset,
361+
name="<stderr>",
362+
mode="w",
363+
errors="backslashreplace",
364+
)
365+
366+
@click.testing._pause_echo(echo_input)
367+
def visible_input(prompt: Optional[str] = None) -> str:
368+
sys.stdout.write(prompt or "")
369+
val = text_input.readline().rstrip("\r\n")
370+
sys.stdout.write(f"{val}\n")
371+
sys.stdout.flush()
372+
return val
373+
374+
@click.testing._pause_echo(echo_input)
375+
def hidden_input(prompt: Optional[str] = None) -> str:
376+
sys.stdout.write(f"{prompt or ''}\n")
377+
sys.stdout.flush()
378+
return text_input.readline().rstrip("\r\n")
379+
380+
@click.testing._pause_echo(echo_input)
381+
def _getchar(echo: bool) -> str:
382+
char = sys.stdin.read(1)
383+
384+
if echo:
385+
sys.stdout.write(char)
386+
387+
sys.stdout.flush()
388+
return char
389+
390+
default_color = color
391+
392+
def should_strip_ansi(
393+
stream: Optional[IO[Any]] = None, color: Optional[bool] = None
394+
) -> bool:
395+
if color is None:
396+
return not default_color
397+
return not color
398+
399+
old_visible_prompt_func = termui.visible_prompt_func
400+
old_hidden_prompt_func = termui.hidden_prompt_func
401+
old__getchar_func = termui._getchar
402+
old_should_strip_ansi = utils.should_strip_ansi
403+
termui.visible_prompt_func = visible_input
404+
termui.hidden_prompt_func = hidden_input
405+
termui._getchar = _getchar
406+
utils.should_strip_ansi = should_strip_ansi
407+
408+
old_env = {}
409+
try:
410+
for key, value in env.items():
411+
old_env[key] = os.environ.get(key)
412+
if value is None:
413+
try:
414+
del os.environ[key]
415+
except Exception:
416+
pass
417+
else:
418+
os.environ[key] = value
419+
yield (stream_mixer.stdout, stream_mixer.stderr, stream_mixer.output)
420+
finally:
421+
for key, value in old_env.items():
422+
if value is None:
423+
try:
424+
del os.environ[key]
425+
except Exception:
426+
pass
427+
else:
428+
os.environ[key] = value
429+
sys.stdout = old_stdout
430+
sys.stderr = old_stderr
431+
sys.stdin = old_stdin
432+
termui.visible_prompt_func = old_visible_prompt_func
433+
termui.hidden_prompt_func = old_hidden_prompt_func
434+
termui._getchar = old__getchar_func
435+
utils.should_strip_ansi = old_should_strip_ansi
436+
formatting.FORCED_WIDTH = old_forced_width
437+
438+
def invoke2(
439+
self,
440+
cli: click.core.BaseCommand,
441+
args: Optional[Union[str, Sequence[str]]] = None,
442+
input: Optional[Union[str, bytes, IO[Any]]] = None,
443+
env: Optional[Mapping[str, Optional[str]]] = None,
444+
catch_exceptions: bool = True,
445+
color: bool = False,
446+
**extra: Any,
447+
) -> click.testing.Result:
448+
"""Copy of ``click.testing.CliRunner.invoke()`` with
449+
450+
.. caution::
451+
This is a hard-copy of the modified ``invoke()`` method `from click#2523 PR
452+
<https://github.com/pallets/click/pull/2523/files#diff-b07fd6fad9f9ea8be5cbcbeaf34c956703b929b2de95c56229e77c328a7c6010>`_
453+
which has not been merged upstream yet.
454+
"""
455+
exc_info = None
456+
with self.isolation(input=input, env=env, color=color) as outstreams:
457+
return_value = None
458+
exception: Optional[BaseException] = None
459+
exit_code = 0
460+
461+
if isinstance(args, str):
462+
args = shlex.split(args)
463+
464+
try:
465+
prog_name = extra.pop("prog_name")
466+
except KeyError:
467+
prog_name = self.get_default_prog_name(cli)
468+
469+
try:
470+
return_value = cli.main(args=args or (), prog_name=prog_name, **extra)
471+
except SystemExit as e:
472+
exc_info = sys.exc_info()
473+
e_code = cast(Optional[Union[int, Any]], e.code)
474+
475+
if e_code is None:
476+
e_code = 0
477+
478+
if e_code != 0:
479+
exception = e
480+
481+
if not isinstance(e_code, int):
482+
sys.stdout.write(str(e_code))
483+
sys.stdout.write("\n")
484+
e_code = 1
485+
486+
exit_code = e_code
487+
488+
except Exception as e:
489+
if not catch_exceptions:
490+
raise
491+
exception = e
492+
exit_code = 1
493+
exc_info = sys.exc_info()
494+
finally:
495+
sys.stdout.flush()
496+
stdout = outstreams[0].getvalue()
497+
stderr = outstreams[1].getvalue()
498+
output = outstreams[2].getvalue()
499+
500+
return ExtraResult(
501+
runner=self,
502+
stdout_bytes=stdout,
503+
stderr_bytes=stderr,
504+
output_bytes=output,
505+
return_value=return_value,
506+
exit_code=exit_code,
507+
exception=exception,
508+
exc_info=exc_info, # type: ignore
509+
)
510+
511+
186512
def invoke(
187513
self,
188514
cli: click.core.BaseCommand,
@@ -278,7 +604,7 @@ def invoke(
278604
extra_params_bypass = patch.object(cli, "main", partial(cli.main, **extra_bypass))
279605

280606
with extra_params_bypass:
281-
result = super().invoke(
607+
result = self.invoke2(
282608
cli=cli,
283609
args=args,
284610
input=input,
@@ -292,11 +618,12 @@ def invoke(
292618
if color is False:
293619
result.stdout_bytes = strip_ansi(result.stdout_bytes)
294620
result.stderr_bytes = strip_ansi(result.stderr_bytes)
621+
result.output_bytes = strip_ansi(result.output_bytes)
295622

296623
print_cli_run(
297624
[self.get_default_prog_name(cli)] + list(args),
298625
result.output,
299-
result.stderr if result.stderr_bytes is not None else "",
626+
result.stderr,
300627
result.exit_code,
301628
)
302629

0 commit comments

Comments
 (0)