diff --git a/revisions/client.py b/revisions/client.py index 5bc30d1..c617b70 100755 --- a/revisions/client.py +++ b/revisions/client.py @@ -1,168 +1,171 @@ #!/usr/bin/env python import logging import logging.handlers import os import sys import time from datetime import timezone from multiprocessing import Process from threading import Thread from typing import Any, Dict import iso8601 import yaml import zmq from swh.core import config from swh.model.hashutil import hash_to_bytes from swh.provenance import get_archive, get_provenance from swh.provenance.archive import ArchiveInterface from swh.provenance.revision import RevisionEntry, revision_add # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILENAME" DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None) class Client(Process): def __init__( self, idx: int, threads: int, conf: Dict[str, Any], trackall: bool, lower: bool, mindepth: int, ): super().__init__() self.idx = idx self.threads = threads self.conf = conf self.trackall = trackall self.lower = lower self.mindepth = mindepth def run(self): # Using the same archive object for every worker to share internal caches. archive = get_archive(**self.conf["archive"]) # Launch as many threads as requested workers = [] for idx in range(self.threads): logging.info(f"Process {self.idx}: launching thread {idx}") worker = Worker( idx, archive, self.conf, self.trackall, self.lower, self.mindepth ) worker.start() workers.append(worker) # Wait for all threads to complete their work for idx, worker in enumerate(workers): logging.info(f"Process {self.idx}: waiting for thread {idx} to finish") worker.join() logging.info(f"Process {self.idx}: thread {idx} finished executing") class Worker(Thread): def __init__( self, idx: int, archive: ArchiveInterface, conf: Dict[str, Any], trackall: bool, lower: bool, mindepth: int, ): super().__init__() self.idx = idx self.archive = archive self.storage_conf = conf["storage"] self.url = f"tcp://{conf['rev_server']['host']}:{conf['rev_server']['port']}" # Each worker has its own provenance object to isolate # the processing of each revision. # self.provenance = get_provenance(**storage_conf) self.trackall = trackall self.lower = lower self.mindepth = mindepth logging.info( f"Worker {self.idx} created ({self.trackall}, {self.lower}, {self.mindepth})" ) def run(self): context = zmq.Context() socket = context.socket(zmq.REQ) socket.connect(self.url) with get_provenance(**self.storage_conf) as provenance: while True: socket.send(b"NEXT") response = socket.recv_json() if response is None: break - # Ensure date has a valid timezone - date = iso8601.parse_date(response["date"]) - if date.tzinfo is None: - date = date.replace(tzinfo=timezone.utc) - - revision = RevisionEntry( - hash_to_bytes(response["rev"]), - date=date, - root=hash_to_bytes(response["root"]), - ) + revisions = [] + for revision in response: + # Ensure date has a valid timezone + date = iso8601.parse_date(revision["date"]) + if date.tzinfo is None: + date = date.replace(tzinfo=timezone.utc) + revisions.append( + RevisionEntry( + hash_to_bytes(revision["rev"]), + date=date, + root=hash_to_bytes(revision["root"]), + ) + ) revision_add( provenance, self.archive, - [revision], + revisions, trackall=self.trackall, lower=self.lower, mindepth=self.mindepth, ) if __name__ == "__main__": # Check parameters if len(sys.argv) != 5: print("usage: client ") exit(-1) processes = int(sys.argv[1]) threads = 1 # int(sys.argv[2]) trackall = sys.argv[2].lower() != "false" lower = sys.argv[3].lower() != "false" mindepth = int(sys.argv[4]) config_file = None # TODO: Add as a cli option if ( config_file is None and DEFAULT_PATH is not None and config.config_exists(DEFAULT_PATH) ): config_file = DEFAULT_PATH if config_file is None or not os.path.exists(config_file): print("No configuration provided") exit(-1) conf = yaml.safe_load(open(config_file, "rb"))["provenance"] # Start counter start = time.time() # Launch as many clients as requested clients = [] for idx in range(processes): logging.info(f"MAIN: launching process {idx}") client = Client(idx, threads, conf, trackall, lower, mindepth) client.start() clients.append(client) # Wait for all processes to complete their work for idx, client in enumerate(clients): logging.info(f"MAIN: waiting for process {idx} to finish") client.join() logging.info(f"MAIN: process {idx} finished executing") # Stop counter and report elapsed time stop = time.time() print("Elapsed time:", stop - start, "seconds") diff --git a/revisions/server.py b/revisions/server.py index 51b3d4a..b12251f 100755 --- a/revisions/server.py +++ b/revisions/server.py @@ -1,180 +1,270 @@ #!/usr/bin/env python import gzip import io import os import queue import sys import threading +import time from datetime import datetime, timezone from enum import Enum -from typing import Any, Dict, List +from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional import iso8601 import yaml import zmq from swh.core import config from swh.provenance import get_provenance from swh.provenance.postgresql.provenance import ProvenanceStoragePostgreSql from swh.provenance.provenance import ProvenanceInterface # All generic config code should reside in swh.core.config CONFIG_ENVVAR = "SWH_CONFIG_FILENAME" DEFAULT_PATH = os.environ.get(CONFIG_ENVVAR, None) UTCEPOCH = datetime.fromtimestamp(0, timezone.utc) -class StatsCommand(Enum): - GET = "get" - EXIT = "exit" - - -class StatsWorker(threading.Thread): - def __init__(self, filename: str, storage_conf: Dict[str, Any]) -> None: - super().__init__() - self.filename = filename - self.queue = queue.Queue() - self.storage_conf = storage_conf - - def run(self) -> None: - tables = init_stats(self.filename) - with get_provenance(**self.storage_conf) as provenance: - while True: - try: - cmd, idx = self.queue.get(timeout=1) - if cmd == StatsCommand.EXIT: - break - elif cmd == StatsCommand.GET: - write_stats( - self.filename, idx, get_tables_stats(provenance, tables) - ) - except queue.Empty: - continue - - +# TODO: move this functions to StatsWorker class def get_tables_stats( provenance: ProvenanceInterface, tables: List[str] ) -> Dict[str, int]: # TODO: use ProvenanceStorageInterface instead! assert isinstance(provenance.storage, ProvenanceStoragePostgreSql) stats = {} for table in tables: with provenance.storage.transaction(readonly=True) as cursor: cursor.execute(f"SELECT COUNT(*) AS count FROM {table}") stats[table] = cursor.fetchone()["count"] return stats def init_stats(filename: str) -> List[str]: tables = [ "content", "content_in_revision", "content_in_directory", "directory", "directory_in_revision", "location", "revision", ] - header = ["revisions count", "datetime"] + header = ["datetime"] for table in tables: header.append(f"{table} rows") with io.open(filename, "w") as outfile: outfile.write(",".join(header)) outfile.write("\n") return tables -def write_stats(filename: str, count: int, stats: Dict[str, int]) -> None: - line = [str(count), str(datetime.now())] - for table, count in stats.items(): +def write_stats(filename: str, stats: Dict[str, int]) -> None: + line = [str(datetime.now())] + for _, count in stats.items(): line.append(str(count)) with io.open(filename, "a") as outfile: outfile.write(",".join(line)) outfile.write("\n") +class Command(Enum): + TERMINATE = "terminate" + + +class StatsWorker(threading.Thread): + def __init__( + self, + filename: str, + storage_conf: Dict[str, Any], + timeout: float = 300, + group: None = None, + target: Optional[Callable[..., Any]] = ..., + name: Optional[str] = ..., + args: Iterable[Any] = ..., + kwargs: Optional[Mapping[str, Any]] = ..., + *, + daemon: Optional[bool] = ..., + ) -> None: + super().__init__( + group=group, + target=target, + name=name, + args=args, + kwargs=kwargs, + daemon=daemon, + ) + self.filename = filename + self.queue = queue.Queue() + self.storage_conf = storage_conf + self.timeout = timeout + + def run(self) -> None: + tables = init_stats(self.filename) + start = time.monotonic() + with get_provenance(**self.storage_conf) as provenance: + while True: + now = time.monotonic() + if now - start > self.timeout: + write_stats(self.filename, get_tables_stats(provenance, tables)) + start = now + try: + cmd = self.queue.get(timeout=1) + if cmd == Command.TERMINATE: + break + except queue.Empty: + continue + + def stop(self) -> None: + self.queue.put(Command.TERMINATE) + self.join() + + +class RevisionWorker(threading.Thread): + def __init__( + self, + filename: str, + url: str, + limit: Optional[int] = None, + size: int = 1, + skip: int = 0, + group: None = None, + target: Optional[Callable[..., Any]] = ..., + name: Optional[str] = ..., + args: Iterable[Any] = ..., + kwargs: Optional[Mapping[str, Any]] = ..., + *, + daemon: Optional[bool] = ..., + ) -> None: + super().__init__( + group=group, + target=target, + name=name, + args=args, + kwargs=kwargs, + daemon=daemon, + ) + self.filename = filename + self.limit = limit + self.queue = queue.Queue() + self.size = size + self.skip = skip + self.url = url + + def run(self) -> None: + context = zmq.Context() + socket: zmq.Socket = context.socket(zmq.REP) + socket.bind(self.url) + + # TODO: improve this using a context manager + file = ( + io.open(self.filename, "r") + if os.path.splitext(self.filename)[1] == ".csv" + else gzip.open(self.filename, "rt") + ) + provider = (line.strip().split(",") for line in file if line.strip()) + + count = 0 + while True: + if self.limit is not None and count > self.limit: + break + + response = [] + for rev, date, root in provider: + count += 1 + if count <= self.skip or iso8601.parse_date(date) <= UTCEPOCH: + continue + response.append({"rev": rev, "date": date, "root": root}) + if len(response) == self.size: + break + if not response: + break + + # Wait for next request from client + # (TODO: make it non-blocking or add timeout) + socket.recv() + socket.send_json(response) + + try: + cmd = self.queue.get(block=False) + if cmd == Command.TERMINATE: + break + except queue.Empty: + continue + + while True: # TODO: improve shutdown logic + socket.recv() + socket.send_json(None) + # context.term() + + def stop(self) -> None: + self.queue.put(Command.TERMINATE) + self.join() + + if __name__ == "__main__": + # TODO: improve command line parsing if len(sys.argv) < 2: print("usage: server ") print("where") print( " filename : csv file containing the list of revisions to be iterated (one per" ) print( " line): revision sha1, date in ISO format, root directory sha1." ) exit(-1) - filename = sys.argv[1] - config_file = None # TODO: Add as a cli option if ( config_file is None and DEFAULT_PATH is not None and config.config_exists(DEFAULT_PATH) ): config_file = DEFAULT_PATH if config_file is None or not os.path.exists(config_file): print("No configuration provided") exit(-1) conf = yaml.safe_load(open(config_file, "rb"))["provenance"] - context = zmq.Context() - socket = context.socket(zmq.REP) - socket.bind(f"tcp://*:{conf['rev_server']['port']}") - - stats = conf["rev_server"].get("stats", False) - if stats: + # Init stats + statsfile = conf["rev_server"].get("stats_file") + if statsfile is not None: storage_conf = ( - conf["storage"] - if conf["storage"]["cls"] == "postgresql" - else conf["storage"]["storage_config"] + conf["storage"]["storage_config"] + if conf["storage"]["cls"] == "rabbitmq" + else conf["storage"] ) - dbname = storage_conf["db"].get("dbname", storage_conf["db"].get("service")) - statsfile = f"stats_{dbname}_{datetime.now()}.csv" - worker = StatsWorker(statsfile, storage_conf) - worker.start() - - revisions_provider = ( - (line.strip().split(",") for line in io.open(filename, "r") if line.strip()) - if os.path.splitext(filename)[1] == ".csv" - else ( - line.strip().split(",") - for line in gzip.open(filename, "rt") - if line.strip() + statsfile = f"stats_{datetime.now()}_{statsfile}" + statsworker = StatsWorker( + statsfile, storage_conf, timeout=conf["rev_server"].get("stats_rate", 300) ) - ) - - limit = conf["rev_server"].get("limit") - skip = conf["rev_server"].get("skip", 0) - for idx, (rev, date, root) in enumerate(revisions_provider): - if iso8601.parse_date(date) <= UTCEPOCH: - continue - - if limit is not None and limit <= idx: - break - - if stats and idx > skip and idx % stats == 0: - worker.queue.put((StatsCommand.GET, idx)) + statsworker.start() - # Wait for next request from client - request = socket.recv() - response = { - "rev": rev, - "date": date, - "root": root, - } - socket.send_json(response) - - if stats: - worker.queue.put((StatsCommand.GET, 0)) - worker.queue.put((StatsCommand.EXIT, None)) - worker.join() + # Init revision provider + revsfile = sys.argv[1] + url = f"tcp://*:{conf['rev_server']['port']}" + revsworker = RevisionWorker( + revsfile, + url, + limit=conf["rev_server"].get("limit"), + size=conf["rev_server"].get("size", 1), + skip=conf["rev_server"].get("skip", 0), + ) + revsworker.start() + # Wait for user commands while True: - # Force all clients to exit - request = socket.recv() - socket.send_json(None) + try: + command = input("Enter EXIT to stop service: ") + if command.lower() == "exit": + break + except KeyboardInterrupt: + pass + + # Release resources + revsworker.stop() + if statsfile: + statsworker.stop()