diff --git a/swh/graph/graph.py b/swh/graph/graph.py index ed042d3..37214e4 100644 --- a/swh/graph/graph.py +++ b/swh/graph/graph.py @@ -1,166 +1,178 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import asyncio import contextlib import functools from swh.graph.backend import Backend from swh.graph.dot import dot_to_svg, graph_dot, KIND_TO_SHAPE BASE_URL = 'https://archive.softwareheritage.org/browse' KIND_TO_URL_FRAGMENT = { 'ori': '/origin/{}', 'snp': '/snapshot/{}', 'rel': '/release/{}', 'rev': '/revision/{}', 'dir': '/directory/{}', 'cnt': '/content/sha1_git:{}/', } def call_async_gen(generator, *args, **kwargs): loop = asyncio.get_event_loop() it = generator(*args, **kwargs).__aiter__() while True: try: res = loop.run_until_complete(it.__anext__()) yield res except StopAsyncIteration: break class Neighbors: """Neighbor iterator with custom O(1) length method""" def __init__(self, graph, iterator, length_func): self.graph = graph self.iterator = iterator self.length_func = length_func def __iter__(self): return self def __next__(self): succ = self.iterator.nextLong() if succ == -1: raise StopIteration return GraphNode(self.graph, succ) def __len__(self): return self.length_func() class GraphNode: """Node in the SWH graph""" def __init__(self, graph, node_id): self.graph = graph self.id = node_id def children(self): return Neighbors( self.graph, self.graph.java_graph.successors(self.id), lambda: self.graph.java_graph.outdegree(self.id)) def parents(self): return Neighbors( self.graph, self.graph.java_graph.predecessors(self.id), lambda: self.graph.java_graph.indegree(self.id)) def simple_traversal(self, ttype, direction='forward', edges='*'): for node in call_async_gen( self.graph.backend.simple_traversal, ttype, direction, edges, self.id ): yield self.graph[node] def leaves(self, *args, **kwargs): yield from self.simple_traversal('leaves', *args, **kwargs) def visit_nodes(self, *args, **kwargs): yield from self.simple_traversal('visit_nodes', *args, **kwargs) def visit_paths(self, direction='forward', edges='*'): for path in call_async_gen( self.graph.backend.visit_paths, direction, edges, self.id ): yield [self.graph[node] for node in path] def walk(self, dst, direction='forward', edges='*', traversal='dfs'): for node in call_async_gen( self.graph.backend.walk, direction, edges, traversal, self.id, dst ): yield self.graph[node] def _count(self, ttype, direction='forward', edges='*'): return self.graph.backend.count(ttype, direction, edges, self.id) count_leaves = functools.partialmethod(_count, ttype='leaves') count_neighbors = functools.partialmethod(_count, ttype='neighbors') count_visit_nodes = functools.partialmethod(_count, ttype='visit_nodes') @property def pid(self): return self.graph.node2pid[self.id] @property def kind(self): return self.pid.split(':')[2] def __str__(self): return self.pid def __repr__(self): return '<{}>'.format(self.pid) def dot_fragment(self): swh, version, kind, hash = self.pid.split(':') label = '{}:{}..{}'.format(kind, hash[0:2], hash[-2:]) url = BASE_URL + KIND_TO_URL_FRAGMENT[kind].format(hash) shape = KIND_TO_SHAPE[kind] return ('{} [label="{}", href="{}", target="_blank", shape="{}"];' .format(self.id, label, url, shape)) def _repr_svg_(self): nodes = [self, *list(self.children()), *list(self.parents())] dot = graph_dot(nodes) svg = dot_to_svg(dot) return svg class Graph: def __init__(self, backend, node2pid, pid2node): self.backend = backend self.java_graph = backend.entry.get_graph() self.node2pid = node2pid self.pid2node = pid2node def stats(self): return self.backend.stats() @property def path(self): return self.java_graph.getPath() def __len__(self): return self.java_graph.getNbNodes() def __getitem__(self, node_id): if isinstance(node_id, int): self.node2pid[node_id] # check existence return GraphNode(self, node_id) elif isinstance(node_id, str): node_id = self.pid2node[node_id] return GraphNode(self, node_id) + def __iter__(self): + for pid, pos in self.backend.pid2node: + yield self[pid] + + def iter_prefix(self, prefix): + for pid, pos in self.backend.pid2node.iter_prefix(prefix): + yield self[pid] + + def iter_type(self, pid_type): + for pid, pos in self.backend.pid2node.iter_type(pid_type): + yield self[pid] + @contextlib.contextmanager def load(graph_path): with Backend(graph_path) as backend: yield Graph(backend, backend.node2pid, backend.pid2node) diff --git a/swh/graph/pid.py b/swh/graph/pid.py index 7ed9102..9f8758b 100644 --- a/swh/graph/pid.py +++ b/swh/graph/pid.py @@ -1,371 +1,400 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import mmap import os import struct from collections.abc import MutableMapping from enum import Enum from mmap import MAP_SHARED, MAP_PRIVATE from typing import BinaryIO, Iterator, Tuple from swh.model.identifiers import PersistentId, parse_persistent_identifier PID_BIN_FMT = 'BB20s' # 2 unsigned chars + 20 bytes INT_BIN_FMT = '>q' # big endian, 8-byte integer PID_BIN_SIZE = 22 # in bytes INT_BIN_SIZE = 8 # in bytes class PidType(Enum): """types of existing PIDs, used to serialize PID type as a (char) integer note that the order does matter also for driving the binary search in PID-indexed maps """ content = 1 directory = 2 origin = 3 release = 4 revision = 5 snapshot = 6 def str_to_bytes(pid_str: str) -> bytes: """Convert a PID to a byte sequence The binary format used to represent PIDs as byte sequences is as follows: - 1 byte for the namespace version represented as a C `unsigned char` - 1 byte for the object type, as the int value of :class:`PidType` enums, represented as a C `unsigned char` - 20 bytes for the SHA1 digest as a byte sequence Args: pid: persistent identifier Returns: bytes: byte sequence representation of pid """ pid = parse_persistent_identifier(pid_str) return struct.pack(PID_BIN_FMT, pid.scheme_version, PidType[pid.object_type].value, bytes.fromhex(pid.object_id)) def bytes_to_str(bytes: bytes) -> str: """Inverse function of :func:`str_to_bytes` See :func:`str_to_bytes` for a description of the binary PID format. Args: bytes: byte sequence representation of pid Returns: pid: persistent identifier """ (version, type, bin_digest) = struct.unpack(PID_BIN_FMT, bytes) pid = PersistentId(object_type=PidType(type).name, object_id=bin_digest) return str(pid) class _OnDiskMap(): """mmap-ed on-disk sequence of fixed size records """ def __init__(self, record_size: int, fname: str, mode: str = 'rb', length: int = None): """open an existing on-disk map Args: record_size: size of each record in bytes fname: path to the on-disk map mode: file open mode, usually either 'rb' for read-only maps, 'wb' for creating new maps, or 'rb+' for updating existing ones (default: 'rb') length: map size in number of logical records; used to initialize writable maps at creation time. Must be given when mode is 'wb' and the map doesn't exist on disk; ignored otherwise """ os_modes = { 'rb': os.O_RDONLY, 'wb': os.O_RDWR | os.O_CREAT, 'rb+': os.O_RDWR } if mode not in os_modes: raise ValueError('invalid file open mode: ' + mode) new_map = (mode == 'wb') writable_map = mode in ['wb', 'rb+'] self.record_size = record_size self.fd = os.open(fname, os_modes[mode]) if new_map: if length is None: raise ValueError('missing length when creating new map') os.truncate(self.fd, length * self.record_size) self.size = os.path.getsize(fname) (self.length, remainder) = divmod(self.size, record_size) if remainder: raise ValueError( 'map size {} is not a multiple of the record size {}'.format( self.size, record_size)) self.mm = mmap.mmap( self.fd, self.size, flags=MAP_SHARED if writable_map else MAP_PRIVATE) def close(self) -> None: """close the map shuts down both the mmap and the underlying file descriptor """ if not self.mm.closed: self.mm.close() os.close(self.fd) def __len__(self) -> int: return self.length def __delitem__(self, pos: int) -> None: raise NotImplementedError('cannot delete records from fixed-size map') class PidToIntMap(_OnDiskMap, MutableMapping): """memory mapped map from PID (:ref:`persistent-identifiers`) to a continuous range 0..N of (8-byte long) integers This is the converse mapping of :class:`IntToPidMap`. The on-disk serialization format is a sequence of fixed length (30 bytes) records with the following fields: - PID (22 bytes): binary PID representation as per :func:`str_to_bytes` - long (8 bytes): big endian long integer The records are sorted lexicographically by PID type and checksum, where type is the integer value of :class:`PidType`. PID lookup in the map is performed via binary search. Hence a huge map with, say, 11 B entries, will require ~30 disk seeks. Note that, due to fixed size + ordering, it is not possible to create these maps by random writing. Hence, __setitem__ can be used only to *update* the value associated to an existing key, rather than to add a missing item. To create an entire map from scratch, you should do so *sequentially*, using static method :meth:`write_record` (or, at your own risk, by hand via the mmap :attr:`mm`). """ # record binary format: PID + a big endian 8-byte big endian integer RECORD_BIN_FMT = '>' + PID_BIN_FMT + 'q' RECORD_SIZE = PID_BIN_SIZE + INT_BIN_SIZE def __init__(self, fname: str, mode: str = 'rb', length: int = None): """open an existing on-disk map Args: fname: path to the on-disk map mode: file open mode, usually either 'rb' for read-only maps, 'wb' for creating new maps, or 'rb+' for updating existing ones (default: 'rb') length: map size in number of logical records; used to initialize read-write maps at creation time. Must be given when mode is 'wb'; ignored otherwise """ super().__init__(self.RECORD_SIZE, fname, mode=mode, length=length) def _get_bin_record(self, pos: int) -> Tuple[bytes, bytes]: """seek and return the (binary) record at a given (logical) position see :func:`_get_record` for an equivalent function with additional deserialization Args: pos: 0-based record number Returns: a pair `(pid, int)`, where pid and int are bytes """ rec_pos = pos * self.RECORD_SIZE int_pos = rec_pos + PID_BIN_SIZE return (self.mm[rec_pos:int_pos], self.mm[int_pos:int_pos+INT_BIN_SIZE]) def _get_record(self, pos: int) -> Tuple[str, int]: """seek and return the record at a given (logical) position moral equivalent of :func:`_get_bin_record`, with additional deserialization to non-bytes types Args: pos: 0-based record number Returns: a pair `(pid, int)`, where pid is a string-based PID and int the corresponding integer identifier """ (pid_bytes, int_bytes) = self._get_bin_record(pos) return (bytes_to_str(pid_bytes), struct.unpack(INT_BIN_FMT, int_bytes)[0]) @classmethod def write_record(cls, f: BinaryIO, pid: str, int: int) -> None: """write a logical record to a file-like object Args: f: file-like object to write the record to pid: textual PID int: PID integer identifier """ f.write(str_to_bytes(pid)) f.write(struct.pack(INT_BIN_FMT, int)) - def _find(self, pid_str: str) -> Tuple[int, int]: - """lookup the integer identifier of a pid and its position + def _bisect_pos(self, pid_str: str) -> int: + """bisect the position of the given identifier. If the identifier is + not found, the position of the pid immediately after is returned. Args: pid_str: the pid as a string Returns: - a pair `(pid, pos)` with pid integer identifier and its logical - record position in the map + the logical record of the bisected position in the map """ if not isinstance(pid_str, str): - raise TypeError('PID must be a str, not ' + type(pid_str)) + raise TypeError('PID must be a str, not {}'.format(type(pid_str))) try: target = str_to_bytes(pid_str) # desired PID as bytes except ValueError: raise ValueError('invalid PID: "{}"'.format(pid_str)) - min = 0 - max = self.length - 1 - while (min <= max): - mid = (min + max) // 2 - (pid, int) = self._get_bin_record(mid) + lo = 0 + hi = self.length - 1 + while lo < hi: + mid = (lo + hi) // 2 + (pid, _value) = self._get_bin_record(mid) if pid < target: - min = mid + 1 - elif pid > target: - max = mid - 1 - else: # pid == target - return (struct.unpack(INT_BIN_FMT, int)[0], mid) + lo = mid + 1 + else: + hi = mid + return lo + + def _find(self, pid_str: str) -> Tuple[int, int]: + """lookup the integer identifier of a pid and its position + + Args: + pid_str: the pid as a string + + Returns: + a pair `(pid, pos)` with pid integer identifier and its logical + record position in the map + """ + pos = self._bisect_pos(pid_str) + pid_found, value = self._get_record(pos) + if pid_found == pid_str: + return (value, pos) raise KeyError(pid_str) def __getitem__(self, pid_str: str) -> int: """lookup the integer identifier of a PID Args: pid: the PID as a string Returns: the integer identifier of pid """ return self._find(pid_str)[0] # return element, ignore position def __setitem__(self, pid_str: str, int: str) -> None: (_pid, pos) = self._find(pid_str) # might raise KeyError and that's OK rec_pos = pos * self.RECORD_SIZE int_pos = rec_pos + PID_BIN_SIZE self.mm[rec_pos:int_pos] = str_to_bytes(pid_str) self.mm[int_pos:int_pos+INT_BIN_SIZE] = struct.pack(INT_BIN_FMT, int) def __iter__(self) -> Iterator[Tuple[str, int]]: for pos in range(self.length): yield self._get_record(pos) + def iter_prefix(self, prefix: str): + swh, n, t, sha = prefix.split(':') + sha = sha.ljust(40, '0') + start_pid = ':'.join([swh, n, t, sha]) + start = self._bisect_pos(start_pid) + for pos in range(start, self.length): + pid, value = self._get_record(pos) + if not pid.startswith(prefix): + break + yield pid, value + + def iter_type(self, pid_type: str) -> Iterator[Tuple[str, int]]: + prefix = 'swh:1:{}:'.format(pid_type) + yield from self.iter_prefix(prefix) + class IntToPidMap(_OnDiskMap, MutableMapping): """memory mapped map from a continuous range of 0..N (8-byte long) integers to PIDs (:ref:`persistent-identifiers`) This is the converse mapping of :class:`PidToIntMap`. The on-disk serialization format is a sequence of fixed length (22 bytes), where each record is the binary representation of PID as per :func:`str_to_bytes`. The records are sorted by long integer, so that integer lookup is possible via fixed-offset seek. """ RECORD_BIN_FMT = PID_BIN_FMT RECORD_SIZE = PID_BIN_SIZE def __init__(self, fname: str, mode: str = 'rb', length: int = None): """open an existing on-disk map Args: fname: path to the on-disk map mode: file open mode, usually either 'rb' for read-only maps, 'wb' for creating new maps, or 'rb+' for updating existing ones (default: 'rb') size: map size in number of logical records; used to initialize read-write maps at creation time. Must be given when mode is 'wb'; ignored otherwise length: passed to :class:`_OnDiskMap` """ super().__init__(self.RECORD_SIZE, fname, mode=mode, length=length) def _get_bin_record(self, pos: int) -> bytes: """seek and return the (binary) PID at a given (logical) position Args: pos: 0-based record number Returns: PID as a byte sequence """ rec_pos = pos * self.RECORD_SIZE return self.mm[rec_pos:rec_pos+self.RECORD_SIZE] @classmethod def write_record(cls, f: BinaryIO, pid: str) -> None: """write a PID to a file-like object Args: f: file-like object to write the record to pid: textual PID """ f.write(str_to_bytes(pid)) def __getitem__(self, pos: int) -> str: orig_pos = pos if pos < 0: pos = len(self) + pos if not (0 <= pos < len(self)): raise IndexError(orig_pos) return bytes_to_str(self._get_bin_record(pos)) def __setitem__(self, pos: int, pid: str) -> None: rec_pos = pos * self.RECORD_SIZE self.mm[rec_pos:rec_pos+self.RECORD_SIZE] = str_to_bytes(pid) def __iter__(self) -> Iterator[Tuple[int, str]]: for pos in range(self.length): yield (pos, self[pos]) diff --git a/swh/graph/tests/test_graph.py b/swh/graph/tests/test_graph.py index 4fe848d..4de3734 100644 --- a/swh/graph/tests/test_graph.py +++ b/swh/graph/tests/test_graph.py @@ -1,112 +1,122 @@ import pytest def test_graph(graph): assert len(graph) == 21 obj = 'swh:1:dir:0000000000000000000000000000000000000008' node = graph[obj] assert str(node) == obj assert len(node.children()) == 3 assert len(node.parents()) == 2 actual = {p.pid for p in node.children()} expected = { 'swh:1:cnt:0000000000000000000000000000000000000001', 'swh:1:dir:0000000000000000000000000000000000000006', 'swh:1:cnt:0000000000000000000000000000000000000007' } assert expected == actual actual = {p.pid for p in node.parents()} expected = { 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:dir:0000000000000000000000000000000000000012', } assert expected == actual def test_invalid_pid(graph): with pytest.raises(IndexError): graph[1337] with pytest.raises(IndexError): graph[len(graph) + 1] with pytest.raises(KeyError): graph['swh:1:dir:0000000000000000000000000000000420000012'] def test_leaves(graph): actual = list(graph['swh:1:ori:0000000000000000000000000000000000000021'] .leaves()) actual = [p.pid for p in actual] expected = [ 'swh:1:cnt:0000000000000000000000000000000000000001', 'swh:1:cnt:0000000000000000000000000000000000000004', 'swh:1:cnt:0000000000000000000000000000000000000005', 'swh:1:cnt:0000000000000000000000000000000000000007' ] assert set(actual) == set(expected) def test_visit_nodes(graph): actual = list(graph['swh:1:rel:0000000000000000000000000000000000000010'] .visit_nodes(edges='rel:rev,rev:rev')) actual = [p.pid for p in actual] expected = [ 'swh:1:rel:0000000000000000000000000000000000000010', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003' ] assert set(actual) == set(expected) def test_visit_paths(graph): actual = list(graph['swh:1:snp:0000000000000000000000000000000000000020'] .visit_paths(edges='snp:*,rev:*')) actual = [tuple(n.pid for n in path) for path in actual] expected = [ ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:rev:0000000000000000000000000000000000000003', 'swh:1:dir:0000000000000000000000000000000000000002' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rev:0000000000000000000000000000000000000009', 'swh:1:dir:0000000000000000000000000000000000000008' ), ( 'swh:1:snp:0000000000000000000000000000000000000020', 'swh:1:rel:0000000000000000000000000000000000000010' ) ] assert set(actual) == set(expected) def test_walk(graph): actual = list(graph['swh:1:dir:0000000000000000000000000000000000000016'] .walk('rel', edges='dir:dir,dir:rev,rev:*', direction='backward', traversal='bfs')) actual = [p.pid for p in actual] expected = [ 'swh:1:dir:0000000000000000000000000000000000000016', 'swh:1:dir:0000000000000000000000000000000000000017', 'swh:1:rev:0000000000000000000000000000000000000018', 'swh:1:rel:0000000000000000000000000000000000000019' ] assert set(actual) == set(expected) def test_count(graph): assert (graph['swh:1:ori:0000000000000000000000000000000000000021'] .count_leaves() == 4) assert (graph['swh:1:rel:0000000000000000000000000000000000000010'] .count_visit_nodes(edges='rel:rev,rev:rev') == 3) assert (graph['swh:1:rev:0000000000000000000000000000000000000009'] .count_neighbors(direction='backward') == 3) + + +def test_iter_type(graph): + rev_list = list(graph.iter_type('rev')) + actual = [n.pid for n in rev_list] + expected = ['swh:1:rev:0000000000000000000000000000000000000003', + 'swh:1:rev:0000000000000000000000000000000000000009', + 'swh:1:rev:0000000000000000000000000000000000000013', + 'swh:1:rev:0000000000000000000000000000000000000018'] + assert expected == actual diff --git a/swh/graph/tests/test_pid.py b/swh/graph/tests/test_pid.py index 033fe63..1b04059 100644 --- a/swh/graph/tests/test_pid.py +++ b/swh/graph/tests/test_pid.py @@ -1,183 +1,192 @@ # Copyright (C) 2019 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information import os import shutil import tempfile import unittest from itertools import islice from swh.graph.pid import str_to_bytes, bytes_to_str from swh.graph.pid import PidToIntMap, IntToPidMap +from swh.model.identifiers import PID_TYPES class TestPidSerialization(unittest.TestCase): pairs = [ ('swh:1:cnt:94a9ed024d3859793618152ea559a168bbcbb5e2', bytes.fromhex('01' + '01' + '94a9ed024d3859793618152ea559a168bbcbb5e2')), ('swh:1:dir:d198bc9d7a6bcf6db04f476d29314f157507d505', bytes.fromhex('01' + '02' + 'd198bc9d7a6bcf6db04f476d29314f157507d505')), ('swh:1:ori:b63a575fe3faab7692c9f38fb09d4bb45651bb0f', bytes.fromhex('01' + '03' + 'b63a575fe3faab7692c9f38fb09d4bb45651bb0f')), ('swh:1:rel:22ece559cc7cc2364edc5e5593d63ae8bd229f9f', bytes.fromhex('01' + '04' + '22ece559cc7cc2364edc5e5593d63ae8bd229f9f')), ('swh:1:rev:309cf2674ee7a0749978cf8265ab91a60aea0f7d', bytes.fromhex('01' + '05' + '309cf2674ee7a0749978cf8265ab91a60aea0f7d')), ('swh:1:snp:c7c108084bc0bf3d81436bf980b46e98bd338453', bytes.fromhex('01' + '06' + 'c7c108084bc0bf3d81436bf980b46e98bd338453')), ] def test_str_to_bytes(self): for (pid_str, pid_bytes) in self.pairs: self.assertEqual(str_to_bytes(pid_str), pid_bytes) def test_bytes_to_str(self): for (pid_str, pid_bytes) in self.pairs: self.assertEqual(bytes_to_str(pid_bytes), pid_str) def test_round_trip(self): for (pid_str, pid_bytes) in self.pairs: self.assertEqual(pid_str, bytes_to_str(str_to_bytes(pid_str))) self.assertEqual(pid_bytes, str_to_bytes(bytes_to_str(pid_bytes))) def gen_records(types=['cnt', 'dir', 'ori', 'rel', 'rev', 'snp'], length=10000): """generate sequential PID/int records, suitable for filling int<->pid maps for testing swh-graph on-disk binary databases Args: types (list): list of PID types to be generated, specified as the corresponding 3-letter component in PIDs length (int): number of PIDs to generate *per type* Yields: pairs (pid, int) where pid is a textual PID and int its sequential integer identifier """ pos = 0 for t in sorted(types): for i in range(0, length): seq = format(pos, 'x') # current position as hex string pid = 'swh:1:{}:{}{}'.format(t, '0' * (40 - len(seq)), seq) yield (pid, pos) pos += 1 # pairs PID/position in the sequence generated by :func:`gen_records` above MAP_PAIRS = [ ('swh:1:cnt:0000000000000000000000000000000000000000', 0), ('swh:1:cnt:000000000000000000000000000000000000002a', 42), ('swh:1:dir:0000000000000000000000000000000000002afc', 11004), ('swh:1:ori:00000000000000000000000000000000000056ce', 22222), ('swh:1:rel:0000000000000000000000000000000000008235', 33333), ('swh:1:rev:000000000000000000000000000000000000ad9c', 44444), ('swh:1:snp:000000000000000000000000000000000000ea5f', 59999), ] class TestPidToIntMap(unittest.TestCase): @classmethod def setUpClass(cls): """create reasonably sized (~2 MB) PID->int map to test on-disk DB """ cls.tmpdir = tempfile.mkdtemp(prefix='swh.graph.test.') cls.fname = os.path.join(cls.tmpdir, 'pid2int.bin') with open(cls.fname, 'wb') as f: for (pid, i) in gen_records(length=10000): PidToIntMap.write_record(f, pid, i) @classmethod def tearDownClass(cls): shutil.rmtree(cls.tmpdir) def setUp(self): self.map = PidToIntMap(self.fname) def tearDown(self): self.map.close() def test_lookup(self): for (pid, pos) in MAP_PAIRS: self.assertEqual(self.map[pid], pos) def test_missing(self): with self.assertRaises(KeyError): self.map['swh:1:ori:0101010100000000000000000000000000000000'], with self.assertRaises(KeyError): self.map['swh:1:cnt:0101010100000000000000000000000000000000'], def test_type_error(self): with self.assertRaises(TypeError): self.map[42] with self.assertRaises(TypeError): self.map[1.2] def test_update(self): fname2 = self.fname + '.update' shutil.copy(self.fname, fname2) # fresh map copy map2 = PidToIntMap(fname2, mode='rb+') for (pid, int) in islice(map2, 11): # update the first N items new_int = int + 42 map2[pid] = new_int self.assertEqual(map2[pid], new_int) # check updated value os.unlink(fname2) # tmpdir will be cleaned even if we don't reach this + def test_iter_type(self): + for t in PID_TYPES: + first_20 = list(islice(self.map.iter_type(t), 20)) + k = first_20[0][1] + expected = [('swh:1:{}:{:040x}'.format(t, i), i) + for i in range(k, k + 20)] + assert first_20 == expected + class TestIntToPidMap(unittest.TestCase): @classmethod def setUpClass(cls): """create reasonably sized (~1 MB) int->PID map to test on-disk DB """ cls.tmpdir = tempfile.mkdtemp(prefix='swh.graph.test.') cls.fname = os.path.join(cls.tmpdir, 'int2pid.bin') with open(cls.fname, 'wb') as f: for (pid, _i) in gen_records(length=10000): IntToPidMap.write_record(f, pid) @classmethod def tearDownClass(cls): shutil.rmtree(cls.tmpdir) def setUp(self): self.map = IntToPidMap(self.fname) def tearDown(self): self.map.close() def test_lookup(self): for (pid, pos) in MAP_PAIRS: self.assertEqual(self.map[pos], pid) def test_out_of_bounds(self): with self.assertRaises(IndexError): self.map[1000000] with self.assertRaises(IndexError): self.map[-1000000] def test_update(self): fname2 = self.fname + '.update' shutil.copy(self.fname, fname2) # fresh map copy map2 = IntToPidMap(fname2, mode='rb+') for (int, pid) in islice(map2, 11): # update the first N items new_pid = pid.replace(':0', ':f') # mangle first hex digit map2[int] = new_pid self.assertEqual(map2[int], new_pid) # check updated value os.unlink(fname2) # tmpdir will be cleaned even if we don't reach this