diff --git a/swh/graph/graph.py b/swh/graph/graph.py --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -159,6 +159,18 @@ node_id = self.pid2node[node_id] return GraphNode(self, node_id) + def __iter__(self): + for pid, pos in self.backend.pid2node: + yield self[pid] + + def iter_prefix(self, prefix): + for pid, pos in self.backend.pid2node.iter_prefix(prefix): + yield self[pid] + + def iter_type(self, pid_type): + for pid, pos in self.backend.pid2node.iter_type(pid_type): + yield self[pid] + @contextlib.contextmanager def load(graph_path): diff --git a/swh/graph/pid.py b/swh/graph/pid.py --- a/swh/graph/pid.py +++ b/swh/graph/pid.py @@ -236,36 +236,50 @@ f.write(str_to_bytes(pid)) f.write(struct.pack(INT_BIN_FMT, int)) - def _find(self, pid_str: str) -> Tuple[int, int]: - """lookup the integer identifier of a pid and its position + def _bisect_pos(self, pid_str: str) -> int: + """bisect the position of the given identifier. If the identifier is + not found, the position of the pid immediately after is returned. Args: pid_str: the pid as a string Returns: - a pair `(pid, pos)` with pid integer identifier and its logical - record position in the map + the logical record of the bisected position in the map """ if not isinstance(pid_str, str): - raise TypeError('PID must be a str, not ' + type(pid_str)) + raise TypeError('PID must be a str, not {}'.format(type(pid_str))) try: target = str_to_bytes(pid_str) # desired PID as bytes except ValueError: raise ValueError('invalid PID: "{}"'.format(pid_str)) - min = 0 - max = self.length - 1 - while (min <= max): - mid = (min + max) // 2 - (pid, int) = self._get_bin_record(mid) + lo = 0 + hi = self.length - 1 + while lo < hi: + mid = (lo + hi) // 2 + (pid, _value) = self._get_bin_record(mid) if pid < target: - min = mid + 1 - elif pid > target: - max = mid - 1 - else: # pid == target - return (struct.unpack(INT_BIN_FMT, int)[0], mid) + lo = mid + 1 + else: + hi = mid + return lo + + def _find(self, pid_str: str) -> Tuple[int, int]: + """lookup the integer identifier of a pid and its position + + Args: + pid_str: the pid as a string + + Returns: + a pair `(pid, pos)` with pid integer identifier and its logical + record position in the map + """ + pos = self._bisect_pos(pid_str) + pid_found, value = self._get_record(pos) + if pid_found == pid_str: + return (value, pos) raise KeyError(pid_str) def __getitem__(self, pid_str: str) -> int: @@ -292,6 +306,21 @@ for pos in range(self.length): yield self._get_record(pos) + def iter_prefix(self, prefix: str): + swh, n, t, sha = prefix.split(':') + sha = sha.ljust(40, '0') + start_pid = ':'.join([swh, n, t, sha]) + start = self._bisect_pos(start_pid) + for pos in range(start, self.length): + pid, value = self._get_record(pos) + if not pid.startswith(prefix): + break + yield pid, value + + def iter_type(self, pid_type: str) -> Iterator[Tuple[str, int]]: + prefix = 'swh:1:{}:'.format(pid_type) + yield from self.iter_prefix(prefix) + class IntToPidMap(_OnDiskMap, MutableMapping): """memory mapped map from a continuous range of 0..N (8-byte long) integers to diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py --- a/swh/graph/tests/test_graph.py +++ b/swh/graph/tests/test_graph.py @@ -110,3 +110,13 @@ .count_visit_nodes(edges='rel:rev,rev:rev') == 3) assert (graph['swh:1:rev:0000000000000000000000000000000000000009'] .count_neighbors(direction='backward') == 3) + + +def test_iter_type(graph): + rev_list = list(graph.iter_type('rev')) + actual = [n.pid for n in rev_list] + expected = ['swh:1:rev:0000000000000000000000000000000000000003', + 'swh:1:rev:0000000000000000000000000000000000000009', + 'swh:1:rev:0000000000000000000000000000000000000013', + 'swh:1:rev:0000000000000000000000000000000000000018'] + assert expected == actual diff --git a/swh/graph/tests/test_pid.py b/swh/graph/tests/test_pid.py --- a/swh/graph/tests/test_pid.py +++ b/swh/graph/tests/test_pid.py @@ -12,6 +12,7 @@ from swh.graph.pid import str_to_bytes, bytes_to_str from swh.graph.pid import PidToIntMap, IntToPidMap +from swh.model.identifiers import PID_TYPES class TestPidSerialization(unittest.TestCase): @@ -137,6 +138,14 @@ os.unlink(fname2) # tmpdir will be cleaned even if we don't reach this + def test_iter_type(self): + for t in PID_TYPES: + first_20 = list(islice(self.map.iter_type(t), 20)) + k = first_20[0][1] + expected = [('swh:1:{}:{:040x}'.format(t, i), i) + for i in range(k, k + 20)] + assert first_20 == expected + class TestIntToPidMap(unittest.TestCase):