Changeset View
Changeset View
Standalone View
Standalone View
swh/provenance/mongo/backend.py
# Copyright (C) 2021 The Software Heritage developers | # Copyright (C) 2021 The Software Heritage developers | ||||
# See the AUTHORS file at the top-level directory of this distribution | # See the AUTHORS file at the top-level directory of this distribution | ||||
# License: GNU General Public License version 3, or any later version | # License: GNU General Public License version 3, or any later version | ||||
# See top-level LICENSE file for more information | # See top-level LICENSE file for more information | ||||
from __future__ import annotations | from __future__ import annotations | ||||
from datetime import datetime, timezone | from datetime import datetime, timezone | ||||
import os | import os | ||||
from types import TracebackType | from types import TracebackType | ||||
from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union | from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union | ||||
from bson import ObjectId | from bson import ObjectId | ||||
import mongomock | import mongomock | ||||
import pymongo | import pymongo | ||||
from pymongo import UpdateOne | |||||
from swh.core.statsd import statsd | from swh.core.statsd import statsd | ||||
from swh.model.model import Sha1Git | from swh.model.model import Sha1Git | ||||
from ..interface import ( | from ..interface import ( | ||||
EntityType, | EntityType, | ||||
ProvenanceResult, | ProvenanceResult, | ||||
ProvenanceStorageInterface, | ProvenanceStorageInterface, | ||||
RelationData, | RelationData, | ||||
RelationType, | RelationType, | ||||
RevisionData, | RevisionData, | ||||
) | ) | ||||
from .entity import Entity | |||||
STORAGE_DURATION_METRIC = "swh_provenance_storage_mongodb_duration_seconds" | STORAGE_DURATION_METRIC = "swh_provenance_storage_mongodb_duration_seconds" | ||||
class ProvenanceStorageMongoDb: | class ProvenanceStorageMongoDb: | ||||
def __init__(self, engine: str, **kwargs): | def __init__(self, engine: str, **kwargs): | ||||
self.engine = engine | self.engine = engine | ||||
self.dbname = kwargs.pop("dbname") | self.dbname = kwargs.pop("dbname") | ||||
Show All 10 Lines | def __exit__( | ||||
exc_tb: Optional[TracebackType], | exc_tb: Optional[TracebackType], | ||||
) -> None: | ) -> None: | ||||
self.close() | self.close() | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "close"}) | ||||
def close(self) -> None: | def close(self) -> None: | ||||
self.db.client.close() | self.db.client.close() | ||||
def _format_data(self, data: Union[Iterable[Sha1Git], Dict[Sha1Git, datetime]]): | |||||
return data if isinstance(data, dict) else dict.fromkeys(data) | |||||
def _generate_date_upserts(self, sha1, date): | |||||
ts = datetime.timestamp(date) if date is not None else None | |||||
# update only those with date either as None or later than the given one | |||||
return UpdateOne( | |||||
{"$and": [{"sha1": sha1}, {"$or": [{"ts": None}, {"ts": {"$gt": ts}}]}]}, | |||||
{"$set": {"ts": ts, "sha1": sha1}}, | |||||
upsert=True, | |||||
) | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_add"}) | ||||
def content_add( | def content_add( | ||||
self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | self, cnts: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | ||||
) -> bool: | ) -> bool: | ||||
data = cnts if isinstance(cnts, dict) else dict.fromkeys(cnts) | writes = [ | ||||
existing = { | self._generate_date_upserts(sha1, date) | ||||
x["sha1"]: x | for sha1, date in self._format_data(cnts).items() | ||||
for x in self.db.content.find( | ] | ||||
{"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} | Entity.factory("content").bulk_write(self.db, writes) | ||||
) | |||||
} | |||||
for sha1, date in data.items(): | |||||
ts = datetime.timestamp(date) if date is not None else None | |||||
if sha1 in existing: | |||||
cnt = existing[sha1] | |||||
if ts is not None and (cnt["ts"] is None or ts < cnt["ts"]): | |||||
self.db.content.update_one( | |||||
{"_id": cnt["_id"]}, {"$set": {"ts": ts}} | |||||
) | |||||
else: | |||||
self.db.content.insert_one( | |||||
{ | |||||
"sha1": sha1, | |||||
"ts": ts, | |||||
"revision": {}, | |||||
"directory": {}, | |||||
} | |||||
) | |||||
return True | return True | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) | |||||
def directory_add( | |||||
self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | |||||
) -> bool: | |||||
writes = [ | |||||
self._generate_date_upserts(sha1, date) | |||||
for sha1, date in self._format_data(dirs).items() | |||||
] | |||||
Entity.factory("directory").bulk_write(self.db, writes) | |||||
return True | |||||
def _get_oldest_revision_from_content(self, content): | |||||
# FIXME, returing with the assumption that content has all the revisons in the array | |||||
# Change to seperate collection content_in_revision if gets too big | |||||
return self.db.revision.find_one({"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}, | |||||
"ts": content["ts"]}) | |||||
def _get_preferred_origin(self, revision): | |||||
return self.db.origin.find_one({"sha1": revision["preferred"]}) | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_first"}) | ||||
def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | def content_find_first(self, id: Sha1Git) -> Optional[ProvenanceResult]: | ||||
# get all the revisions | # get all the revisions | ||||
# iterate and find the earliest | # iterate and find the earliest | ||||
content = self.db.content.find_one({"sha1": id}) | content = self.db.content.find_one({"sha1": id}) | ||||
if not content: | if not content: | ||||
return None | return None | ||||
occurs = [] | oldest_revision = self._get_oldest_revision_from_content(content) | ||||
for revision in self.db.revision.find( | origin = self._get_preferred_origin(oldest_revision) | ||||
{"_id": {"$in": [ObjectId(obj_id) for obj_id in content["revision"]]}} | |||||
): | |||||
if revision["preferred"] is not None: | |||||
origin = self.db.origin.find_one({"sha1": revision["preferred"]}) | |||||
else: | |||||
origin = {"url": None} | |||||
for path in content["revision"][str(revision["_id"])]: | return ProvenanceResult( | ||||
occurs.append( | |||||
ProvenanceResult( | |||||
content=id, | content=id, | ||||
revision=revision["sha1"], | revision=oldest_revision["sha1"], | ||||
date=datetime.fromtimestamp(revision["ts"], timezone.utc), | date=datetime.fromtimestamp(oldest_revision["ts"], timezone.utc), | ||||
origin=origin["url"], | origin=origin["url"] if origin else None, | ||||
path=path, | path='', | ||||
) | |||||
) | ) | ||||
return sorted(occurs, key=lambda x: (x.date, x.revision, x.origin, x.path))[0] | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "content_find_all"}) | ||||
def content_find_all( | def content_find_all( | ||||
self, id: Sha1Git, limit: Optional[int] = None | self, id: Sha1Git, limit: Optional[int] = None | ||||
) -> Generator[ProvenanceResult, None, None]: | ) -> Generator[ProvenanceResult, None, None]: | ||||
content = self.db.content.find_one({"sha1": id}) | content = self.db.content.find_one({"sha1": id}) | ||||
if not content: | if not content: | ||||
return None | return None | ||||
Show All 13 Lines | ) -> Generator[ProvenanceResult, None, None]: | ||||
content=id, | content=id, | ||||
revision=revision["sha1"], | revision=revision["sha1"], | ||||
date=datetime.fromtimestamp(revision["ts"], timezone.utc), | date=datetime.fromtimestamp(revision["ts"], timezone.utc), | ||||
origin=origin["url"], | origin=origin["url"], | ||||
path=path, | path=path, | ||||
) | ) | ||||
) | ) | ||||
for directory in self.db.directory.find( | for directory in self.db.directory.find( | ||||
{"_id": {"$in": [ObjectId(obj_id) for obj_id in content["directory"]]}} | { | ||||
"_id": { | |||||
"$in": [ObjectId(obj_id) for obj_id in content.get("directory", {})] | |||||
} | |||||
} | |||||
): | ): | ||||
for revision in self.db.revision.find( | for revision in self.db.revision.find( | ||||
{"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} | {"_id": {"$in": [ObjectId(obj_id) for obj_id in directory["revision"]]}} | ||||
): | ): | ||||
if revision["preferred"] is not None: | if revision["preferred"] is not None: | ||||
origin = self.db.origin.find_one({"sha1": revision["preferred"]}) | origin = self.db.origin.find_one({"sha1": revision["preferred"]}) | ||||
else: | else: | ||||
origin = {"url": None} | origin = {"url": None} | ||||
Show All 23 Lines | def content_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
return { | return { | ||||
x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) | x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) | ||||
for x in self.db.content.find( | for x in self.db.content.find( | ||||
{"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, | {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, | ||||
{"sha1": 1, "ts": 1, "_id": 0}, | {"sha1": 1, "ts": 1, "_id": 0}, | ||||
) | ) | ||||
} | } | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_add"}) | |||||
def directory_add( | |||||
self, dirs: Union[Iterable[Sha1Git], Dict[Sha1Git, Optional[datetime]]] | |||||
) -> bool: | |||||
data = dirs if isinstance(dirs, dict) else dict.fromkeys(dirs) | |||||
existing = { | |||||
x["sha1"]: x | |||||
for x in self.db.directory.find( | |||||
{"sha1": {"$in": list(data)}}, {"sha1": 1, "ts": 1, "_id": 1} | |||||
) | |||||
} | |||||
for sha1, date in data.items(): | |||||
ts = datetime.timestamp(date) if date is not None else None | |||||
if sha1 in existing: | |||||
dir = existing[sha1] | |||||
if ts is not None and (dir["ts"] is None or ts < dir["ts"]): | |||||
self.db.directory.update_one( | |||||
{"_id": dir["_id"]}, {"$set": {"ts": ts}} | |||||
) | |||||
else: | |||||
self.db.directory.insert_one({"sha1": sha1, "ts": ts, "revision": {}}) | |||||
return True | |||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "directory_get"}) | ||||
def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | def directory_get(self, ids: Iterable[Sha1Git]) -> Dict[Sha1Git, datetime]: | ||||
return { | return { | ||||
x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) | x["sha1"]: datetime.fromtimestamp(x["ts"], timezone.utc) | ||||
for x in self.db.directory.find( | for x in self.db.directory.find( | ||||
{"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, | {"sha1": {"$in": list(ids)}, "ts": {"$ne": None}}, | ||||
{"sha1": 1, "ts": 1, "_id": 0}, | {"sha1": 1, "ts": 1, "_id": 0}, | ||||
▲ Show 20 Lines • Show All 151 Lines • ▼ Show 20 Lines | ) -> bool: | ||||
) | ) | ||||
} | } | ||||
for sha1, dsts in denorm.items(): | for sha1, dsts in denorm.items(): | ||||
# update | # update | ||||
if src_relation != "revision": | if src_relation != "revision": | ||||
k = { | k = { | ||||
obj_id: list(set(paths + dsts.get(obj_id, []))) | obj_id: list(set(paths + dsts.get(obj_id, []))) | ||||
for obj_id, paths in src_objs[sha1][dst_relation].items() | for obj_id, paths in src_objs[sha1].get(dst_relation, {}).items() | ||||
} | } | ||||
self.db.get_collection(src_relation).update_one( | self.db.get_collection(src_relation).update_one( | ||||
{"_id": src_objs[sha1]["_id"]}, | {"_id": src_objs[sha1]["_id"]}, | ||||
{"$set": {dst_relation: dict(dsts, **k)}}, | {"$set": {dst_relation: dict(dsts, **k)}}, | ||||
) | ) | ||||
else: | else: | ||||
self.db.get_collection(src_relation).update_one( | self.db.get_collection(src_relation).update_one( | ||||
{"_id": src_objs[sha1]["_id"]}, | {"_id": src_objs[sha1]["_id"]}, | ||||
▲ Show 20 Lines • Show All 53 Lines • ▼ Show 20 Lines | ) -> Dict[Sha1Git, Set[RelationData]]: | ||||
else: | else: | ||||
dst_objs = { | dst_objs = { | ||||
x["sha1"]: x["_id"] | x["sha1"]: x["_id"] | ||||
for x in self.db.get_collection(dst).find( | for x in self.db.get_collection(dst).find( | ||||
{"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} | {"sha1": {"$in": list(sha1s)}}, {"_id": 1, "sha1": 1} | ||||
) | ) | ||||
} | } | ||||
src_objs = { | src_objs = { | ||||
x["sha1"]: x[dst] | x["sha1"]: x.get(dst, {}) | ||||
for x in self.db.get_collection(src).find( | for x in self.db.get_collection(src).find( | ||||
{}, {"_id": 0, "sha1": 1, dst: 1} | {}, {"_id": 0, "sha1": 1, dst: 1} | ||||
) | ) | ||||
} | } | ||||
result: Dict[Sha1Git, Set[RelationData]] = {} | result: Dict[Sha1Git, Set[RelationData]] = {} | ||||
if src != "revision": | if src != "revision": | ||||
for dst_sha1, dst_obj_id in dst_objs.items(): | for dst_sha1, dst_obj_id in dst_objs.items(): | ||||
for src_sha1, denorm in src_objs.items(): | for src_sha1, denorm in src_objs.items(): | ||||
Show All 16 Lines | class ProvenanceStorageMongoDb: | ||||
@statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) | @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "relation_get_all"}) | ||||
def relation_get_all( | def relation_get_all( | ||||
self, relation: RelationType | self, relation: RelationType | ||||
) -> Dict[Sha1Git, Set[RelationData]]: | ) -> Dict[Sha1Git, Set[RelationData]]: | ||||
src, *_, dst = relation.value.split("_") | src, *_, dst = relation.value.split("_") | ||||
empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] | empty: Union[Dict[str, bytes], List[str]] = {} if src != "revision" else [] | ||||
src_objs = { | src_objs = { | ||||
x["sha1"]: x[dst] | x["sha1"]: x.get(dst, {}) | ||||
for x in self.db.get_collection(src).find( | for x in self.db.get_collection(src).find( | ||||
{dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} | {dst: {"$ne": empty}}, {"_id": 0, "sha1": 1, dst: 1} | ||||
) | ) | ||||
} | } | ||||
dst_ids = list( | dst_ids = list( | ||||
{ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} | {ObjectId(obj_id) for _, value in src_objs.items() for obj_id in value} | ||||
) | ) | ||||
dst_objs = { | dst_objs = { | ||||
Show All 30 Lines |