diff --git a/swh/core/api/__init__.py b/swh/core/api/__init__.py --- a/swh/core/api/__init__.py +++ b/swh/core/api/__init__.py @@ -141,68 +141,67 @@ This backend class will never be instantiated, it only serves as a template.""" - def __init__(self, api_exception, url, timeout=None): + def __init__(self, api_exception, url, timeout=None, chunk_size=4096): super().__init__() self.api_exception = api_exception base_url = url if url.endswith('/') else url + '/' self.url = base_url self.session = requests.Session() self.timeout = timeout + self.chunk_size = chunk_size def _url(self, endpoint): return '%s%s' % (self.url, endpoint) - def raw_post(self, endpoint, data, **opts): + def raw_verb(self, verb, endpoint, **opts): + if 'chunk_size' in opts: + # if the chunk_size argument has been passed, consider the user + # also wants stream=True, otherwise, what's the point. + opts['stream'] = True if self.timeout and 'timeout' not in opts: opts['timeout'] = self.timeout try: - return self.session.post( + return getattr(self.session, verb)( self._url(endpoint), - data=data, **opts ) except requests.exceptions.ConnectionError as e: raise self.api_exception(e) - def raw_get(self, endpoint, params=None, **opts): - if self.timeout and 'timeout' not in opts: - opts['timeout'] = self.timeout - try: - return self.session.get( - self._url(endpoint), - params=params, - **opts - ) - except requests.exceptions.ConnectionError as e: - raise self.api_exception(e) - - def post(self, endpoint, data, params=None): - data = encode_data(data) - response = self.raw_post( - endpoint, data, params=params, + def post(self, endpoint, data, **opts): + if isinstance(data, (collections.Iterator, collections.Generator)): + data = (encode_data(x) for x in data) + else: + data = encode_data(data) + chunk_size = opts.pop('chunk_size', self.chunk_size) + response = self.raw_verb( + 'post', endpoint, data=data, headers={'content-type': 'application/x-msgpack', - 'accept': 'application/x-msgpack'}) - return self._decode_response(response) - - def get(self, endpoint, params=None): - response = self.raw_get( - endpoint, params=params, - headers={'accept': 'application/x-msgpack'}) - return self._decode_response(response) - - def post_stream(self, endpoint, data, params=None): - if not isinstance(data, collections.Iterable): - raise ValueError("`data` must be Iterable") - response = self.raw_post( - endpoint, data, params=params, - headers={'accept': 'application/x-msgpack'}) - - return self._decode_response(response) - - def get_stream(self, endpoint, params=None, chunk_size=4096): - response = self.raw_get(endpoint, params=params, stream=True, - headers={'accept': 'application/x-msgpack'}) - return response.iter_content(chunk_size) + 'accept': 'application/x-msgpack'}, + **opts) + if opts.get('stream') or \ + response.headers.get('transfer-encoding') == 'chunked': + return response.iter_content(chunk_size) + else: + return self._decode_response(response) + + def post_stream(self, endpoint, data, **opts): + return self.post(endpoint, data, stream=True, **opts) + + def get(self, endpoint, **opts): + chunk_size = opts.pop('chunk_size', self.chunk_size) + response = self.raw_verb( + 'get', endpoint, + headers={'accept': 'application/x-msgpack'}, + **opts) + if opts.get('stream') or \ + response.headers.get('transfer-encoding') == 'chunked': + return response.iter_content(chunk_size) + else: + return self._decode_response(response) + + def get_stream(self, endpoint, **opts): + return self.get(endpoint, stream=True, **opts) def _decode_response(self, response): if response.status_code == 404: