diff --git a/sql/upgrades/143.sql b/sql/upgrades/143.sql new file mode 100644 --- /dev/null +++ b/sql/upgrades/143.sql @@ -0,0 +1,94 @@ +-- SWH DB schema upgrade +-- from_version: 142 +-- to_version: 143 +-- description: Remove origin ids + +insert into dbversion(version, release, description) + values(143, now(), 'Work In Progress'); + +create or replace function swh_origin_visit_add(origin_url text, date timestamptz, type text) + returns bigint + language sql +as $$ + with origin_id as ( + select id + from origin + where url = origin_url + ), last_known_visit as ( + select coalesce(max(visit), 0) as visit + from origin_visit + where origin = (select id from origin_id) + ) + insert into origin_visit (origin, date, type, visit, status) + values ((select id from origin_id), date, type, + (select visit from last_known_visit) + 1, 'ongoing') + returning visit; +$$; + +create or replace function swh_visit_find_by_date(origin_url text, visit_date timestamptz default NOW()) + returns setof origin_visit + language plpgsql + stable +as $$ +declare + origin_id bigint; +begin + select id into origin_id from origin where url=origin_url; + return query + with closest_two_visits as (( + select ov, (date - visit_date), visit as interval + from origin_visit ov + where ov.origin = origin_id + and ov.date >= visit_date + order by ov.date asc, ov.visit desc + limit 1 + ) union ( + select ov, (visit_date - date), visit as interval + from origin_visit ov + where ov.origin = origin_id + and ov.date < visit_date + order by ov.date desc, ov.visit desc + limit 1 + )) select (ov).* from closest_two_visits order by interval, visit limit 1; +end +$$; + +drop function swh_visit_get; + +alter type origin_metadata_signature + rename attribute origin_id to origin_url; + +alter type origin_metadata_signature + alter attribute origin_url set data type text; + +create or replace function swh_origin_metadata_get_by_origin( + origin text) + returns setof origin_metadata_signature + language sql + stable +as $$ + select om.id as id, o.url as origin_url, discovery_date, tool_id, om.metadata, + mp.id as provider_id, provider_name, provider_type, provider_url + from origin_metadata as om + inner join metadata_provider mp on om.provider_id = mp.id + inner join origin o on om.origin_id = o.id + where o.url = origin + order by discovery_date desc; +$$; + +create or replace function swh_origin_metadata_get_by_provider_type( + origin_url text, + provider_type text) + returns setof origin_metadata_signature + language sql + stable +as $$ + select om.id as id, o.url as origin_url, discovery_date, tool_id, om.metadata, + mp.id as provider_id, provider_name, provider_type, provider_url + from origin_metadata as om + inner join metadata_provider mp on om.provider_id = mp.id + inner join origin o on om.origin_id = o.id + where o.url = origin_url + and mp.provider_type = provider_type + order by discovery_date desc; +$$; diff --git a/swh/storage/algos/origin.py b/swh/storage/algos/origin.py --- a/swh/storage/algos/origin.py +++ b/swh/storage/algos/origin.py @@ -13,7 +13,6 @@ Yields: dict: the origin dictionary with the keys: - - id: origin's id - type: origin's type - url: origin's url """ @@ -28,6 +27,8 @@ if not origins: break start = origins[-1]['id'] + 1 - yield from origins + for origin in origins: + del origin['id'] + yield origin if origin_to and start > origin_to: break diff --git a/swh/storage/api/client.py b/swh/storage/api/client.py --- a/swh/storage/api/client.py +++ b/swh/storage/api/client.py @@ -213,16 +213,16 @@ def tool_get(self, tool): return self.post('tool/data', {'tool': tool}) - def origin_metadata_add(self, origin_id, ts, provider, tool, metadata): - return self.post('origin/metadata/add', {'origin_id': origin_id, + def origin_metadata_add(self, origin_url, ts, provider, tool, metadata): + return self.post('origin/metadata/add', {'origin_url': origin_url, 'ts': ts, 'provider': provider, 'tool': tool, 'metadata': metadata}) - def origin_metadata_get_by(self, origin_id, provider_type=None): + def origin_metadata_get_by(self, origin_url, provider_type=None): return self.post('origin/metadata/get', { - 'origin_id': origin_id, + 'origin_url': origin_url, 'provider_type': provider_type }) diff --git a/swh/storage/db.py b/swh/storage/db.py --- a/swh/storage/db.py +++ b/swh/storage/db.py @@ -186,14 +186,15 @@ yield from cur - def snapshot_get_by_origin_visit(self, origin_id, visit_id, cur=None): + def snapshot_get_by_origin_visit(self, origin_url, visit_id, cur=None): cur = self._cursor(cur) query = """\ - SELECT snapshot from origin_visit where - origin_visit.origin=%s and origin_visit.visit=%s; + SELECT snapshot FROM origin_visit + INNER JOIN origin ON origin.id = origin_visit.origin + WHERE origin.url=%s AND origin_visit.visit=%s; """ - cur.execute(query, (origin_id, visit_id)) + cur.execute(query, (origin_url, visit_id)) ret = cur.fetchone() if ret: return ret[0] @@ -323,9 +324,10 @@ cur = self._cursor(cur) update_cols = [] values = [] - where = ['origin=%s AND visit=%s'] + where = ['origin.id = origin_visit.origin', + 'origin.url=%s', + 'visit=%s'] where_values = [origin_id, visit_id] - from_ = '' if 'status' in updates: update_cols.append('status=%s') values.append(updates.pop('status')) @@ -337,17 +339,20 @@ values.append(updates.pop('snapshot')) assert not updates, 'Unknown fields: %r' % updates query = """UPDATE origin_visit - SET {update_cols} - {from} - WHERE {where}""".format(**{ + SET {update_cols} + FROM origin + WHERE {where}""".format(**{ 'update_cols': ', '.join(update_cols), - 'from': from_, 'where': ' AND '.join(where) }) cur.execute(query, (*values, *where_values)) def origin_visit_upsert(self, origin, visit, date, type, status, metadata, snapshot, cur=None): + # doing an extra query like this is way simpler than trying to join + # the origin id in the query below + origin_id = next(self.origin_id_get_by_url([origin])) + cur = self._cursor(cur) query = """INSERT INTO origin_visit ({cols}) VALUES ({values}) ON CONFLICT ON CONSTRAINT origin_visit_pkey DO @@ -357,10 +362,14 @@ updates=', '.join('{0}=excluded.{0}'.format(col) for col in self.origin_visit_get_cols)) cur.execute( - query, (origin, visit, date, type, status, metadata, snapshot)) + query, (origin_id, visit, date, type, status, metadata, snapshot)) - origin_visit_get_cols = ['origin', 'visit', 'date', 'type', 'status', - 'metadata', 'snapshot'] + origin_visit_get_cols = [ + 'origin', 'visit', 'date', 'type', + 'status', 'metadata', 'snapshot'] + origin_visit_select_cols = [ + 'origin.url AS origin', 'visit', 'date', 'origin_visit.type AS type', + 'status', 'metadata', 'snapshot'] def origin_visit_get_all(self, origin_id, last_visit=None, limit=None, cur=None): @@ -385,10 +394,11 @@ query = """\ SELECT %s FROM origin_visit - WHERE origin=%%s %s + INNER JOIN origin ON origin.id = origin_visit.origin + WHERE origin.url=%%s %s order by visit asc limit %%s""" % ( - ', '.join(self.origin_visit_get_cols), extra_condition + ', '.join(self.origin_visit_select_cols), extra_condition ) cur.execute(query, args) @@ -411,8 +421,9 @@ query = """\ SELECT %s FROM origin_visit - WHERE origin = %%s AND visit = %%s - """ % (', '.join(self.origin_visit_get_cols)) + INNER JOIN origin ON origin.id = origin_visit.origin + WHERE origin.url = %%s AND visit = %%s + """ % (', '.join(self.origin_visit_select_cols)) cur.execute(query, (origin_id, visit_id)) r = cur.fetchall() @@ -457,10 +468,11 @@ cur = self._cursor(cur) query_parts = [ - 'SELECT %s' % ', '.join(self.origin_visit_get_cols), - 'FROM origin_visit'] + 'SELECT %s' % ', '.join(self.origin_visit_select_cols), + 'FROM origin_visit', + 'INNER JOIN origin ON origin.id = origin_visit.origin'] - query_parts.append('WHERE origin = %s') + query_parts.append('WHERE origin.url = %s') if require_snapshot: query_parts.append('AND snapshot is not null') @@ -607,15 +619,15 @@ def origin_add(self, url, cur=None): """Insert a new origin and return the new identifier.""" insert = """INSERT INTO origin (url) values (%s) - RETURNING id""" + RETURNING url""" cur.execute(insert, (url,)) return cur.fetchone()[0] - origin_cols = ['id', 'url'] + origin_cols = ['url'] def origin_get_by_url(self, origins, cur=None): - """Retrieve origin `(id, type, url)` from urls if found.""" + """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) query = """SELECT %s FROM (VALUES %%s) as t(url) @@ -625,18 +637,19 @@ yield from execute_values_generator( cur, query, ((url,) for url in origins)) - def origin_get_by_id(self, ids, cur=None): - """Retrieve origin `(id, type, url)` from ids if found. - - """ + def origin_id_get_by_url(self, origins, cur=None): + """Retrieve origin `(type, url)` from urls if found.""" cur = self._cursor(cur) - query = """SELECT %s FROM (VALUES %%s) as t(id) - LEFT JOIN origin ON t.id = origin.id - """ % ','.join('origin.' + col for col in self.origin_cols) + query = """SELECT id FROM (VALUES %s) as t(url) + LEFT JOIN origin ON t.url = origin.url + """ - yield from execute_values_generator( - cur, query, ((id,) for id in ids)) + for row in execute_values_generator( + cur, query, ((url,) for url in origins)): + yield row[0] + + origin_get_range_cols = ['id', 'url'] def origin_get_range(self, origin_from=1, origin_count=100, cur=None): """Retrieve ``origin_count`` origins whose ids are greater @@ -653,7 +666,7 @@ query = """SELECT %s FROM origin WHERE id >= %%s ORDER BY id LIMIT %%s - """ % ','.join(self.origin_cols) + """ % ','.join(self.origin_get_range_cols) cur.execute(query, (origin_from, origin_count)) yield from cur @@ -770,19 +783,17 @@ """ cur = self._cursor(cur) insert = """INSERT INTO origin_metadata (origin_id, discovery_date, - provider_id, tool_id, metadata) values (%s, %s, %s, %s, %s) - RETURNING id""" - cur.execute(insert, (origin, ts, provider, tool, jsonize(metadata))) - - return cur.fetchone()[0] + provider_id, tool_id, metadata) + SELECT id, %s, %s, %s, %s FROM origin WHERE url = %s""" + cur.execute(insert, (ts, provider, tool, jsonize(metadata), origin)) - origin_metadata_get_cols = ['origin_id', 'discovery_date', + origin_metadata_get_cols = ['origin_url', 'discovery_date', 'tool_id', 'metadata', 'provider_id', 'provider_name', 'provider_type', 'provider_url'] - def origin_metadata_get_by(self, origin_id, provider_type=None, cur=None): - """Retrieve all origin_metadata entries for one origin_id + def origin_metadata_get_by(self, origin_url, provider_type=None, cur=None): + """Retrieve all origin_metadata entries for one origin_url """ cur = self._cursor(cur) @@ -792,7 +803,7 @@ %%s)''' % (','.join( self.origin_metadata_get_cols)) - cur.execute(query, (origin_id, )) + cur.execute(query, (origin_url, )) else: query = '''SELECT %s @@ -800,7 +811,7 @@ %%s, %%s)''' % (','.join( self.origin_metadata_get_cols)) - cur.execute(query, (origin_id, provider_type)) + cur.execute(query, (origin_url, provider_type)) yield from cur diff --git a/swh/storage/in_memory.py b/swh/storage/in_memory.py --- a/swh/storage/in_memory.py +++ b/swh/storage/in_memory.py @@ -3,7 +3,6 @@ # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information -import os import re import bisect import dateutil @@ -32,10 +31,6 @@ return datetime.datetime.now(tz=datetime.timezone.utc) -ENABLE_ORIGIN_IDS = \ - os.environ.get('SWH_STORAGE_IN_MEMORY_ENABLE_ORIGIN_IDS', 'true') == 'true' - - class Storage: def __init__(self, journal_writer=None): self._contents = {} @@ -876,7 +871,7 @@ and :meth:`snapshot_get_branches` should be used instead. Args: - origin (Union[str,int]): the origin's URL or identifier + origin (str): the origin's URL allowed_statuses (list of str): list of visit statuses considered to find the latest snapshot for the origin. For instance, ``allowed_statuses=['full']`` will only consider visits that @@ -1013,15 +1008,8 @@ def _convert_origin(self, t): if t is None: return None - (origin_id, origin) = t - origin = origin.to_dict() - if ENABLE_ORIGIN_IDS: - origin['id'] = origin_id - if 'type' in origin: - del origin['type'] - - return origin + return t.to_dict() def origin_get(self, origins): """Return origins, either all identified by their ids or all @@ -1069,16 +1057,12 @@ results = [] for origin in origins: result = None - if 'id' in origin: - assert ENABLE_ORIGIN_IDS, 'origin ids are disabled' - if origin['id'] <= len(self._origins_by_id): - result = self._origins[self._origins_by_id[origin['id']-1]] - elif 'url' in origin: + if 'url' in origin: if origin['url'] in self._origins: result = self._origins[origin['url']] else: raise ValueError( - 'Origin must have either id or url.') + 'Origin must have an url.') results.append(self._convert_origin(result)) if return_single: @@ -1099,7 +1083,8 @@ Yields: dicts containing origin information as returned - by :meth:`swh.storage.in_memory.Storage.origin_get`. + by :meth:`swh.storage.in_memory.Storage.origin_get`, plus + an 'id' key. """ origin_from = max(origin_from, 1) if origin_from <= len(self._origins_by_id): @@ -1107,8 +1092,9 @@ if max_idx > len(self._origins_by_id): max_idx = len(self._origins_by_id) for idx in range(origin_from-1, max_idx): - yield self._convert_origin( + origin = self._convert_origin( self._origins[self._origins_by_id[idx]]) + yield {'id': idx+1, **origin} def origin_search(self, url_pattern, offset=0, limit=50, regexp=False, with_visit=False, db=None, cur=None): @@ -1139,9 +1125,6 @@ origins = [orig for orig in origins if len(self._origin_visits[orig['url']]) > 0] - if ENABLE_ORIGIN_IDS: - origins.sort(key=lambda origin: origin['id']) - return origins[offset:offset+limit] def origin_count(self, url_pattern, regexp=False, with_visit=False, @@ -1179,10 +1162,7 @@ """ origins = copy.deepcopy(origins) for origin in origins: - if ENABLE_ORIGIN_IDS: - origin['id'] = self.origin_add_one(origin) - else: - self.origin_add_one(origin) + self.origin_add_one(origin) return origins def origin_add_one(self, origin): @@ -1200,34 +1180,27 @@ """ origin = Origin.from_dict(origin) - - if origin.url in self._origins: - if ENABLE_ORIGIN_IDS: - (origin_id, _) = self._origins[origin.url] - else: + if origin.url not in self._origins: if self.journal_writer: self.journal_writer.write_addition('origin', origin) - if ENABLE_ORIGIN_IDS: - # origin ids are in the range [1, +inf[ - origin_id = len(self._origins) + 1 - self._origins_by_id.append(origin.url) - assert len(self._origins_by_id) == origin_id - else: - origin_id = None - self._origins[origin.url] = (origin_id, origin) + + # generate an origin_id because it is needed by origin_get_range. + # TODO: remove this when we remove origin_get_range + origin_id = len(self._origins) + 1 + self._origins_by_id.append(origin.url) + assert len(self._origins_by_id) == origin_id + + self._origins[origin.url] = origin self._origin_visits[origin.url] = [] self._objects[origin.url].append(('origin', origin.url)) - if ENABLE_ORIGIN_IDS: - return origin_id - else: - return origin.url + return origin.url def origin_visit_add(self, origin, date, type): """Add an origin_visit for the origin at date with status 'ongoing'. Args: - origin (Union[int,str]): visited origin's identifier or URL + origin (str): visited origin's identifier or URL date (Union[str,datetime]): timestamp of such visit type (str): the type of loader used for the visit (hg, git, ...) @@ -1238,7 +1211,7 @@ - visit: the visit's identifier for the new visit occurrence """ - origin_url = self._get_origin_url(origin) + origin_url = origin if origin_url is None: raise ValueError('Unknown origin.') @@ -1250,12 +1223,12 @@ visit_ret = None if origin_url in self._origins: - (origin_id, origin) = self._origins[origin_url] + origin = self._origins[origin_url] # visit ids are in the range [1, +inf[ visit_id = len(self._origin_visits[origin_url]) + 1 status = 'ongoing' visit = OriginVisit( - origin=origin, + origin=origin.url, date=date, type=type, status=status, @@ -1265,7 +1238,7 @@ ) self._origin_visits[origin_url].append(visit) visit_ret = { - 'origin': origin_id if ENABLE_ORIGIN_IDS else origin.url, + 'origin': origin.url, 'visit': visit_id, } @@ -1273,6 +1246,7 @@ ('origin_visit', None)) if self.journal_writer: + visit = attr.evolve(visit, origin=origin) self.journal_writer.write_addition('origin_visit', visit) return visit_ret @@ -1282,7 +1256,7 @@ """Update an origin_visit's status. Args: - origin (Union[int,str]): visited origin's identifier or URL + origin (str): visited origin's URL visit_id (int): visit's identifier status: visit's new status metadata: data associated to the visit @@ -1314,8 +1288,9 @@ visit = attr.evolve(visit, **updates) if self.journal_writer: - (_, origin) = self._origins[origin_url] - self.journal_writer.write_update('origin_visit', visit) + origin = self._origins[origin_url] + journal_visit = attr.evolve(visit, origin=origin) + self.journal_writer.write_update('origin_visit', journal_visit) self._origin_visits[origin_url][visit_id-1] = visit @@ -1346,13 +1321,15 @@ for visit in visits: visit = attr.evolve( visit, - origin=self._origins[visit.origin.url][1]) + origin=self._origins[visit.origin.url]) self.journal_writer.write_addition('origin_visit', visit) for visit in visits: visit_id = visit.visit origin_url = visit.origin.url + visit = attr.evolve(visit, origin=origin_url) + self._objects[(origin_url, visit_id)].append( ('origin_visit', None)) @@ -1365,12 +1342,7 @@ if visit is None: return - (origin_id, origin) = self._origins[visit.origin.url] visit = visit.to_dict() - if ENABLE_ORIGIN_IDS: - visit['origin'] = origin_id - else: - visit['origin'] = origin.url return visit @@ -1467,10 +1439,9 @@ snapshot (Optional[sha1_git]): identifier of the snapshot associated to the visit """ - res = self._origins.get(origin) - if not res: + origin = self._origins.get(origin) + if not origin: return - (_, origin) = res visits = self._origin_visits[origin.url] if allowed_statuses is not None: visits = [visit for visit in visits @@ -1513,49 +1484,46 @@ """Recomputes the statistics for `stat_counters`.""" pass - def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, + def origin_metadata_add(self, origin_url, ts, provider, tool, metadata, db=None, cur=None): """ Add an origin_metadata for the origin at ts with provenance and metadata. Args: - origin_id (int): the origin's id for which the metadata is added + origin_url (str): the origin url for which the metadata is added ts (datetime): timestamp of the found metadata provider: id of the provider of metadata (ex:'hal') tool: id of the tool used to extract metadata metadata (jsonb): the metadata retrieved at the time and location """ - if isinstance(origin_id, str): - origin = self.origin_get({'url': origin_id}) - if not origin: - return - origin_id = origin['id'] + if not isinstance(origin_url, str): + raise TypeError('origin_id must be str, not %r' % (origin_url,)) if isinstance(ts, str): ts = dateutil.parser.parse(ts) origin_metadata = { - 'origin_id': origin_id, + 'origin_url': origin_url, 'discovery_date': ts, 'tool_id': tool, 'metadata': metadata, 'provider_id': provider, } - self._origin_metadata[origin_id].append(origin_metadata) + self._origin_metadata[origin_url].append(origin_metadata) return None - def origin_metadata_get_by(self, origin_id, provider_type=None, db=None, + def origin_metadata_get_by(self, origin_url, provider_type=None, db=None, cur=None): - """Retrieve list of all origin_metadata entries for the origin_id + """Retrieve list of all origin_metadata entries for the origin_url Args: - origin_id (int): the unique origin's identifier + origin_url (str): the origin's url provider_type (str): (optional) type of provider Returns: list of dicts: the origin_metadata dictionary with the keys: - - origin_id (int): origin's identifier + - origin_url (int): origin's URL - discovery_date (datetime): timestamp of discovery - tool_id (int): metadata's extracting tool - metadata (jsonb) @@ -1565,14 +1533,10 @@ - provider_url (str) """ - if isinstance(origin_id, str): - origin = self.origin_get({'url': origin_id}) - if not origin: - return - origin_id = origin['id'] - + if not isinstance(origin_url, str): + raise TypeError('origin_url must be str, not %r' % (origin_url,)) metadata = [] - for item in self._origin_metadata[origin_id]: + for item in self._origin_metadata[origin_url]: item = copy.deepcopy(item) provider = self.metadata_provider_get(item['provider_id']) for attr_name in ('name', 'type', 'url'): @@ -1678,13 +1642,8 @@ def _get_origin_url(self, origin): if isinstance(origin, str): return origin - elif isinstance(origin, int): - if origin <= len(self._origins_by_id): - return self._origins_by_id[origin-1] - else: - return None else: - raise TypeError('origin must be a string or an integer.') + raise TypeError('origin must be a string.') def _person_add(self, person): """Add a person in storage. diff --git a/swh/storage/sql/30-swh-schema.sql b/swh/storage/sql/30-swh-schema.sql --- a/swh/storage/sql/30-swh-schema.sql +++ b/swh/storage/sql/30-swh-schema.sql @@ -17,7 +17,7 @@ -- latest schema version insert into dbversion(version, release, description) - values(142, now(), 'Work In Progress'); + values(143, now(), 'Work In Progress'); -- a SHA1 checksum create domain sha1 as bytea check (length(value) = 20); diff --git a/swh/storage/sql/40-swh-func.sql b/swh/storage/sql/40-swh-func.sql --- a/swh/storage/sql/40-swh-func.sql +++ b/swh/storage/sql/40-swh-func.sql @@ -693,17 +693,22 @@ -- add a new origin_visit for origin origin_id at date. -- -- Returns the new visit id. -create or replace function swh_origin_visit_add(origin_id bigint, date timestamptz, type text) +create or replace function swh_origin_visit_add(origin_url text, date timestamptz, type text) returns bigint language sql as $$ - with last_known_visit as ( + with origin_id as ( + select id + from origin + where url = origin_url + ), last_known_visit as ( select coalesce(max(visit), 0) as visit from origin_visit - where origin = origin_id + where origin = (select id from origin_id) ) insert into origin_visit (origin, date, type, visit, status) - values (origin_id, date, type, (select visit from last_known_visit) + 1, 'ongoing') + values ((select id from origin_id), date, type, + (select visit from last_known_visit) + 1, 'ongoing') returning visit; $$; @@ -828,40 +833,34 @@ select dir_id, name from path order by depth desc limit 1; $$; --- Find the visit of origin id closest to date visit_date +-- Find the visit of origin closest to date visit_date -- Breaks ties by selecting the largest visit id -create or replace function swh_visit_find_by_date(origin bigint, visit_date timestamptz default NOW()) - returns origin_visit - language sql +create or replace function swh_visit_find_by_date(origin_url text, visit_date timestamptz default NOW()) + returns setof origin_visit + language plpgsql stable as $$ +declare + origin_id bigint; +begin + select id into origin_id from origin where url=origin_url; + return query with closest_two_visits as (( select ov, (date - visit_date), visit as interval from origin_visit ov - where ov.origin = origin + where ov.origin = origin_id and ov.date >= visit_date order by ov.date asc, ov.visit desc limit 1 ) union ( select ov, (visit_date - date), visit as interval from origin_visit ov - where ov.origin = origin + where ov.origin = origin_id and ov.date < visit_date order by ov.date desc, ov.visit desc limit 1 - )) select (ov).* from closest_two_visits order by interval, visit limit 1 -$$; - --- Find the visit of origin id closest to date visit_date -create or replace function swh_visit_get(origin bigint) - returns origin_visit - language sql - stable -as $$ - select * - from origin_visit - where origin=origin - order by date desc + )) select (ov).* from closest_two_visits order by interval, visit limit 1; +end $$; -- Object listing by object_id @@ -927,7 +926,7 @@ -- origin_metadata functions create type origin_metadata_signature as ( id bigint, - origin_id bigint, + origin_url text, discovery_date timestamptz, tool_id bigint, metadata jsonb, @@ -937,32 +936,34 @@ provider_url text ); create or replace function swh_origin_metadata_get_by_origin( - origin integer) + origin text) returns setof origin_metadata_signature language sql stable as $$ - select om.id as id, origin_id, discovery_date, tool_id, om.metadata, + select om.id as id, o.url as origin_url, discovery_date, tool_id, om.metadata, mp.id as provider_id, provider_name, provider_type, provider_url from origin_metadata as om inner join metadata_provider mp on om.provider_id = mp.id - where om.origin_id = origin + inner join origin o on om.origin_id = o.id + where o.url = origin order by discovery_date desc; $$; create or replace function swh_origin_metadata_get_by_provider_type( - origin integer, - type text) + origin_url text, + provider_type text) returns setof origin_metadata_signature language sql stable as $$ - select om.id as id, origin_id, discovery_date, tool_id, om.metadata, + select om.id as id, o.url as origin_url, discovery_date, tool_id, om.metadata, mp.id as provider_id, provider_name, provider_type, provider_url from origin_metadata as om inner join metadata_provider mp on om.provider_id = mp.id - where om.origin_id = origin - and mp.provider_type = type + inner join origin o on om.origin_id = o.id + where o.url = origin_url + and mp.provider_type = provider_type order by discovery_date desc; $$; -- end origin_metadata functions diff --git a/swh/storage/storage.py b/swh/storage/storage.py --- a/swh/storage/storage.py +++ b/swh/storage/storage.py @@ -212,13 +212,12 @@ if content_without_data: content_without_data = \ [cont.copy() for cont in content_without_data] - origins = db.origin_get_by_url( + origin_ids = db.origin_id_get_by_url( [cont.get('origin') for cont in content_without_data], cur=cur) - for (cont, origin) in zip(content_without_data, origins): - origin = dict(zip(db.origin_cols, origin)) + for (cont, origin_id) in zip(content_without_data, origin_ids): if 'origin' in cont: - cont['origin'] = origin['id'] + cont['origin'] = origin_id db.mktemp('skipped_content', cur) db.copy_to(content_without_data, 'tmp_skipped_content', db.skipped_content_keys, cur) @@ -1100,7 +1099,7 @@ should be used instead. Args: - origin (Union[str,int]): the origin's URL or identifier + origin (str): the origin's URL allowed_statuses (list of str): list of visit statuses considered to find the latest snapshot for the visit. For instance, ``allowed_statuses=['full']`` will only consider visits that @@ -1216,7 +1215,7 @@ """Add an origin_visit for the origin at ts with status 'ongoing'. Args: - origin (Union[int,str]): visited origin's identifier or URL + origin (str): visited origin's identifier or URL date (Union[str,datetime]): timestamp of such visit type (str): the type of loader used for the visit (hg, git, ...) @@ -1227,30 +1226,25 @@ - visit: the visit identifier for the new visit occurrence """ - if isinstance(origin, str): - origin = self.origin_get({'url': origin}, db=db, cur=cur) - origin_id = origin['id'] - else: - origin = self.origin_get({'id': origin}, db=db, cur=cur) - origin_id = origin['id'] + origin_url = origin + origin = self.origin_get({'url': origin_url}, db=db, cur=cur) if isinstance(date, str): # FIXME: Converge on iso8601 at some point date = dateutil.parser.parse(date) - visit_id = db.origin_visit_add(origin_id, date, type, cur) + visit_id = db.origin_visit_add(origin_url, date, type, cur) if self.journal_writer: # We can write to the journal only after inserting to the # DB, because we want the id of the visit - del origin['id'] self.journal_writer.write_addition('origin_visit', { 'origin': origin, 'date': date, 'type': type, 'visit': visit_id, 'status': 'ongoing', 'metadata': None, 'snapshot': None}) return { - 'origin': origin_id, + 'origin': origin_url, 'visit': visit_id, } @@ -1261,7 +1255,7 @@ """Update an origin_visit's status. Args: - origin (Union[int,str]): visited origin's identifier or URL + origin (str): visited origin's URL visit_id: Visit's id status: Visit's new status metadata: Data associated to the visit @@ -1272,12 +1266,8 @@ None """ - if isinstance(origin, str): - origin_id = self.origin_get({'url': origin}, db=db, cur=cur)['id'] - else: - origin_id = origin - - visit = db.origin_visit_get(origin_id, visit_id, cur=cur) + origin_url = origin + visit = db.origin_visit_get(origin_url, visit_id, cur=cur) if not visit: raise ValueError('Invalid visit_id for this origin.') @@ -1295,12 +1285,11 @@ if updates: if self.journal_writer: origin = self.origin_get( - [{'id': origin_id}], db=db, cur=cur)[0] - del origin['id'] + [{'url': origin_url}], db=db, cur=cur)[0] self.journal_writer.write_update('origin_visit', { **visit, **updates, 'origin': origin}) - db.origin_visit_update(origin_id, visit_id, updates, cur) + db.origin_visit_update(origin_url, visit_id, updates, cur) @db_transaction() def origin_visit_upsert(self, visits, db=None, cur=None): @@ -1331,11 +1320,10 @@ visit = copy.deepcopy(visit) if visit.get('type') is None: visit['type'] = visit['origin']['type'] - del visit['origin']['id'] self.journal_writer.write_addition('origin_visit', visit) for visit in visits: - visit['origin'] = visit['origin']['id'] + visit['origin'] = visit['origin']['url'] # TODO: upsert them all in a single query db.origin_visit_upsert(**visit, cur=cur) @@ -1345,7 +1333,7 @@ """Retrieve all the origin's visit's information. Args: - origin (Union[int,str]): The occurrence's origin (identifier/URL). + origin (str): The visited origin last_visit: Starting point from which listing the next visits Default to None limit (int): Number of results to return from the last visit. @@ -1355,11 +1343,6 @@ List of visits. """ - if isinstance(origin, str): - origin = self.origin_get([{'url': origin}], db=db, cur=cur)[0] - if not origin: - return - origin = origin['id'] for line in db.origin_visit_get_all( origin, last_visit=last_visit, limit=limit, cur=cur): data = dict(zip(db.origin_visit_get_cols, line)) @@ -1379,10 +1362,6 @@ A visit. """ - origin = self.origin_get([{'url': origin}], db=db, cur=cur)[0] - if not origin: - return - origin = origin['id'] line = db.origin_visit_find_by_date(origin, visit_date, cur=cur) if line: return dict(zip(db.origin_visit_get_cols, line)) @@ -1399,11 +1378,6 @@ it does not exist """ - if isinstance(origin, str): - origin = self.origin_get({'url': origin}, db=db, cur=cur) - if not origin: - return - origin = origin['id'] ori_visit = db.origin_visit_get(origin, visit, cur) if not ori_visit: return None @@ -1438,11 +1412,6 @@ snapshot (Optional[sha1_git]): identifier of the snapshot associated to the visit """ - origin = self.origin_get({'url': origin}, db=db, cur=cur) - if not origin: - return - origin = origin['id'] - origin_visit = db.origin_visit_get_latest( origin, allowed_statuses=allowed_statuses, require_snapshot=require_snapshot, cur=cur) @@ -1475,8 +1444,6 @@ return ret - origin_keys = ['id', 'url'] - @db_transaction(statement_timeout=500) def origin_get(self, origins, db=None, cur=None): """Return origins, either all identified by their ids or all @@ -1488,14 +1455,10 @@ Args: origin: a list of dictionaries representing the individual origins to find. - These dicts have either the key url: + These dicts have the key url: - url (bytes): the url the origin points to - or the id: - - - id: the origin id - Returns: dict: the origin dictionary with the keys: @@ -1515,36 +1478,19 @@ else: return_single = False - origin_ids = [origin.get('id') for origin in origins] - origin_urls = [origin.get('url') for origin in origins] - if any(origin_ids): - # Lookup per ID - if all(origin_ids): - results = db.origin_get_by_id(origin_ids, cur) - else: - raise ValueError( - 'Either all origins or none at all should have an "id".') - elif any(origin_urls): - # Lookup per type + URL - if all(origin_urls): - results = db.origin_get_by_url(origin_urls, cur) - else: - raise ValueError( - 'Either all origins or none at all should have ' - 'an "url" key.') - else: # unsupported lookup - raise ValueError('Origin must have either id or url.') + origin_urls = [origin['url'] for origin in origins] + results = db.origin_get_by_url(origin_urls, cur) - results = [dict(zip(self.origin_keys, result)) + results = [dict(zip(db.origin_cols, result)) for result in results] if return_single: assert len(results) == 1 - if results[0]['id'] is not None: + if results[0]['url'] is not None: return results[0] else: return None else: - return [None if res['id'] is None else res for res in results] + return [None if res['url'] is None else res for res in results] @db_transaction_generator() def origin_get_range(self, origin_from=1, origin_count=100, @@ -1563,7 +1509,7 @@ by :meth:`swh.storage.storage.Storage.origin_get`. """ for origin in db.origin_get_range(origin_from, origin_count, cur): - yield dict(zip(self.origin_keys, origin)) + yield dict(zip(db.origin_get_range_cols, origin)) @db_transaction_generator() def origin_search(self, url_pattern, offset=0, limit=50, @@ -1587,7 +1533,7 @@ """ for origin in db.origin_search(url_pattern, offset, limit, regexp, with_visit, cur): - yield dict(zip(self.origin_keys, origin)) + yield dict(zip(db.origin_cols, origin)) @db_transaction() def origin_count(self, url_pattern, regexp=False, @@ -1625,7 +1571,7 @@ """ origins = copy.deepcopy(origins) for origin in origins: - origin['id'] = self.origin_add_one(origin, db=db, cur=cur) + self.origin_add_one(origin, db=db, cur=cur) return origins @@ -1645,10 +1591,10 @@ exists. """ - origin_id = list(db.origin_get_by_url( - [origin['url']], cur))[0][0] - if origin_id: - return origin_id + origin_row = list(db.origin_get_by_url([origin['url']], cur))[0] + origin_url = dict(zip(db.origin_cols, origin_row))['url'] + if origin_url: + return origin_url if self.journal_writer: self.journal_writer.write_addition('origin', origin) @@ -1688,40 +1634,31 @@ cur.execute('select * from swh_update_counter(%s)', (key,)) @db_transaction() - def origin_metadata_add(self, origin_id, ts, provider, tool, metadata, + def origin_metadata_add(self, origin_url, ts, provider, tool, metadata, db=None, cur=None): """ Add an origin_metadata for the origin at ts with provenance and metadata. Args: - origin_id (int): the origin's id for which the metadata is added + origin_url (str): the origin url for which the metadata is added ts (datetime): timestamp of the found metadata provider (int): the provider of metadata (ex:'hal') tool (int): tool used to extract metadata metadata (jsonb): the metadata retrieved at the time and location - - Returns: - id (int): the origin_metadata unique id """ - if isinstance(origin_id, str): - origin = self.origin_get({'url': origin_id}, db=db, cur=cur) - if not origin: - return - origin_id = origin['id'] - if isinstance(ts, str): ts = dateutil.parser.parse(ts) - return db.origin_metadata_add(origin_id, ts, provider, tool, - metadata, cur) + db.origin_metadata_add(origin_url, ts, provider, tool, + metadata, cur) @db_transaction_generator(statement_timeout=500) - def origin_metadata_get_by(self, origin_id, provider_type=None, db=None, + def origin_metadata_get_by(self, origin_url, provider_type=None, db=None, cur=None): """Retrieve list of all origin_metadata entries for the origin_id Args: - origin_id (int): the unique origin identifier + origin_url (str): the origin's URL provider_type (str): (optional) type of provider Returns: @@ -1737,13 +1674,7 @@ - provider_url (str) """ - if isinstance(origin_id, str): - origin = self.origin_get({'url': origin_id}, db=db, cur=cur) - if not origin: - return - origin_id = origin['id'] - - for line in db.origin_metadata_get_by(origin_id, provider_type, cur): + for line in db.origin_metadata_get_by(origin_url, provider_type, cur): yield dict(zip(db.origin_metadata_get_cols, line)) @db_transaction() diff --git a/swh/storage/tests/test_in_memory.py b/swh/storage/tests/test_in_memory.py --- a/swh/storage/tests/test_in_memory.py +++ b/swh/storage/tests/test_in_memory.py @@ -8,11 +8,6 @@ from swh.storage import get_storage from swh.storage.tests.test_storage import ( # noqa TestStorage, TestStorageGeneratedData) -from swh.storage.in_memory import ENABLE_ORIGIN_IDS - - -TestStorage._test_origin_ids = ENABLE_ORIGIN_IDS -TestStorageGeneratedData._test_origin_ids = ENABLE_ORIGIN_IDS # tests are executed using imported classes (TestStorage and diff --git a/swh/storage/tests/test_storage.py b/swh/storage/tests/test_storage.py --- a/swh/storage/tests/test_storage.py +++ b/swh/storage/tests/test_storage.py @@ -78,7 +78,6 @@ class twice. """ maxDiff = None # type: ClassVar[Optional[int]] - _test_origin_ids = True def test_check_config(self, swh_storage): assert swh_storage.check_config(check_write=True) @@ -872,8 +871,6 @@ id = swh_storage.origin_add_one(data.origin) actual_origin = swh_storage.origin_get({'url': data.origin['url']}) - if self._test_origin_ids: - assert actual_origin['id'] == id assert actual_origin['url'] == data.origin['url'] id2 = swh_storage.origin_add_one(data.origin) @@ -889,15 +886,11 @@ actual_origin = swh_storage.origin_get([{ 'url': data.origin['url'], }])[0] - if self._test_origin_ids: - assert actual_origin['id'] == origin1['id'] assert actual_origin['url'] == origin1['url'] actual_origin2 = swh_storage.origin_get([{ 'url': data.origin2['url'], }])[0] - if self._test_origin_ids: - assert actual_origin2['id'] == origin2['id'] assert actual_origin2['url'] == origin2['url'] if 'id' in actual_origin: @@ -927,49 +920,21 @@ def test_origin_get_legacy(self, swh_storage): assert swh_storage.origin_get(data.origin) is None - id = swh_storage.origin_add_one(data.origin) + swh_storage.origin_add_one(data.origin) - # lookup per url (returns id) actual_origin0 = swh_storage.origin_get( {'url': data.origin['url']}) - if self._test_origin_ids: - assert actual_origin0['id'] == id assert actual_origin0['url'] == data.origin['url'] - # lookup per id (returns dict) - if self._test_origin_ids: - actual_origin1 = swh_storage.origin_get({'id': id}) - - assert actual_origin1 == {'id': id, - 'url': data.origin['url']} - def test_origin_get(self, swh_storage): assert swh_storage.origin_get(data.origin) is None - origin_id = swh_storage.origin_add_one(data.origin) + swh_storage.origin_add_one(data.origin) - # lookup per url (returns id) actual_origin0 = swh_storage.origin_get( [{'url': data.origin['url']}]) assert len(actual_origin0) == 1 assert actual_origin0[0]['url'] == data.origin['url'] - if self._test_origin_ids: - # lookup per id (returns dict) - actual_origin1 = swh_storage.origin_get([{'id': origin_id}]) - - assert len(actual_origin1) == 1 - assert actual_origin1[0] == {'id': origin_id, - 'url': data.origin['url']} - - def test_origin_get_consistency(self, swh_storage): - assert swh_storage.origin_get(data.origin) is None - id = swh_storage.origin_add_one(data.origin) - - with pytest.raises(ValueError): - swh_storage.origin_get([ - {'url': data.origin['url']}, - {'id': id}]) - def test_origin_search_single_result(self, swh_storage): found_origins = list(swh_storage.origin_search(data.origin['url'])) assert len(found_origins) == 0 @@ -1081,27 +1046,23 @@ # check both origins were returned assert found_origins0 != found_origins1 - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_add(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_add(self, swh_storage): # given - origin_id = swh_storage.origin_add_one(data.origin2) - assert origin_id is not None + swh_storage.origin_add_one(data.origin2) - origin_id_or_url = data.origin2['url'] if use_url else origin_id + origin_url = data.origin2['url'] # when date_visit = datetime.datetime.now(datetime.timezone.utc) origin_visit1 = swh_storage.origin_visit_add( - origin_id_or_url, + origin_url, type=data.type_visit1, date=date_visit) actual_origin_visits = list(swh_storage.origin_visit_get( - origin_id_or_url)) + origin_url)) assert { - 'origin': origin_id, + 'origin': origin_url, 'date': date_visit, 'visit': origin_visit1['visit'], 'type': data.type_visit1, @@ -1126,41 +1087,35 @@ def test_origin_visit_get__unknown_origin(self, swh_storage): assert [] == list(swh_storage.origin_visit_get('foo')) - if self._test_origin_ids: - assert list(swh_storage.origin_visit_get(10)) == [] - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_add_default_type(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_add_default_type(self, swh_storage): # given - origin_id = swh_storage.origin_add_one(data.origin2) - origin_id_or_url = data.origin2['url'] if use_url else origin_id - assert origin_id is not None + swh_storage.origin_add_one(data.origin2) + origin_url = data.origin2['url'] # when date_visit = datetime.datetime.now(datetime.timezone.utc) date_visit2 = date_visit + datetime.timedelta(minutes=1) origin_visit1 = swh_storage.origin_visit_add( - origin_id_or_url, + origin_url, date=date_visit, type=data.type_visit1, ) origin_visit2 = swh_storage.origin_visit_add( - origin_id_or_url, + origin_url, date=date_visit2, type=data.type_visit2, ) # then - assert origin_visit1['origin'] == origin_id + assert origin_visit1['origin'] == origin_url assert origin_visit1['visit'] is not None actual_origin_visits = list(swh_storage.origin_visit_get( - origin_id_or_url)) + origin_url)) expected_visits = [ { - 'origin': origin_id, + 'origin': origin_url, 'date': date_visit, 'visit': origin_visit1['visit'], 'type': data.type_visit1, @@ -1169,7 +1124,7 @@ 'snapshot': None, }, { - 'origin': origin_id, + 'origin': origin_url, 'date': date_visit2, 'visit': origin_visit2['visit'], 'type': data.type_visit2, @@ -1189,20 +1144,16 @@ assert ('origin_visit', visit) in objects def test_origin_visit_add_validation(self, swh_storage): - origin_id_or_url = swh_storage.origin_add_one(data.origin2) + origin_url = swh_storage.origin_add_one(data.origin2) with pytest.raises((TypeError, psycopg2.ProgrammingError)) as cm: - swh_storage.origin_visit_add(origin_id_or_url, date=[b'foo'], - type=data.type_visit1) + swh_storage.origin_visit_add(origin_url, date=[b'foo']) if type(cm.value) == psycopg2.ProgrammingError: assert cm.value.pgcode \ == psycopg2.errorcodes.UNDEFINED_FUNCTION - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_update(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_update(self, swh_storage): # given swh_storage.origin_add_one(data.origin) origin_url = data.origin['url'] @@ -1363,9 +1314,10 @@ assert ('origin_visit', data5) in objects def test_origin_visit_update_validation(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) visit = swh_storage.origin_visit_add( - origin_id, + origin_url, date=data.date_visit2, type=data.type_visit2, ) @@ -1373,7 +1325,7 @@ with pytest.raises((ValueError, psycopg2.DataError), match='status') as cm: swh_storage.origin_visit_update( - origin_id, visit['visit'], status='foobar') + origin_url, visit['visit'], status='foobar') if type(cm.value) == psycopg2.DataError: assert cm.value.pgcode == \ @@ -1414,29 +1366,26 @@ def test_origin_visit_find_by_date__unknown_origin(self, swh_storage): swh_storage.origin_visit_find_by_date('foo', data.date_visit2) - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_update_missing_snapshot(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_update_missing_snapshot(self, swh_storage): # given - origin_id = swh_storage.origin_add_one(data.origin) - origin_id_or_url = data.origin['url'] if use_url else origin_id + swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] origin_visit = swh_storage.origin_visit_add( - origin_id_or_url, + origin_url, date=data.date_visit1, type=data.type_visit1, ) # when swh_storage.origin_visit_update( - origin_id_or_url, + origin_url, origin_visit['visit'], snapshot=data.snapshot['id']) # then actual_origin_visit = swh_storage.origin_visit_get_by( - origin_id_or_url, + origin_url, origin_visit['visit']) assert actual_origin_visit['snapshot'] == data.snapshot['id'] @@ -1444,36 +1393,33 @@ swh_storage.snapshot_add([data.snapshot]) assert actual_origin_visit['snapshot'] == data.snapshot['id'] - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_get_by(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return - origin_id = swh_storage.origin_add_one(data.origin) - origin_id2 = swh_storage.origin_add_one(data.origin2) + def test_origin_visit_get_by(self, swh_storage): + swh_storage.origin_add_one(data.origin) + swh_storage.origin_add_one(data.origin2) - origin_id_or_url = data.origin['url'] if use_url else origin_id - origin2_id_or_url = data.origin2['url'] if use_url else origin_id2 + origin_url = data.origin['url'] + origin2_url = data.origin2['url'] origin_visit1 = swh_storage.origin_visit_add( - origin_id_or_url, + origin_url, date=data.date_visit2, type=data.type_visit2, ) swh_storage.snapshot_add([data.snapshot]) swh_storage.origin_visit_update( - origin_id_or_url, + origin_url, origin_visit1['visit'], snapshot=data.snapshot['id']) # Add some other {origin, visit} entries swh_storage.origin_visit_add( - origin_id_or_url, + origin_url, date=data.date_visit3, type=data.type_visit3, ) swh_storage.origin_visit_add( - origin2_id_or_url, + origin2_url, date=data.date_visit3, type=data.type_visit3, ) @@ -1485,13 +1431,13 @@ } swh_storage.origin_visit_update( - origin_id_or_url, + origin_url, origin_visit1['visit'], status='full', metadata=visit1_metadata) expected_origin_visit = origin_visit1.copy() expected_origin_visit.update({ - 'origin': origin_id, + 'origin': origin_url, 'visit': origin_visit1['visit'], 'date': data.date_visit2, 'type': data.type_visit2, @@ -1502,25 +1448,19 @@ # when actual_origin_visit1 = swh_storage.origin_visit_get_by( - origin_id_or_url, + origin_url, origin_visit1['visit']) # then assert actual_origin_visit1 == expected_origin_visit def test_origin_visit_get_by__unknown_origin(self, swh_storage): - if self._test_origin_ids: - assert swh_storage.origin_visit_get_by(2, 10) is None assert swh_storage.origin_visit_get_by('foo', 10) is None - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_upsert_new(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_upsert_new(self, swh_storage): # given - origin_id = swh_storage.origin_add_one(data.origin2) + swh_storage.origin_add_one(data.origin2) origin_url = data.origin2['url'] - assert origin_id is not None # when swh_storage.origin_visit_upsert([ @@ -1549,7 +1489,7 @@ origin_url)) assert actual_origin_visits == [ { - 'origin': origin_id, + 'origin': origin_url, 'date': data.date_visit2, 'visit': 123, 'type': data.type_visit2, @@ -1558,7 +1498,7 @@ 'snapshot': None, }, { - 'origin': origin_id, + 'origin': origin_url, 'date': data.date_visit3, 'visit': 1234, 'type': data.type_visit2, @@ -1592,14 +1532,10 @@ ('origin_visit', data1), ('origin_visit', data2)] - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_upsert_existing(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_upsert_existing(self, swh_storage): # given - origin_id = swh_storage.origin_add_one(data.origin2) + swh_storage.origin_add_one(data.origin2) origin_url = data.origin2['url'] - assert origin_id is not None # when origin_visit1 = swh_storage.origin_visit_add( @@ -1618,14 +1554,14 @@ }]) # then - assert origin_visit1['origin'] == origin_id + assert origin_visit1['origin'] == origin_url assert origin_visit1['visit'] is not None actual_origin_visits = list(swh_storage.origin_visit_get( origin_url)) assert actual_origin_visits == [ { - 'origin': origin_id, + 'origin': origin_url, 'date': data.date_visit2, 'visit': origin_visit1['visit'], 'type': data.type_visit1, @@ -1659,20 +1595,12 @@ ('origin_visit', data2)] def test_origin_visit_get_by_no_result(self, swh_storage): - if self._test_origin_ids: - actual_origin_visit = swh_storage.origin_visit_get_by( - 10, 999) - assert actual_origin_visit is None - swh_storage.origin_add([data.origin]) actual_origin_visit = swh_storage.origin_visit_get_by( data.origin['url'], 999) assert actual_origin_visit is None - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_visit_get_latest(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_origin_visit_get_latest(self, swh_storage): swh_storage.origin_add_one(data.origin) origin_url = data.origin['url'] origin_visit1 = swh_storage.origin_visit_add( @@ -1799,9 +1727,10 @@ assert revisions[0]['committer'] == revisions[1]['committer'] def test_snapshot_add_get_empty(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) origin_visit1 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit1, type=data.type_visit1, ) @@ -1811,12 +1740,12 @@ assert actual_result == {'snapshot:add': 1} swh_storage.origin_visit_update( - origin_id, visit_id, snapshot=data.empty_snapshot['id']) + origin_url, visit_id, snapshot=data.empty_snapshot['id']) by_id = swh_storage.snapshot_get(data.empty_snapshot['id']) assert by_id == {**data.empty_snapshot, 'next_branch': None} - by_ov = swh_storage.snapshot_get_by_origin_visit(origin_id, visit_id) + by_ov = swh_storage.snapshot_get_by_origin_visit(origin_url, visit_id) assert by_ov == {**data.empty_snapshot, 'next_branch': None} expected_origin = data.origin.copy() @@ -1845,9 +1774,10 @@ ('origin_visit', data2)] def test_snapshot_add_get_complete(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) origin_visit1 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit1, type=data.type_visit1, ) @@ -1855,13 +1785,13 @@ actual_result = swh_storage.snapshot_add([data.complete_snapshot]) swh_storage.origin_visit_update( - origin_id, visit_id, snapshot=data.complete_snapshot['id']) + origin_url, visit_id, snapshot=data.complete_snapshot['id']) assert actual_result == {'snapshot:add': 1} by_id = swh_storage.snapshot_get(data.complete_snapshot['id']) assert by_id == {**data.complete_snapshot, 'next_branch': None} - by_ov = swh_storage.snapshot_get_by_origin_visit(origin_id, visit_id) + by_ov = swh_storage.snapshot_get_by_origin_visit(origin_url, visit_id) assert by_ov == {**data.complete_snapshot, 'next_branch': None} def test_snapshot_add_many(self, swh_storage): @@ -1988,9 +1918,10 @@ assert snapshot == expected_snapshot def test_snapshot_add_get_filtered(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) origin_visit1 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit1, type=data.type_visit1, ) @@ -1998,7 +1929,7 @@ swh_storage.snapshot_add([data.complete_snapshot]) swh_storage.origin_visit_update( - origin_id, visit_id, snapshot=data.complete_snapshot['id']) + origin_url, visit_id, snapshot=data.complete_snapshot['id']) snp_id = data.complete_snapshot['id'] branches = data.complete_snapshot['branches'] @@ -2106,9 +2037,10 @@ assert snapshot == expected_snapshot def test_snapshot_add_get(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) origin_visit1 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit1, type=data.type_visit1, ) @@ -2116,20 +2048,21 @@ swh_storage.snapshot_add([data.snapshot]) swh_storage.origin_visit_update( - origin_id, visit_id, snapshot=data.snapshot['id']) + origin_url, visit_id, snapshot=data.snapshot['id']) by_id = swh_storage.snapshot_get(data.snapshot['id']) assert by_id == {**data.snapshot, 'next_branch': None} - by_ov = swh_storage.snapshot_get_by_origin_visit(origin_id, visit_id) + by_ov = swh_storage.snapshot_get_by_origin_visit(origin_url, visit_id) assert by_ov == {**data.snapshot, 'next_branch': None} origin_visit_info = swh_storage.origin_visit_get_by( - origin_id, visit_id) + origin_url, visit_id) assert origin_visit_info['snapshot'] == data.snapshot['id'] def test_snapshot_add_nonexistent_visit(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) visit_id = 54164461156 swh_storage.journal_writer.objects[:] = [] @@ -2138,29 +2071,30 @@ with pytest.raises(ValueError): swh_storage.origin_visit_update( - origin_id, visit_id, snapshot=data.snapshot['id']) + origin_url, visit_id, snapshot=data.snapshot['id']) assert list(swh_storage.journal_writer.objects) == [ ('snapshot', data.snapshot)] def test_snapshot_add_twice__by_origin_visit(self, swh_storage): - origin_id = swh_storage.origin_add_one(data.origin) + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) origin_visit1 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit1, type=data.type_visit1, ) visit1_id = origin_visit1['visit'] swh_storage.snapshot_add([data.snapshot]) swh_storage.origin_visit_update( - origin_id, visit1_id, snapshot=data.snapshot['id']) + origin_url, visit1_id, snapshot=data.snapshot['id']) by_ov1 = swh_storage.snapshot_get_by_origin_visit( - origin_id, visit1_id) + origin_url, visit1_id) assert by_ov1 == {**data.snapshot, 'next_branch': None} origin_visit2 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit2, type=data.type_visit2, ) @@ -2168,10 +2102,10 @@ swh_storage.snapshot_add([data.snapshot]) swh_storage.origin_visit_update( - origin_id, visit2_id, snapshot=data.snapshot['id']) + origin_url, visit2_id, snapshot=data.snapshot['id']) by_ov2 = swh_storage.snapshot_get_by_origin_visit( - origin_id, visit2_id) + origin_url, visit2_id) assert by_ov2 == {**data.snapshot, 'next_branch': None} expected_origin = data.origin.copy() @@ -2219,20 +2153,18 @@ ('origin_visit', data3), ('origin_visit', data4)] - @pytest.mark.parametrize('use_url', [True, False]) - def test_snapshot_get_latest(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return - origin_id = swh_storage.origin_add_one(data.origin) + def test_snapshot_get_latest(self, swh_storage): + origin_url = data.origin['url'] + swh_storage.origin_add_one(data.origin) origin_url = data.origin['url'] origin_visit1 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit1, type=data.type_visit1, ) visit1_id = origin_visit1['visit'] origin_visit2 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit2, type=data.type_visit2, ) @@ -2240,7 +2172,7 @@ # Add a visit with the same date as the previous one origin_visit3 = swh_storage.origin_visit_add( - origin=origin_id, + origin=origin_url, date=data.date_visit2, type=data.type_visit3, ) @@ -2252,7 +2184,7 @@ # Add snapshot to visit1, latest snapshot = visit 1 snapshot swh_storage.snapshot_add([data.complete_snapshot]) swh_storage.origin_visit_update( - origin_id, visit1_id, snapshot=data.complete_snapshot['id']) + origin_url, visit1_id, snapshot=data.complete_snapshot['id']) assert {**data.complete_snapshot, 'next_branch': None} \ == swh_storage.snapshot_get_latest(origin_url) @@ -2263,7 +2195,7 @@ allowed_statuses=['full']) is None # Mark the first visit as completed and check status filter again - swh_storage.origin_visit_update(origin_id, visit1_id, status='full') + swh_storage.origin_visit_update(origin_url, visit1_id, status='full') assert {**data.complete_snapshot, 'next_branch': None} \ == swh_storage.snapshot_get_latest( origin_url, @@ -2272,9 +2204,9 @@ # Add snapshot to visit2 and check that the new snapshot is returned swh_storage.snapshot_add([data.empty_snapshot]) swh_storage.origin_visit_update( - origin_id, visit2_id, snapshot=data.empty_snapshot['id']) + origin_url, visit2_id, snapshot=data.empty_snapshot['id']) assert {**data.empty_snapshot, 'next_branch': None} \ - == swh_storage.snapshot_get_latest(origin_id) + == swh_storage.snapshot_get_latest(origin_url) # Check that the status filter is still working assert {**data.complete_snapshot, 'next_branch': None} \ @@ -2286,14 +2218,11 @@ # the new snapshot is returned swh_storage.snapshot_add([data.complete_snapshot]) swh_storage.origin_visit_update( - origin_id, visit3_id, snapshot=data.complete_snapshot['id']) + origin_url, visit3_id, snapshot=data.complete_snapshot['id']) assert {**data.complete_snapshot, 'next_branch': None} \ == swh_storage.snapshot_get_latest(origin_url) - @pytest.mark.parametrize('use_url', [True, False]) - def test_snapshot_get_latest__missing_snapshot(self, swh_storage, use_url): - if not self._test_origin_ids and not use_url: - return + def test_snapshot_get_latest__missing_snapshot(self, swh_storage): # Origin does not exist origin_url = data.origin['url'] assert swh_storage.snapshot_get_latest(origin_url) is None @@ -2821,13 +2750,10 @@ # then assert provider_id, actual_provider['id'] - @pytest.mark.parametrize('use_url', [True, False]) - def test_origin_metadata_add(self, swh_storage, use_url): - if not self._test_origin_ids: - pytest.skip('requires origin id') - + def test_origin_metadata_add(self, swh_storage): # given - origin = swh_storage.origin_add([data.origin])[0] + origin = data.origin + swh_storage.origin_add([origin])[0] tools = swh_storage.tool_add([data.metadata_tool]) tool = tools[0] @@ -2843,32 +2769,30 @@ }) # when adding for the same origin 2 metadatas - origin = origin['url' if use_url else 'id'] - - n_om = len(list(swh_storage.origin_metadata_get_by(origin))) + n_om = len(list(swh_storage.origin_metadata_get_by(origin['url']))) swh_storage.origin_metadata_add( - origin, + origin['url'], data.origin_metadata['discovery_date'], provider['id'], tool['id'], data.origin_metadata['metadata']) swh_storage.origin_metadata_add( - origin, + origin['url'], '2015-01-01 23:00:00+00', provider['id'], tool['id'], data.origin_metadata2['metadata']) - n_actual_om = len(list(swh_storage.origin_metadata_get_by(origin))) + n_actual_om = len(list( + swh_storage.origin_metadata_get_by(origin['url']))) # then assert n_actual_om == n_om + 2 def test_origin_metadata_get(self, swh_storage): - if not self._test_origin_ids: - pytest.skip('requires origin id') - # given - origin_id = swh_storage.origin_add([data.origin])[0]['id'] - origin_id2 = swh_storage.origin_add([data.origin2])[0]['id'] + origin_url = data.origin['url'] + origin_url2 = data.origin2['url'] + swh_storage.origin_add([data.origin]) + swh_storage.origin_add([data.origin2]) swh_storage.metadata_provider_add(data.provider['name'], data.provider['type'], @@ -2881,29 +2805,29 @@ tool = swh_storage.tool_add([data.metadata_tool])[0] # when adding for the same origin 2 metadatas swh_storage.origin_metadata_add( - origin_id, + origin_url, data.origin_metadata['discovery_date'], provider['id'], tool['id'], data.origin_metadata['metadata']) swh_storage.origin_metadata_add( - origin_id2, + origin_url2, data.origin_metadata2['discovery_date'], provider['id'], tool['id'], data.origin_metadata2['metadata']) swh_storage.origin_metadata_add( - origin_id, + origin_url, data.origin_metadata2['discovery_date'], provider['id'], tool['id'], data.origin_metadata2['metadata']) all_metadatas = list(sorted(swh_storage.origin_metadata_get_by( - origin_id), key=lambda x: x['discovery_date'])) + origin_url), key=lambda x: x['discovery_date'])) metadatas_for_origin2 = list(swh_storage.origin_metadata_get_by( - origin_id2)) + origin_url2)) expected_results = [{ - 'origin_id': origin_id, + 'origin_url': origin_url, 'discovery_date': datetime.datetime( 2015, 1, 1, 23, 0, tzinfo=datetime.timezone.utc), @@ -2917,7 +2841,7 @@ 'provider_url': 'http:///hal/inria', 'tool_id': tool['id'] }, { - 'origin_id': origin_id, + 'origin_url': origin_url, 'discovery_date': datetime.datetime( 2017, 1, 1, 23, 0, tzinfo=datetime.timezone.utc), @@ -2956,11 +2880,10 @@ def test_origin_metadata_get_by_provider_type(self, swh_storage): # given - if not self._test_origin_ids: - pytest.skip('reauires origin id') - - origin_id = swh_storage.origin_add([data.origin])[0]['id'] - origin_id2 = swh_storage.origin_add([data.origin2])[0]['id'] + origin_url = data.origin['url'] + origin_url2 = data.origin2['url'] + swh_storage.origin_add([data.origin]) + swh_storage.origin_add([data.origin2]) provider1_id = swh_storage.metadata_provider_add( data.provider['name'], data.provider['type'], @@ -2990,26 +2913,26 @@ # when adding for the same origin 2 metadatas swh_storage.origin_metadata_add( - origin_id, + origin_url, data.origin_metadata['discovery_date'], provider1['id'], tool['id'], data.origin_metadata['metadata']) swh_storage.origin_metadata_add( - origin_id2, + origin_url2, data.origin_metadata2['discovery_date'], provider2['id'], tool['id'], data.origin_metadata2['metadata']) provider_type = 'registry' m_by_provider = list(swh_storage.origin_metadata_get_by( - origin_id2, + origin_url2, provider_type)) for item in m_by_provider: if 'id' in item: del item['id'] expected_results = [{ - 'origin_id': origin_id2, + 'origin_url': origin_url2, 'discovery_date': datetime.datetime( 2017, 1, 1, 23, 0, tzinfo=datetime.timezone.utc), @@ -3030,8 +2953,6 @@ class TestStorageGeneratedData: - _test_origin_ids = True - def assert_contents_ok(self, expected_contents, actual_contents, keys_to_check={'sha1', 'data'}): """Assert that a given list of contents matches on a given set of keys. @@ -3180,28 +3101,7 @@ self.assert_contents_ok( [contents_map[get_sha1s[-1]]], actual_contents2, ['sha1']) - def test_origin_get_invalid_id_legacy(self, swh_storage): - if self._test_origin_ids: - invalid_origin_id = 1 - origin_info = swh_storage.origin_get({'id': invalid_origin_id}) - assert origin_info is None - - origin_visits = list(swh_storage.origin_visit_get( - invalid_origin_id)) - assert origin_visits == [] - - def test_origin_get_invalid_id(self, swh_storage): - if self._test_origin_ids: - origin_info = swh_storage.origin_get([{'id': 1}, {'id': 2}]) - assert origin_info == [None, None] - - origin_visits = list(swh_storage.origin_visit_get(1)) - assert origin_visits == [] - def test_origin_get_range(self, swh_storage, swh_origins): - if not self._test_origin_ids: - pytest.skip('requires origin id') - actual_origins = list( swh_storage.origin_get_range(origin_from=0, origin_count=0)) @@ -3212,33 +3112,41 @@ origin_count=1)) assert len(actual_origins) == 1 assert actual_origins[0]['id'] == 1 + assert actual_origins[0]['url'] == swh_origins[0]['url'] actual_origins = list( swh_storage.origin_get_range(origin_from=1, origin_count=1)) assert len(actual_origins) == 1 assert actual_origins[0]['id'] == 1 + assert actual_origins[0]['url'] == swh_origins[0]['url'] actual_origins = list( swh_storage.origin_get_range(origin_from=1, origin_count=10)) assert len(actual_origins) == 10 assert actual_origins[0]['id'] == 1 + assert actual_origins[0]['url'] == swh_origins[0]['url'] assert actual_origins[-1]['id'] == 10 + assert actual_origins[-1]['url'] == swh_origins[9]['url'] actual_origins = list( swh_storage.origin_get_range(origin_from=1, origin_count=20)) assert len(actual_origins) == 20 assert actual_origins[0]['id'] == 1 + assert actual_origins[0]['url'] == swh_origins[0]['url'] assert actual_origins[-1]['id'] == 20 + assert actual_origins[-1]['url'] == swh_origins[19]['url'] actual_origins = list( swh_storage.origin_get_range(origin_from=1, origin_count=101)) assert len(actual_origins) == 100 assert actual_origins[0]['id'] == 1 + assert actual_origins[0]['url'] == swh_origins[0]['url'] assert actual_origins[-1]['id'] == 100 + assert actual_origins[-1]['url'] == swh_origins[99]['url'] actual_origins = list( swh_storage.origin_get_range(origin_from=11, @@ -3250,7 +3158,9 @@ origin_count=10)) assert len(actual_origins) == 10 assert actual_origins[0]['id'] == 11 + assert actual_origins[0]['url'] == swh_origins[10]['url'] assert actual_origins[-1]['id'] == 20 + assert actual_origins[-1]['url'] == swh_origins[19]['url'] actual_origins = list( swh_storage.origin_get_range(origin_from=91, @@ -3258,6 +3168,10 @@ assert len(actual_origins) == 10 assert actual_origins[0]['id'] == 91 assert actual_origins[-1]['id'] == 100 + assert actual_origins[0]['id'] == 91 + assert actual_origins[0]['url'] == swh_origins[90]['url'] + assert actual_origins[-1]['id'] == 100 + assert actual_origins[-1]['url'] == swh_origins[99]['url'] def test_origin_count(self, swh_storage): new_origins = [ @@ -3298,11 +3212,12 @@ for (obj_type, obj) in objects: obj = obj.to_dict() if obj_type == 'origin_visit': - origin_id = swh_storage.origin_add_one(obj.pop('origin')) + origin = obj.pop('origin') + swh_storage.origin_add_one(origin) if 'visit' in obj: del obj['visit'] swh_storage.origin_visit_add( - origin_id, obj['date'], obj['type']) + origin['url'], obj['date'], obj['type']) else: method = getattr(swh_storage, obj_type + '_add') try: @@ -3314,8 +3229,6 @@ @pytest.mark.db class TestLocalStorage: """Test the local storage""" - _test_origin_ids = True - # This test is only relevant on the local storage, with an actual # objstorage raising an exception def test_content_add_objstorage_exception(self, swh_storage): diff --git a/tox.ini b/tox.ini --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist=flake8,mypy,py3-no-origin-ids,py3 +envlist=flake8,mypy,py3 [testenv:py3] deps = @@ -23,15 +23,6 @@ {envsitepackagesdir}/swh/storage \ --cov-branch {posargs} -[testenv:py3-no-origin-ids] -deps = - .[testing] - pytest-cov -setenv = - SWH_STORAGE_IN_MEMORY_ENABLE_ORIGIN_IDS=false -commands = - pytest --hypothesis-profile=fast {posargs} {envsitepackagesdir}/swh/storage/tests/test_in_memory.py - [testenv:flake8] skip_install = true deps =