Skip to content

Commit 56517e5

Browse files
committed
Make parameterized tests in email less hackish.
Or perhaps more hackish, depending on your perspective. But at least this way it is now possible to run the individual tests using the unittest CLI.
1 parent 01d7058 commit 56517e5

4 files changed

Lines changed: 122 additions & 93 deletions

File tree

Lib/test/test_email/__init__.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,82 @@ def assertDefectsEqual(self, actual, expected):
7171
for i in range(len(actual)):
7272
self.assertIsInstance(actual[i], expected[i],
7373
'item {}'.format(i))
74+
75+
76+
# Metaclass to allow for parameterized tests
77+
class Parameterized(type):
78+
79+
"""Provide a test method parameterization facility.
80+
81+
Parameters are specified as the value of a class attribute that ends with
82+
the string '_params'. Call the portion before '_params' the prefix. Then
83+
a method to be parameterized must have the same prefix, the string
84+
'_as_', and an arbitrary suffix.
85+
86+
The value of the _params attribute may be either a dictionary or a list.
87+
The values in the dictionary and the elements of the list may either be
88+
single values, or a list. If single values, they are turned into single
89+
element tuples. However derived, the resulting sequence is passed via
90+
*args to the parameterized test function.
91+
92+
In a _params dictioanry, the keys become part of the name of the generated
93+
tests. In a _params list, the values in the list are converted into a
94+
string by joining the string values of the elements of the tuple by '_' and
95+
converting any blanks into '_'s, and this become part of the name. The
96+
full name of a generated test is the portion of the _params name before the
97+
'_params' portion, plus an '_', plus the name derived as explained above.
98+
99+
For example, if we have:
100+
101+
count_params = range(2)
102+
103+
def count_as_foo_arg(self, foo):
104+
self.assertEqual(foo+1, myfunc(foo))
105+
106+
we will get parameterized test methods named:
107+
test_foo_arg_0
108+
test_foo_arg_1
109+
test_foo_arg_2
110+
111+
Or we could have:
112+
113+
example_params = {'foo': ('bar', 1), 'bing': ('bang', 2)}
114+
115+
def example_as_myfunc_input(self, name, count):
116+
self.assertEqual(name+str(count), myfunc(name, count))
117+
118+
and get:
119+
test_myfunc_input_foo
120+
test_myfunc_input_bing
121+
122+
Note: if and only if the generated test name is a valid identifier can it
123+
be used to select the test individually from the unittest command line.
124+
125+
"""
126+
127+
def __new__(meta, classname, bases, classdict):
128+
paramdicts = {}
129+
for name, attr in classdict.items():
130+
if name.endswith('_params'):
131+
if not hasattr(attr, 'keys'):
132+
d = {}
133+
for x in attr:
134+
if not hasattr(x, '__iter__'):
135+
x = (x,)
136+
n = '_'.join(str(v) for v in x).replace(' ', '_')
137+
d[n] = x
138+
attr = d
139+
paramdicts[name[:-7] + '_as_'] = attr
140+
testfuncs = {}
141+
for name, attr in classdict.items():
142+
for paramsname, paramsdict in paramdicts.items():
143+
if name.startswith(paramsname):
144+
testnameroot = 'test_' + name[len(paramsname):]
145+
for paramname, params in paramsdict.items():
146+
test = (lambda self, name=name, params=params:
147+
getattr(self, name)(*params))
148+
testname = testnameroot + '_' + paramname
149+
test.__name__ = testname
150+
testfuncs[testname] = test
151+
classdict.update(testfuncs)
152+
return super().__new__(meta, classname, bases, classdict)

Lib/test/test_email/test_generator.py

Lines changed: 9 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from email import message_from_string, message_from_bytes
55
from email.generator import Generator, BytesGenerator
66
from email import policy
7-
from test.test_email import TestEmailBase
7+
from test.test_email import TestEmailBase, Parameterized
88

99

10-
class TestGeneratorBase:
10+
class TestGeneratorBase(metaclass=Parameterized):
1111

1212
policy = policy.default
1313

@@ -80,69 +80,46 @@ def msgmaker(self, msg, policy=None):
8080
"\n"
8181
"None\n")
8282

83-
def _test_maxheaderlen_parameter(self, n):
83+
length_params = [n for n in refold_long_expected]
84+
85+
def length_as_maxheaderlen_parameter(self, n):
8486
msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
8587
s = self.ioclass()
8688
g = self.genclass(s, maxheaderlen=n, policy=self.policy)
8789
g.flatten(msg)
8890
self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
8991

90-
for n in refold_long_expected:
91-
locals()['test_maxheaderlen_parameter_' + str(n)] = (
92-
lambda self, n=n:
93-
self._test_maxheaderlen_parameter(n))
94-
95-
def _test_max_line_length_policy(self, n):
92+
def length_as_max_line_length_policy(self, n):
9693
msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
9794
s = self.ioclass()
9895
g = self.genclass(s, policy=self.policy.clone(max_line_length=n))
9996
g.flatten(msg)
10097
self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
10198

102-
for n in refold_long_expected:
103-
locals()['test_max_line_length_policy' + str(n)] = (
104-
lambda self, n=n:
105-
self._test_max_line_length_policy(n))
106-
107-
def _test_maxheaderlen_parm_overrides_policy(self, n):
99+
def length_as_maxheaderlen_parm_overrides_policy(self, n):
108100
msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
109101
s = self.ioclass()
110102
g = self.genclass(s, maxheaderlen=n,
111103
policy=self.policy.clone(max_line_length=10))
112104
g.flatten(msg)
113105
self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[n]))
114106

115-
for n in refold_long_expected:
116-
locals()['test_maxheaderlen_parm_overrides_policy' + str(n)] = (
117-
lambda self, n=n:
118-
self._test_maxheaderlen_parm_overrides_policy(n))
119-
120-
def _test_refold_none_does_not_fold(self, n):
107+
def length_as_max_line_length_with_refold_none_does_not_fold(self, n):
121108
msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
122109
s = self.ioclass()
123110
g = self.genclass(s, policy=self.policy.clone(refold_source='none',
124111
max_line_length=n))
125112
g.flatten(msg)
126113
self.assertEqual(s.getvalue(), self.typ(self.refold_long_expected[0]))
127114

128-
for n in refold_long_expected:
129-
locals()['test_refold_none_does_not_fold' + str(n)] = (
130-
lambda self, n=n:
131-
self._test_refold_none_does_not_fold(n))
132-
133-
def _test_refold_all(self, n):
115+
def length_as_max_line_length_with_refold_all_folds(self, n):
134116
msg = self.msgmaker(self.typ(self.refold_long_expected[0]))
135117
s = self.ioclass()
136118
g = self.genclass(s, policy=self.policy.clone(refold_source='all',
137119
max_line_length=n))
138120
g.flatten(msg)
139121
self.assertEqual(s.getvalue(), self.typ(self.refold_all_expected[n]))
140122

141-
for n in refold_long_expected:
142-
locals()['test_refold_all' + str(n)] = (
143-
lambda self, n=n:
144-
self._test_refold_all(n))
145-
146123
def test_crlf_control_via_policy(self):
147124
source = "Subject: test\r\n\r\ntest body\r\n"
148125
expected = source

Lib/test/test_email/test_headerregistry.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from email import errors
55
from email import policy
66
from email.message import Message
7-
from test.test_email import TestEmailBase
7+
from test.test_email import TestEmailBase, Parameterized
88
from email import headerregistry
99
from email.headerregistry import Address, Group
1010

@@ -175,9 +175,9 @@ def test_set_date_header_from_datetime(self):
175175
self.assertEqual(m['Date'].datetime, self.dt)
176176

177177

178-
class TestAddressHeader(TestHeaderBase):
178+
class TestAddressHeader(TestHeaderBase, metaclass=Parameterized):
179179

180-
examples = {
180+
example_params = {
181181

182182
'empty':
183183
('<>',
@@ -305,8 +305,8 @@ class TestAddressHeader(TestHeaderBase):
305305
# trailing comments, which aren't currently handled. comments in
306306
# general are not handled yet.
307307

308-
def _test_single_addr(self, source, defects, decoded, display_name,
309-
addr_spec, username, domain, comment):
308+
def example_as_address(self, source, defects, decoded, display_name,
309+
addr_spec, username, domain, comment):
310310
h = self.make_header('sender', source)
311311
self.assertEqual(h, decoded)
312312
self.assertDefectsEqual(h.defects, defects)
@@ -322,13 +322,8 @@ def _test_single_addr(self, source, defects, decoded, display_name,
322322
# XXX: we have no comment support yet.
323323
#self.assertEqual(a.comment, comment)
324324

325-
for name in examples:
326-
locals()['test_'+name] = (
327-
lambda self, name=name:
328-
self._test_single_addr(*self.examples[name]))
329-
330-
def _test_group_single_addr(self, source, defects, decoded, display_name,
331-
addr_spec, username, domain, comment):
325+
def example_as_group(self, source, defects, decoded, display_name,
326+
addr_spec, username, domain, comment):
332327
source = 'foo: {};'.format(source)
333328
gdecoded = 'foo: {};'.format(decoded) if decoded else 'foo:;'
334329
h = self.make_header('to', source)
@@ -344,11 +339,6 @@ def _test_group_single_addr(self, source, defects, decoded, display_name,
344339
self.assertEqual(a.username, username)
345340
self.assertEqual(a.domain, domain)
346341

347-
for name in examples:
348-
locals()['test_group_'+name] = (
349-
lambda self, name=name:
350-
self._test_group_single_addr(*self.examples[name]))
351-
352342
def test_simple_address_list(self):
353343
value = ('Fred <dinsdale@python.org>, foo@example.com, '
354344
'"Harry W. Hastings" <hasty@example.com>')
@@ -366,7 +356,7 @@ def test_simple_address_list(self):
366356
'Harry W. Hastings')
367357

368358
def test_complex_address_list(self):
369-
examples = list(self.examples.values())
359+
examples = list(self.example_params.values())
370360
source = ('dummy list:;, another: (empty);,' +
371361
', '.join([x[0] for x in examples[:4]]) + ', ' +
372362
r'"A \"list\"": ' +

Lib/test/test_email/test_pickleable.py

Lines changed: 26 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -6,83 +6,66 @@
66
import email.message
77
from email import policy
88
from email.headerregistry import HeaderRegistry
9-
from test.test_email import TestEmailBase
9+
from test.test_email import TestEmailBase, Parameterized
1010

11-
class TestPickleCopyHeader(TestEmailBase):
11+
class TestPickleCopyHeader(TestEmailBase, metaclass=Parameterized):
1212

1313
header_factory = HeaderRegistry()
1414

1515
unstructured = header_factory('subject', 'this is a test')
1616

17-
def _test_deepcopy(self, name, value):
17+
header_params = {
18+
'subject': ('subject', 'this is a test'),
19+
'from': ('from', 'frodo@mordor.net'),
20+
'to': ('to', 'a: k@b.com, y@z.com;, j@f.com'),
21+
'date': ('date', 'Tue, 29 May 2012 09:24:26 +1000'),
22+
}
23+
24+
def header_as_deepcopy(self, name, value):
1825
header = self.header_factory(name, value)
1926
h = copy.deepcopy(header)
2027
self.assertEqual(str(h), str(header))
2128

22-
def _test_pickle(self, name, value):
29+
def header_as_pickle(self, name, value):
2330
header = self.header_factory(name, value)
2431
p = pickle.dumps(header)
2532
h = pickle.loads(p)
2633
self.assertEqual(str(h), str(header))
2734

28-
headers = (
29-
('subject', 'this is a test'),
30-
('from', 'frodo@mordor.net'),
31-
('to', 'a: k@b.com, y@z.com;, j@f.com'),
32-
('date', 'Tue, 29 May 2012 09:24:26 +1000'),
33-
)
34-
35-
for header in headers:
36-
locals()['test_deepcopy_'+header[0]] = (
37-
lambda self, header=header:
38-
self._test_deepcopy(*header))
39-
40-
for header in headers:
41-
locals()['test_pickle_'+header[0]] = (
42-
lambda self, header=header:
43-
self._test_pickle(*header))
4435

36+
class TestPickleCopyMessage(TestEmailBase, metaclass=Parameterized):
4537

46-
class TestPickleCopyMessage(TestEmailBase):
47-
48-
msgs = {}
38+
# Message objects are a sequence, so we have to make them a one-tuple in
39+
# msg_params so they get passed to the parameterized test method as a
40+
# single argument instead of as a list of headers.
41+
msg_params = {}
4942

5043
# Note: there will be no custom header objects in the parsed message.
51-
msgs['parsed'] = email.message_from_string(textwrap.dedent("""\
44+
msg_params['parsed'] = (email.message_from_string(textwrap.dedent("""\
5245
Date: Tue, 29 May 2012 09:24:26 +1000
5346
From: frodo@mordor.net
5447
To: bilbo@underhill.org
5548
Subject: help
5649
5750
I think I forgot the ring.
58-
"""), policy=policy.default)
51+
"""), policy=policy.default),)
5952

60-
msgs['created'] = email.message.Message(policy=policy.default)
61-
msgs['created']['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
62-
msgs['created']['From'] = 'frodo@mordor.net'
63-
msgs['created']['To'] = 'bilbo@underhill.org'
64-
msgs['created']['Subject'] = 'help'
65-
msgs['created'].set_payload('I think I forgot the ring.')
53+
msg_params['created'] = (email.message.Message(policy=policy.default),)
54+
msg_params['created'][0]['Date'] = 'Tue, 29 May 2012 09:24:26 +1000'
55+
msg_params['created'][0]['From'] = 'frodo@mordor.net'
56+
msg_params['created'][0]['To'] = 'bilbo@underhill.org'
57+
msg_params['created'][0]['Subject'] = 'help'
58+
msg_params['created'][0].set_payload('I think I forgot the ring.')
6659

67-
def _test_deepcopy(self, msg):
60+
def msg_as_deepcopy(self, msg):
6861
msg2 = copy.deepcopy(msg)
6962
self.assertEqual(msg2.as_string(), msg.as_string())
7063

71-
def _test_pickle(self, msg):
64+
def msg_as_pickle(self, msg):
7265
p = pickle.dumps(msg)
7366
msg2 = pickle.loads(p)
7467
self.assertEqual(msg2.as_string(), msg.as_string())
7568

76-
for name, msg in msgs.items():
77-
locals()['test_deepcopy_'+name] = (
78-
lambda self, msg=msg:
79-
self._test_deepcopy(msg))
80-
81-
for name, msg in msgs.items():
82-
locals()['test_pickle_'+name] = (
83-
lambda self, msg=msg:
84-
self._test_pickle(msg))
85-
8669

8770
if __name__ == '__main__':
8871
unittest.main()

0 commit comments

Comments
 (0)