diff --git a/swh/core/statsd.py b/swh/core/statsd.py --- a/swh/core/statsd.py +++ b/swh/core/statsd.py @@ -59,6 +59,7 @@ import logging import os import socket +import threading import warnings @@ -190,7 +191,8 @@ self.port = int(port) # Socket - self.socket = None + self._socket = None + self.lock = threading.Lock() self.max_buffer_size = max_buffer_size self._send = self._send_to_server self.encoding = 'utf-8' @@ -316,19 +318,21 @@ """ self._report(metric, 's', value, tags, sample_rate) - def get_socket(self): + @property + def socket(self): """ Return a connected socket. Note: connect the socket before assigning it to the class instance to avoid bad thread race conditions. """ - if not self.socket: - sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - sock.connect((self.host, self.port)) - self.socket = sock + with self.lock: + if not self._socket: + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.connect((self.host, self.port)) + self._socket = sock - return self.socket + return self._socket def open_buffer(self, max_buffer_size=50): """ @@ -358,9 +362,10 @@ """ Closes connected socket if connected. """ - if self.socket: - self.socket.close() - self.socket = None + with self.lock: + if self._socket: + self._socket.close() + self._socket = None def _report(self, metric, metric_type, value, tags, sample_rate): """ @@ -393,7 +398,7 @@ def _send_to_server(self, packet): try: # If set, use socket directly - (self.socket or self.get_socket()).send(packet.encode('utf-8')) + self.socket.send(packet.encode('utf-8')) except socket.timeout: return except socket.error: diff --git a/swh/core/tests/test_statsd.py b/swh/core/tests/test_statsd.py --- a/swh/core/tests/test_statsd.py +++ b/swh/core/tests/test_statsd.py @@ -105,7 +105,7 @@ """ # self.statsd = Statsd() - self.statsd.socket = FakeSocket() + self.statsd._socket = FakeSocket() def recv(self): return self.statsd.socket.recv() @@ -212,12 +212,12 @@ ) def test_socket_error(self): - self.statsd.socket = BrokenSocket() + self.statsd._socket = BrokenSocket() self.statsd.gauge('no error', 1) assert True, 'success' def test_socket_timeout(self): - self.statsd.socket = SlowSocket() + self.statsd._socket = SlowSocket() self.statsd.gauge('no error', 1) assert True, 'success' @@ -409,7 +409,7 @@ def test_context_manager(self): fake_socket = FakeSocket() with Statsd() as statsd: - statsd.socket = fake_socket + statsd._socket = fake_socket statsd.gauge('page.views', 123) statsd.timing('timer', 123) @@ -418,7 +418,7 @@ def test_batched_buffer_autoflush(self): fake_socket = FakeSocket() with Statsd() as statsd: - statsd.socket = fake_socket + statsd._socket = fake_socket for i in range(51): statsd.increment('mycounter') self.assertEqual( @@ -434,28 +434,28 @@ def test_instantiating_does_not_connect(self): local_statsd = Statsd() - self.assertEqual(None, local_statsd.socket) + self.assertEqual(None, local_statsd._socket) def test_accessing_socket_opens_socket(self): local_statsd = Statsd() try: - self.assertIsNotNone(local_statsd.get_socket()) + self.assertIsNotNone(local_statsd.socket) finally: - local_statsd.socket.close() + local_statsd.close_socket() def test_accessing_socket_multiple_times_returns_same_socket(self): local_statsd = Statsd() fresh_socket = FakeSocket() - local_statsd.socket = fresh_socket - self.assertEqual(fresh_socket, local_statsd.get_socket()) - self.assertNotEqual(FakeSocket(), local_statsd.get_socket()) + local_statsd._socket = fresh_socket + self.assertEqual(fresh_socket, local_statsd.socket) + self.assertNotEqual(FakeSocket(), local_statsd.socket) def test_tags_from_environment(self): with preserve_envvars('STATSD_TAGS'): os.environ['STATSD_TAGS'] = 'country:china,age:45' statsd = Statsd() - statsd.socket = FakeSocket() + statsd._socket = FakeSocket() statsd.gauge('gt', 123.4) self.assertEqual('gt:123.4|g|#age:45,country:china', statsd.socket.recv()) @@ -464,7 +464,7 @@ with preserve_envvars('STATSD_TAGS'): os.environ['STATSD_TAGS'] = 'country:china,age:45' statsd = Statsd(constant_tags={'country': 'canada'}) - statsd.socket = FakeSocket() + statsd._socket = FakeSocket() statsd.gauge('gt', 123.4) self.assertEqual('gt:123.4|g|#age:45,country:canada', statsd.socket.recv()) @@ -538,7 +538,7 @@ def test_namespace_added(self): local_statsd = Statsd(namespace='test-namespace') - local_statsd.socket = FakeSocket() + local_statsd._socket = FakeSocket() local_statsd.gauge('gauge', 123.4) assert local_statsd.socket.recv() == 'test-namespace.gauge:123.4|g'