Skip to content

Commit 77dd1a5

Browse files
committed
Merge pull request #11 from rdobson/ssh_perf
Improve the performance of grabbing remote data from a host
2 parents 82211e4 + d2fffe5 commit 77dd1a5

File tree

2 files changed

+33
-10
lines changed

2 files changed

+33
-10
lines changed

hwinfo/tools/inspector.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
from hwinfo.host import dmidecode
1414
from hwinfo.host import cpuinfo
1515

16-
def remote_command(host, username, password, cmd):
16+
def get_ssh_client(host, username, password, timeout=10):
1717
client = paramiko.SSHClient()
1818
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
19-
client.connect(host, username=username, password=password, timeout=10)
19+
client.connect(host, username=username, password=password, timeout=timeout)
20+
return client
21+
22+
def remote_command(client, cmd):
2023
cmdstr = ' '.join(cmd)
2124
#print "Executing '%s' on host '%s'" % (cmdstr, host)
2225
_, stdout, stderr = client.exec_command(cmdstr)
2326
output = stdout.readlines()
2427
error = stderr.readlines()
2528
if error:
2629
raise Exception("stderr: %s" % error)
27-
client.close()
2830
return ''.join(output)
2931

3032
def local_command(cmd):
@@ -44,12 +46,22 @@ def __init__(self, host='localhost', username=None, password=None):
4446
self.host = host
4547
self.username = username
4648
self.password = password
49+
self.client = None
50+
if self.is_remote():
51+
self.client = get_ssh_client(self.host, self.username, self.password)
52+
53+
def __del__(self):
54+
if self.client:
55+
self.client.close()
56+
57+
def is_remote(self):
58+
return self.host != 'localhost'
4759

4860
def exec_command(self, cmd):
49-
if self.host == 'localhost':
50-
return local_command(cmd)
61+
if self.is_remote():
62+
return remote_command(self.client, cmd)
5163
else:
52-
return remote_command(self.host, self.username, self.password, cmd)
64+
return local_command(cmd)
5365

5466
def get_lspci_data(self):
5567
return self.exec_command(['lspci', '-nnmm'])

hwinfo/tools/tests/test_inspector.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ def test_local_exec_command(self, local_command):
1313
host.exec_command('ls')
1414
inspector.local_command.assert_called_once_with('ls')
1515

16+
@patch('hwinfo.tools.inspector.get_ssh_client')
1617
@patch('hwinfo.tools.inspector.remote_command')
17-
def test_remote_exec_command(self, remote_command):
18+
def test_remote_exec_command(self, remote_command, get_ssh_client):
19+
mclient = get_ssh_client.return_value = mock.MagicMock()
1820
host = inspector.Host('mymachine', 'root', 'pass')
1921
host.exec_command('ls')
20-
inspector.remote_command.assert_called_once_with('mymachine', 'root', 'pass', 'ls')
22+
inspector.remote_command.assert_called_once_with(mclient, 'ls')
2123

2224
@patch('hwinfo.tools.inspector.Host.exec_command')
2325
def test_get_pci_devices(self, exec_command):
@@ -35,6 +37,15 @@ def test_get_info(self, mock_exec_command, mock_dmidecode_parser_cls):
3537
rec = host.get_info()
3638
self.assertEqual(rec, {'key':'value'})
3739

40+
def test_is_not_remote(self):
41+
host = inspector.Host()
42+
self.assertEqual(host.is_remote(), False)
43+
44+
@patch('hwinfo.tools.inspector.get_ssh_client')
45+
def test_is_remote(self, get_ssh_client):
46+
get_ssh_client.return_value = mock.MagicMock()
47+
host = inspector.Host('test', 'user', 'pass')
48+
self.assertEqual(host.is_remote(), True)
3849

3950
class RemoteCommandTests(unittest.TestCase):
4051

@@ -47,15 +58,15 @@ def setUp(self):
4758
def test_ssh_connect(self, ssh_client_cls):
4859
client = ssh_client_cls.return_value = mock.Mock()
4960
client.exec_command.return_value = self.stdout, self.stdin, self.stderr
50-
inspector.remote_command('test', 'user', 'pass', 'ls')
61+
inspector.get_ssh_client('test', 'user', 'pass')
5162
client.connect.assert_called_with('test', password='pass', username='user', timeout=10)
5263

5364
@patch('paramiko.SSHClient')
5465
def test_ssh_connect_error(self, ssh_client_cls):
5566
client = ssh_client_cls.return_value = mock.Mock()
5667
client.exec_command.return_value = self.stdout, self.stdin, StringIO("Error")
5768
with self.assertRaises(Exception) as context:
58-
inspector.remote_command('test', 'user', 'pass', 'ls')
69+
inspector.remote_command(client, 'ls')
5970
self.assertEqual(context.exception.message, "stderr: ['Error']")
6071

6172
class LocalCommandTests(unittest.TestCase):

0 commit comments

Comments
 (0)