Skip to content

Commit b6cebd5

Browse files
committed
Added type annotations + some fixes to get them correct
One functional change: `CryptoOperation.read_infile()` now reads bytes from `sys.stdin` instead of text. This is necessary to be consistent with the rest of the code, which all deals with bytes.
1 parent 6760eb7 commit b6cebd5

File tree

12 files changed

+129
-116
lines changed

12 files changed

+129
-116
lines changed

rsa/_compat.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,11 @@
2121
from struct import pack
2222

2323

24-
def byte(num):## XXX
24+
def byte(num: int):
2525
"""
2626
Converts a number between 0 and 255 (both inclusive) to a base-256 (byte)
2727
representation.
2828
29-
Use it as a replacement for ``chr`` where you are expecting a byte
30-
because this will work on all current versions of Python::
31-
3229
:param num:
3330
An unsigned integer between 0 and 255 (both inclusive).
3431
:returns:
@@ -37,7 +34,7 @@ def byte(num):## XXX
3734
return pack("B", num)
3835

3936

40-
def xor_bytes(b1, b2):
37+
def xor_bytes(b1: bytes, b2: bytes) -> bytes:
4138
"""
4239
Returns the bitwise XOR result between two bytes objects, b1 ^ b2.
4340

rsa/cli.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,20 +22,21 @@
2222
import abc
2323
import sys
2424
import typing
25-
from optparse import OptionParser
25+
import optparse
2626

2727
import rsa
2828
import rsa.key
2929
import rsa.pkcs1
3030

3131
HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys())
32+
Indexable = typing.Union[typing.Tuple, typing.List[str]]
3233

3334

34-
def keygen():
35+
def keygen() -> None:
3536
"""Key generator."""
3637

3738
# Parse the CLI options
38-
parser = OptionParser(usage='usage: %prog [options] keysize',
39+
parser = optparse.OptionParser(usage='usage: %prog [options] keysize',
3940
description='Generates a new RSA keypair of "keysize" bits.')
4041

4142
parser.add_option('--pubout', type='string',
@@ -104,21 +105,22 @@ class CryptoOperation(metaclass=abc.ABCMeta):
104105

105106
key_class = rsa.PublicKey # type: typing.Type[rsa.key.AbstractKey]
106107

107-
def __init__(self):
108+
def __init__(self) -> None:
108109
self.usage = self.usage % self.__class__.__dict__
109110
self.input_help = self.input_help % self.__class__.__dict__
110111
self.output_help = self.output_help % self.__class__.__dict__
111112

112113
@abc.abstractmethod
113-
def perform_operation(self, indata, key, cli_args):
114+
def perform_operation(self, indata: bytes, key: rsa.key.AbstractKey,
115+
cli_args: Indexable):
114116
"""Performs the program's operation.
115117
116118
Implement in a subclass.
117119
118120
:returns: the data to write to the output.
119121
"""
120122

121-
def __call__(self):
123+
def __call__(self) -> None:
122124
"""Runs the program."""
123125

124126
(cli, cli_args) = self.parse_cli()
@@ -133,13 +135,13 @@ def __call__(self):
133135
if self.has_output:
134136
self.write_outfile(outdata, cli.output)
135137

136-
def parse_cli(self):
138+
def parse_cli(self) -> typing.Tuple[optparse.Values, typing.List[str]]:
137139
"""Parse the CLI options
138140
139141
:returns: (cli_opts, cli_args)
140142
"""
141143

142-
parser = OptionParser(usage=self.usage, description=self.description)
144+
parser = optparse.OptionParser(usage=self.usage, description=self.description)
143145

144146
parser.add_option('-i', '--input', type='string', help=self.input_help)
145147

@@ -158,7 +160,7 @@ def parse_cli(self):
158160

159161
return cli, cli_args
160162

161-
def read_key(self, filename, keyform):
163+
def read_key(self, filename: str, keyform: str) -> rsa.key.AbstractKey:
162164
"""Reads a public or private key."""
163165

164166
print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr)
@@ -167,7 +169,7 @@ def read_key(self, filename, keyform):
167169

168170
return self.key_class.load_pkcs1(keydata, keyform)
169171

170-
def read_infile(self, inname):
172+
def read_infile(self, inname: str) -> bytes:
171173
"""Read the input file"""
172174

173175
if inname:
@@ -176,9 +178,9 @@ def read_infile(self, inname):
176178
return infile.read()
177179

178180
print('Reading input from stdin', file=sys.stderr)
179-
return sys.stdin.read()
181+
return sys.stdin.buffer.read()
180182

181-
def write_outfile(self, outdata, outname):
183+
def write_outfile(self, outdata: bytes, outname: str) -> None:
182184
"""Write the output file"""
183185

184186
if outname:
@@ -200,9 +202,10 @@ class EncryptOperation(CryptoOperation):
200202
operation_past = 'encrypted'
201203
operation_progressive = 'encrypting'
202204

203-
def perform_operation(self, indata, pub_key, cli_args=None):
205+
def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
206+
cli_args: Indexable=()):
204207
"""Encrypts files."""
205-
208+
assert isinstance(pub_key, rsa.key.PublicKey)
206209
return rsa.encrypt(indata, pub_key)
207210

208211

@@ -217,9 +220,10 @@ class DecryptOperation(CryptoOperation):
217220
operation_progressive = 'decrypting'
218221
key_class = rsa.PrivateKey
219222

220-
def perform_operation(self, indata, priv_key, cli_args=None):
223+
def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey,
224+
cli_args: Indexable=()):
221225
"""Decrypts files."""
222-
226+
assert isinstance(priv_key, rsa.key.PrivateKey)
223227
return rsa.decrypt(indata, priv_key)
224228

225229

@@ -239,8 +243,10 @@ class SignOperation(CryptoOperation):
239243
output_help = ('Name of the file to write the signature to. Written '
240244
'to stdout if this option is not present.')
241245

242-
def perform_operation(self, indata, priv_key, cli_args):
246+
def perform_operation(self, indata: bytes, priv_key: rsa.key.AbstractKey,
247+
cli_args: Indexable):
243248
"""Signs files."""
249+
assert isinstance(priv_key, rsa.key.PrivateKey)
244250

245251
hash_method = cli_args[1]
246252
if hash_method not in HASH_METHODS:
@@ -264,8 +270,10 @@ class VerifyOperation(CryptoOperation):
264270
expected_cli_args = 2
265271
has_output = False
266272

267-
def perform_operation(self, indata, pub_key, cli_args):
273+
def perform_operation(self, indata: bytes, pub_key: rsa.key.AbstractKey,
274+
cli_args: Indexable):
268275
"""Verifies files."""
276+
assert isinstance(pub_key, rsa.key.PublicKey)
269277

270278
signature_file = cli_args[1]
271279

rsa/common.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,18 @@
1616

1717
"""Common functionality shared by several modules."""
1818

19+
import typing
20+
1921

2022
class NotRelativePrimeError(ValueError):
21-
def __init__(self, a, b, d, msg=None):
22-
super(NotRelativePrimeError, self).__init__(
23-
msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
23+
def __init__(self, a, b, d, msg=''):
24+
super().__init__(msg or "%d and %d are not relatively prime, divider=%i" % (a, b, d))
2425
self.a = a
2526
self.b = b
2627
self.d = d
2728

2829

29-
def bit_size(num):
30+
def bit_size(num: int) -> int:
3031
"""
3132
Number of bits needed to represent a integer excluding any prefix
3233
0 bits.
@@ -54,7 +55,7 @@ def bit_size(num):
5455
raise TypeError('bit_size(num) only supports integers, not %r' % type(num))
5556

5657

57-
def byte_size(number):
58+
def byte_size(number: int) -> int:
5859
"""
5960
Returns the number of bytes required to hold a specific long number.
6061
@@ -79,7 +80,7 @@ def byte_size(number):
7980
return ceil_div(bit_size(number), 8)
8081

8182

82-
def ceil_div(num, div):
83+
def ceil_div(num: int, div: int) -> int:
8384
"""
8485
Returns the ceiling function of a division between `num` and `div`.
8586
@@ -103,7 +104,7 @@ def ceil_div(num, div):
103104
return quanta
104105

105106

106-
def extended_gcd(a, b):
107+
def extended_gcd(a: int, b: int) -> typing.Tuple[int, int, int]:
107108
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
108109
"""
109110
# r = gcd(a,b) i = multiplicitive inverse of a mod b
@@ -128,7 +129,7 @@ def extended_gcd(a, b):
128129
return a, lx, ly # Return only positive values
129130

130131

131-
def inverse(x, n):
132+
def inverse(x: int, n: int) -> int:
132133
"""Returns the inverse of x % n under multiplication, a.k.a x^-1 (mod n)
133134
134135
>>> inverse(7, 4)
@@ -145,7 +146,7 @@ def inverse(x, n):
145146
return inv
146147

147148

148-
def crt(a_values, modulo_values):
149+
def crt(a_values: typing.Iterable[int], modulo_values: typing.Iterable[int]) -> int:
149150
"""Chinese Remainder Theorem.
150151
151152
Calculates x such that x = a[i] (mod m[i]) for each i.

rsa/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121
"""
2222

2323

24-
def assert_int(var, name):
24+
def assert_int(var: int, name: str):
2525
if isinstance(var, int):
2626
return
2727

2828
raise TypeError('%s should be an integer, not %s' % (name, var.__class__))
2929

3030

31-
def encrypt_int(message, ekey, n):
31+
def encrypt_int(message: int, ekey: int, n: int) -> int:
3232
"""Encrypts a message using encryption key 'ekey', working modulo n"""
3333

3434
assert_int(message, 'message')
@@ -44,7 +44,7 @@ def encrypt_int(message, ekey, n):
4444
return pow(message, ekey, n)
4545

4646

47-
def decrypt_int(cyphertext, dkey, n):
47+
def decrypt_int(cyphertext: int, dkey: int, n: int) -> int:
4848
"""Decrypts a cypher text using the decryption key 'dkey', working modulo n"""
4949

5050
assert_int(cyphertext, 'cyphertext')

0 commit comments

Comments
 (0)