Changeset View
Changeset View
Standalone View
Standalone View
swh/storage/storage.py
Show First 20 Lines • Show All 796 Lines • ▼ Show 20 Lines | def snapshot_get_random(self, db=None, cur=None): | ||||
return db.snapshot_get_random(cur) | return db.snapshot_get_random(cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_add( | def origin_visit_add( | ||||
self, visits: Iterable[OriginVisit], db=None, cur=None | self, visits: Iterable[OriginVisit], db=None, cur=None | ||||
) -> Iterable[OriginVisit]: | ) -> Iterable[OriginVisit]: | ||||
for visit in visits: | for visit in visits: | ||||
origin = self.origin_get({"url": visit.origin}, db=db, cur=cur) | origin = self.origin_get([visit.origin], db=db, cur=cur)[0] | ||||
if not origin: # Cannot add a visit without an origin | if not origin: # Cannot add a visit without an origin | ||||
raise StorageArgumentException("Unknown origin %s", visit.origin) | raise StorageArgumentException("Unknown origin %s", visit.origin) | ||||
all_visits = [] | all_visits = [] | ||||
nb_visits = 0 | nb_visits = 0 | ||||
for visit in visits: | for visit in visits: | ||||
nb_visits += 1 | nb_visits += 1 | ||||
if not visit.visit: | if not visit.visit: | ||||
Show All 32 Lines | class Storage: | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
def origin_visit_status_add( | def origin_visit_status_add( | ||||
self, visit_statuses: Iterable[OriginVisitStatus], db=None, cur=None, | self, visit_statuses: Iterable[OriginVisitStatus], db=None, cur=None, | ||||
) -> None: | ) -> None: | ||||
# First round to check existence (fail early if any is ko) | # First round to check existence (fail early if any is ko) | ||||
for visit_status in visit_statuses: | for visit_status in visit_statuses: | ||||
origin_url = self.origin_get({"url": visit_status.origin}, db=db, cur=cur) | origin_url = self.origin_get([visit_status.origin], db=db, cur=cur)[0] | ||||
if not origin_url: | if not origin_url: | ||||
raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | raise StorageArgumentException(f"Unknown origin {visit_status.origin}") | ||||
for visit_status in visit_statuses: | for visit_status in visit_statuses: | ||||
self._origin_visit_status_add(visit_status, db, cur) | self._origin_visit_status_add(visit_status, db, cur) | ||||
@timed | @timed | ||||
@db_transaction() | @db_transaction() | ||||
▲ Show 20 Lines • Show All 105 Lines • ▼ Show 20 Lines | def object_find_by_sha1_git(self, ids, db=None, cur=None): | ||||
ret[retval[0]].append( | ret[retval[0]].append( | ||||
dict(zip(db.object_find_by_sha1_git_cols, retval)) | dict(zip(db.object_find_by_sha1_git_cols, retval)) | ||||
) | ) | ||||
return ret | return ret | ||||
@timed | @timed | ||||
@db_transaction(statement_timeout=500) | @db_transaction(statement_timeout=500) | ||||
def origin_get(self, origins, db=None, cur=None): | def origin_get( | ||||
if isinstance(origins, dict): | self, origins: Iterable[str], db=None, cur=None | ||||
# Old API | ) -> Iterable[Optional[Origin]]: | ||||
return_single = True | origin_urls = list(origins) | ||||
origins = [origins] | rows = db.origin_get_by_url(origin_urls, cur) | ||||
elif len(origins) == 0: | result: List[Optional[Origin]] = [] | ||||
return [] | for row in rows: | ||||
else: | origin_d = dict(zip(db.origin_cols, row)) | ||||
return_single = False | url = origin_d["url"] | ||||
result.append(None if url is None else Origin(url=url)) | |||||
origin_urls = [origin["url"] for origin in origins] | return result | ||||
results = db.origin_get_by_url(origin_urls, cur) | |||||
results = [dict(zip(db.origin_cols, result)) for result in results] | |||||
if return_single: | |||||
assert len(results) == 1 | |||||
if results[0]["url"] is not None: | |||||
return results[0] | |||||
else: | |||||
return None | |||||
else: | |||||
return [None if res["url"] is None else res for res in results] | |||||
@timed | @timed | ||||
@db_transaction_generator(statement_timeout=500) | @db_transaction_generator(statement_timeout=500) | ||||
def origin_get_by_sha1(self, sha1s, db=None, cur=None): | def origin_get_by_sha1(self, sha1s, db=None, cur=None): | ||||
for line in db.origin_get_by_sha1(sha1s, cur): | for line in db.origin_get_by_sha1(sha1s, cur): | ||||
if line[0] is not None: | if line[0] is not None: | ||||
yield dict(zip(db.origin_cols, line)) | yield dict(zip(db.origin_cols, line)) | ||||
else: | else: | ||||
▲ Show 20 Lines • Show All 324 Lines • Show Last 20 Lines |