diff --git a/msal/oauth2cli/authcode.py b/msal/oauth2cli/authcode.py index 6227d544..bcef60b8 100644 --- a/msal/oauth2cli/authcode.py +++ b/msal/oauth2cli/authcode.py @@ -7,6 +7,7 @@ """ import logging import socket +import sys from string import Template import threading import time @@ -103,7 +104,17 @@ def log_message(self, format, *args): logger.debug(format, *args) # To override the default log-to-stderr behavior -class _AuthCodeHttpServer(HTTPServer): +class _AuthCodeHttpServer(HTTPServer, object): + def __init__(self, server_address, *args, **kwargs): + _, port = server_address + if port and (sys.platform == "win32" or is_wsl()): + # The default allow_reuse_address is True. It works fine on non-Windows. + # On Windows, it undesirably allows multiple servers listening on same port, + # yet the second server would not receive any incoming request. + # So, we need to turn it off. + self.allow_reuse_address = False + super(_AuthCodeHttpServer, self).__init__(server_address, *args, **kwargs) + def handle_timeout(self): # It will be triggered when no request comes in self.timeout seconds. # See https://docs.python.org/3/library/socketserver.html#socketserver.BaseServer.handle_timeout diff --git a/tests/test_authcode.py b/tests/test_authcode.py new file mode 100644 index 00000000..c7e7565f --- /dev/null +++ b/tests/test_authcode.py @@ -0,0 +1,26 @@ +import unittest +import socket +import sys + +from msal.oauth2cli.authcode import AuthCodeReceiver + + +class TestAuthCodeReceiver(unittest.TestCase): + def test_setup_at_a_given_port_and_teardown(self): + port = 12345 # Assuming this port is available + with AuthCodeReceiver(port=port) as receiver: + self.assertEqual(port, receiver.get_port()) + + def test_setup_at_a_ephemeral_port_and_teardown(self): + port = 0 + with AuthCodeReceiver(port=port) as receiver: + self.assertNotEqual(port, receiver.get_port()) + + def test_no_two_concurrent_receivers_can_listen_on_same_port(self): + port = 12345 # Assuming this port is available + with AuthCodeReceiver(port=port) as receiver: + expected_error = OSError if sys.version_info[0] > 2 else socket.error + with self.assertRaises(expected_error): + with AuthCodeReceiver(port=port) as receiver2: + pass +