Skip to content

Commit f46e24e

Browse files
committed
Add support for nearest search node.
1 parent 896b9d0 commit f46e24e

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

typesense/api_call.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import logging
23
import json
34
import time
45

@@ -8,6 +9,8 @@
89
RequestMalformed, RequestUnauthorized,
910
ServerError, ServiceUnavailable, TypesenseClientError)
1011

12+
logger = logging.getLogger(__name__)
13+
1114

1215
class ApiCall(object):
1316
API_KEY_HEADER_NAME = 'X-TYPESENSE-API-KEY'
@@ -16,25 +19,44 @@ def __init__(self, config):
1619
self.config = config
1720
self.nodes = copy.deepcopy(self.config.nodes)
1821
self.node_index = 0
22+
self._initialize_nodes()
23+
24+
def _initialize_nodes(self):
25+
if self.config.nearest_node:
26+
self.set_node_healthcheck(self.config.nearest_node, True)
27+
28+
for node in self.nodes:
29+
self.set_node_healthcheck(node, True)
1930

20-
def check_failed_node(self, node):
31+
def node_due_for_health_check(self, node):
2132
current_epoch_ts = int(time.time())
22-
return (current_epoch_ts - node.last_access_ts) > self.config.healthcheck_interval_seconds
33+
due_for_check = (current_epoch_ts - node.last_access_ts) > self.config.healthcheck_interval_seconds
34+
if due_for_check:
35+
logger.debug('Node {}:{} is due for health check.'.format(node.host, node.port))
36+
return due_for_check
2337

2438
# Returns a healthy host from the pool in a round-robin fashion.
2539
# Might return an unhealthy host periodically to check for recovery.
2640
def get_node(self):
41+
if self.config.nearest_node:
42+
if self.config.nearest_node.healthy or self.node_due_for_health_check(self.config.nearest_node):
43+
logger.debug('Using nearest node.')
44+
return self.config.nearest_node
45+
else:
46+
logger.debug('Nearest node is unhealthy or not due for health check. Falling back to individual nodes.')
47+
2748
i = 0
2849
while i < len(self.nodes):
2950
i += 1
3051
node = self.nodes[self.node_index]
3152
self.node_index = (self.node_index + 1) % len(self.nodes)
3253

33-
if node.healthy or self.check_failed_node(node):
54+
if node.healthy or self.node_due_for_health_check(node):
3455
return node
3556

3657
# None of the nodes are marked healthy, but some of them could have become healthy since last health check.
3758
# So we will just return the next node.
59+
logger.debug('No healthy nodes were found. Returning the next node.')
3860
return self.nodes[self.node_index]
3961

4062
@staticmethod
@@ -63,14 +85,13 @@ def make_request(self, fn, endpoint, as_json, **kwargs):
6385
num_tries = 0
6486
last_exception = None
6587

88+
logger.debug('Making {} {}'.format(fn.__name__, endpoint))
89+
6690
while num_tries < (self.config.num_retries + 1):
6791
num_tries += 1
6892
node = self.get_node()
6993

70-
# We assume node to be unhealthy, unless proven healthy.
71-
# This way, we keep things DRY and don't have to repeat setting healthy as false multiple times.
72-
node.healthy = False
73-
node.last_access_ts = int(time.time())
94+
logger.debug('Try {} to node {}:{} -- healthy? {}'.format(num_tries, node.host, node.port, node.healthy))
7495

7596
try:
7697
url = node.url() + endpoint
@@ -82,24 +103,33 @@ def make_request(self, fn, endpoint, as_json, **kwargs):
82103
# Treat any status code > 0 and < 500 to be an indication that node is healthy
83104
# We exclude 0 since some clients return 0 when request fails
84105
if 0 < r.status_code < 500:
85-
node.healthy = True
106+
logger.debug('{}:{} is healthy. Status code: {}'.format(node.host, node.port, r.status_code))
107+
self.set_node_healthcheck(node, True)
86108

87109
# We should raise a custom exception if status code is not 20X
88-
if 200 <= r.status_code < 300:
110+
if not 200 <= r.status_code < 300:
89111
error_message = r.json().get('message', 'API error.')
90-
# Raised exception will be caught and retried only if it's a 50X
112+
# Raised exception will be caught and retried
91113
raise ApiCall.get_exception(r.status_code)(r.status_code, error_message)
92114

93115
return r.json() if as_json else r.text
94116
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError, requests.exceptions.HTTPError,
95117
requests.exceptions.RequestException, requests.exceptions.SSLError,
96118
HTTPStatus0Error, ServerError, ServiceUnavailable) as e:
97119
# Catch the exception and retry
120+
self.set_node_healthcheck(node, False)
121+
logger.debug('Request to {}:{} failed because of {}'.format(node.host, node.port, e))
122+
logger.debug('Sleeping for {} and retrying...'.format(self.config.retry_interval_seconds))
98123
last_exception = e
99124
time.sleep(self.config.retry_interval_seconds)
100125

126+
logger.debug('No retries left. Raising last exception: {}'.format(last_exception))
101127
raise last_exception
102128

129+
def set_node_healthcheck(self, node, is_healthy):
130+
node.healthy = is_healthy
131+
node.last_access_ts = int(time.time())
132+
103133
def get(self, endpoint, params=None, as_json=True):
104134
params = params or {}
105135
return self.make_request(requests.get, endpoint, as_json,

typesense/configuration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def __init__(self, config_dict):
2727
Node(node_dict['host'], node_dict['port'], node_dict.get('path', ''), node_dict['protocol'])
2828
)
2929

30+
self.nearest_node = config_dict.get('nearest_node', None)
31+
3032
self.api_key = config_dict.get('api_key', '')
3133
self.connection_timeout_seconds = config_dict.get('connection_timeout_seconds', 3.0)
3234
self.num_retries = config_dict.get('num_retries', 3)
@@ -48,6 +50,11 @@ def validate_config_dict(config_dict):
4850
raise ConfigError('`node` entry must be a dictionary with the following required keys: '
4951
'host, port, protocol')
5052

53+
nearest_node = config_dict.get('nearest_node', None)
54+
if nearest_node and not Configuration.validate_node_fields(nearest_node):
55+
raise ConfigError('`nearest_node` entry must be a dictionary with the following required keys: '
56+
'host, port, protocol')
57+
5158
@staticmethod
5259
def validate_node_fields(node):
5360
expected_fields = {'host', 'port', 'protocol'}

0 commit comments

Comments
 (0)