diff --git a/swh/provenance/mongo/backend.py b/swh/provenance/mongo/backend.py --- a/swh/provenance/mongo/backend.py +++ b/swh/provenance/mongo/backend.py @@ -1,4 +1,4 @@ -# Copyright (C) 2021 The Software Heritage developers +# Copyright (C) 2021-2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -8,11 +8,20 @@ from datetime import datetime, timezone import os from types import TracebackType -from typing import Any, Dict, Generator, Iterable, List, Optional, Set, Type, Union +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Type, + Union, +) from bson import ObjectId -import mongomock -import pymongo from swh.core.statsd import statsd from swh.model.model import Sha1Git @@ -29,6 +38,9 @@ STORAGE_DURATION_METRIC = "swh_provenance_storage_mongodb_duration_seconds" +if TYPE_CHECKING: + from pymongo.database import Database + class ProvenanceStorageMongoDb: def __init__(self, engine: str, **kwargs): @@ -96,6 +108,8 @@ else: origin = {"url": None} + assert origin is not None + for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( @@ -125,6 +139,8 @@ else: origin = {"url": None} + assert origin is not None + for path in content["revision"][str(revision["_id"])]: occurs.append( ProvenanceResult( @@ -146,6 +162,8 @@ else: origin = {"url": None} + assert origin is not None + for suffix in content["directory"][str(directory["_id"])]: for prefix in directory["revision"][str(revision["_id"])]: path = ( @@ -241,10 +259,10 @@ @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "open"}) def open(self) -> None: if self.engine == "mongomock": - self.db = mongomock.MongoClient(**self.conn_args).get_database(self.dbname) - else: - # assume real MongoDB server by default - self.db = pymongo.MongoClient(**self.conn_args).get_database(self.dbname) + from mongomock import MongoClient as MongoClient + else: # assume real MongoDB server by default + from pymongo import MongoClient + self.db: Database = MongoClient(**self.conn_args).get_database(self.dbname) @statsd.timed(metric=STORAGE_DURATION_METRIC, tags={"method": "origin_add"}) def origin_add(self, orgs: Dict[Sha1Git, str]) -> bool: