diff --git a/swh/journal/cli.py b/swh/journal/cli.py --- a/swh/journal/cli.py +++ b/swh/journal/cli.py @@ -93,7 +93,8 @@ ctx.fail('You must have a storage configured in your config file.') client = get_journal_client( - ctx, brokers=brokers, prefix=prefix, group_id=group_id) + ctx, brokers=brokers, prefix=prefix, group_id=group_id, + max_messages=max_messages) worker_fn = functools.partial(process_replay_objects, storage=storage) try: @@ -207,7 +208,7 @@ client = get_journal_client( ctx, brokers=brokers, prefix=prefix, group_id=group_id, - object_types=('content',)) + max_messages=max_messages, object_types=('content',)) worker_fn = functools.partial(process_replay_objects_content, src=objstorage_src, dst=objstorage_dst, diff --git a/swh/journal/client.py b/swh/journal/client.py --- a/swh/journal/client.py +++ b/swh/journal/client.py @@ -141,31 +141,38 @@ timeout = self.process_timeout - elapsed - message = self.consumer.poll(timeout=timeout) - if not message: - continue + num_messages = 20 + + if self.max_messages: + if nb_messages >= self.max_messages: + break + num_messages = min(num_messages, self.max_messages-nb_messages) - error = message.error() - if error is not None: - if error.fatal(): - raise KafkaException(error) - logger.info('Received non-fatal kafka error: %s', error) + messages = self.consumer.consume( + timeout=timeout, num_messages=num_messages) + if not messages: continue - nb_messages += 1 + for message in messages: + error = message.error() + if error is not None: + if error.fatal(): + raise KafkaException(error) + logger.info('Received non-fatal kafka error: %s', error) + continue - object_type = message.topic().split('.')[-1] - # Got a message from a topic we did not subscribe to. - assert object_type in self._object_types, object_type + nb_messages += 1 - objects[object_type].append( - self.value_deserializer(message.value()) - ) + object_type = message.topic().split('.')[-1] + # Got a message from a topic we did not subscribe to. + assert object_type in self._object_types, object_type - if nb_messages >= self.max_messages: - break + objects[object_type].append( + self.value_deserializer(message.value()) + ) - worker_fn(dict(objects)) + if nb_messages: + worker_fn(dict(objects)) - self.consumer.commit() + self.consumer.commit() return nb_messages diff --git a/swh/journal/tests/test_replay.py b/swh/journal/tests/test_replay.py --- a/swh/journal/tests/test_replay.py +++ b/swh/journal/tests/test_replay.py @@ -136,6 +136,8 @@ writer.send('origin_visit', 'foo', visit) queue_size = len(queue) + assert replayer.max_messages == 0 + replayer.max_messages = queue_size storage = get_storage('memory', {}) worker_fn = functools.partial(process_replay_objects, storage=storage) diff --git a/swh/journal/tests/test_write_replay.py b/swh/journal/tests/test_write_replay.py --- a/swh/journal/tests/test_write_replay.py +++ b/swh/journal/tests/test_write_replay.py @@ -5,6 +5,7 @@ import functools +import attr from hypothesis import given, settings, HealthCheck from hypothesis.strategies import lists @@ -18,6 +19,32 @@ from .utils import MockedJournalClient, MockedKafkaWriter +def empty_person_name_email(rev_or_rel): + """Empties the 'name' and 'email' fields of the author/committer fields + of a revision or release; leaving only the fullname.""" + if getattr(rev_or_rel, 'author', None): + rev_or_rel = attr.evolve( + rev_or_rel, + author=attr.evolve( + rev_or_rel.author, + name=b'', + email=b'', + ) + ) + + if getattr(rev_or_rel, 'committer', None): + rev_or_rel = attr.evolve( + rev_or_rel, + committer=attr.evolve( + rev_or_rel.committer, + name=b'', + email=b'', + ) + ) + + return rev_or_rel + + @given(lists(object_dicts(), min_size=1)) @settings(suppress_health_check=[HealthCheck.too_slow]) def test_write_replay_same_order_batches(objects): @@ -40,6 +67,8 @@ pass queue_size = len(queue) + assert replayer.max_messages == 0 + replayer.max_messages = queue_size storage2 = Storage() worker_fn = functools.partial(process_replay_objects, storage=storage2) @@ -49,11 +78,25 @@ assert replayer.consumer.committed - for attr_name in ('_contents', '_directories', '_revisions', '_releases', + for attr_name in ('_contents', '_directories', '_snapshots', '_origin_visits', '_origins'): assert getattr(storage1, attr_name) == getattr(storage2, attr_name), \ attr_name + # When hypothesis generates a revision and a release with same + # author (or committer) fullname but different name or email, then + # the storage will use the first name/email it sees. + # This first one will be either the one from the revision or the release, + # and since there is no order guarantees, storage2 has 1/2 chance of + # not seeing the same order as storage1, therefore we need to strip + # them out before comparing. + for attr_name in ('_revisions', '_releases'): + items1 = {k: empty_person_name_email(v) + for (k, v) in getattr(storage1, attr_name).items()} + items2 = {k: empty_person_name_email(v) + for (k, v) in getattr(storage2, attr_name).items()} + assert items1 == items2, attr_name + # TODO: add test for hash collision @@ -78,6 +121,8 @@ contents.append(obj) queue_size = len(queue) + assert replayer.max_messages == 0 + replayer.max_messages = queue_size storage2 = Storage() worker_fn = functools.partial(process_replay_objects_content, diff --git a/swh/journal/tests/utils.py b/swh/journal/tests/utils.py --- a/swh/journal/tests/utils.py +++ b/swh/journal/tests/utils.py @@ -47,8 +47,10 @@ self.queue = queue self.committed = False - def poll(self, timeout=None): - return self.queue.pop(0) + def consume(self, num_messages, timeout=None): + L = self.queue[0:num_messages] + self.queue[0:num_messages] = [] + return L def commit(self): if self.queue == []: