diff --git a/.deepsource.toml b/.deepsource.toml new file mode 100644 index 00000000..bb9f19a4 --- /dev/null +++ b/.deepsource.toml @@ -0,0 +1,12 @@ +version = 1 + +test_patterns = ["dulwich/**test_*.py"] + +exclude_patterns = ["examples/**"] + +[[analyzers]] +name = "python" +enabled = true + + [analyzers.meta] + runtime_version = "3.x.x" \ No newline at end of file diff --git a/dulwich/cli.py b/dulwich/cli.py index 6e2c5fdd..cf1074a3 100755 --- a/dulwich/cli.py +++ b/dulwich/cli.py @@ -1,757 +1,755 @@ #!/usr/bin/python3 -u # # dulwich - Simple command-line interface to Dulwich # Copyright (C) 2008-2011 Jelmer Vernooij # vim: expandtab # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Simple command-line interface to Dulwich> This is a very simple command-line wrapper for Dulwich. It is by no means intended to be a full-blown Git command-line interface but just a way to test Dulwich. """ import os import sys from getopt import getopt import argparse import optparse import signal from typing import Dict, Type from dulwich import porcelain from dulwich.client import get_transport_and_path from dulwich.errors import ApplyDeltaError from dulwich.index import Index from dulwich.pack import Pack, sha_to_hex from dulwich.patch import write_tree_diff from dulwich.repo import Repo def signal_int(signal, frame): sys.exit(1) def signal_quit(signal, frame): import pdb pdb.set_trace() class Command(object): """A Dulwich subcommand.""" def run(self, args): """Run the command.""" raise NotImplementedError(self.run) class cmd_archive(Command): def run(self, args): parser = argparse.ArgumentParser() parser.add_argument( "--remote", type=str, help="Retrieve archive from specified remote repo", ) parser.add_argument('committish', type=str, nargs='?') args = parser.parse_args(args) if args.remote: client, path = get_transport_and_path(args.remote) client.archive( path, args.committish, sys.stdout.write, write_error=sys.stderr.write, ) else: porcelain.archive( ".", args.committish, outstream=sys.stdout.buffer, errstream=sys.stderr ) class cmd_add(Command): def run(self, argv): parser = argparse.ArgumentParser() args = parser.parse_args(argv) porcelain.add(".", paths=args) class cmd_rm(Command): def run(self, argv): parser = argparse.ArgumentParser() args = parser.parse_args(argv) porcelain.rm(".", paths=args) class cmd_fetch_pack(Command): def run(self, argv): parser = argparse.ArgumentParser() parser.add_argument('--all', action='store_true') parser.add_argument('location', nargs='?', type=str) args = parser.parse_args(argv) client, path = get_transport_and_path(args.location) r = Repo(".") if args.all: determine_wants = r.object_store.determine_wants_all else: def determine_wants(x): return [y for y in args if y not in r.object_store] client.fetch(path, r, determine_wants) class cmd_fetch(Command): def run(self, args): opts, args = getopt(args, "", []) opts = dict(opts) client, path = get_transport_and_path(args.pop(0)) r = Repo(".") refs = client.fetch(path, r, progress=sys.stdout.write) print("Remote refs:") for item in refs.items(): print("%s -> %s" % item) class cmd_fsck(Command): def run(self, args): opts, args = getopt(args, "", []) opts = dict(opts) for (obj, msg) in porcelain.fsck("."): print("%s: %s" % (obj, msg)) class cmd_log(Command): def run(self, args): parser = optparse.OptionParser() parser.add_option( "--reverse", dest="reverse", action="store_true", help="Reverse order in which entries are printed", ) parser.add_option( "--name-status", dest="name_status", action="store_true", help="Print name/status for each changed file", ) options, args = parser.parse_args(args) porcelain.log( ".", paths=args, reverse=options.reverse, name_status=options.name_status, outstream=sys.stdout, ) class cmd_diff(Command): def run(self, args): opts, args = getopt(args, "", []) if args == []: print("Usage: dulwich diff COMMITID") sys.exit(1) r = Repo(".") commit_id = args[0] commit = r[commit_id] parent_commit = r[commit.parents[0]] write_tree_diff(sys.stdout, r.object_store, parent_commit.tree, commit.tree) class cmd_dump_pack(Command): def run(self, args): opts, args = getopt(args, "", []) if args == []: print("Usage: dulwich dump-pack FILENAME") sys.exit(1) basename, _ = os.path.splitext(args[0]) x = Pack(basename) print("Object names checksum: %s" % x.name()) print("Checksum: %s" % sha_to_hex(x.get_stored_checksum())) if not x.check(): print("CHECKSUM DOES NOT MATCH") print("Length: %d" % len(x)) for name in x: try: print("\t%s" % x[name]) except KeyError as k: print("\t%s: Unable to resolve base %s" % (name, k)) except ApplyDeltaError as e: print("\t%s: Unable to apply delta: %r" % (name, e)) class cmd_dump_index(Command): def run(self, args): opts, args = getopt(args, "", []) if args == []: print("Usage: dulwich dump-index FILENAME") sys.exit(1) filename = args[0] idx = Index(filename) for o in idx: print(o, idx[o]) class cmd_init(Command): def run(self, args): opts, args = getopt(args, "", ["bare"]) opts = dict(opts) if args == []: path = os.getcwd() else: path = args[0] porcelain.init(path, bare=("--bare" in opts)) class cmd_clone(Command): def run(self, args): parser = optparse.OptionParser() parser.add_option( "--bare", dest="bare", help="Whether to create a bare repository.", action="store_true", ) parser.add_option( "--depth", dest="depth", type=int, help="Depth at which to fetch" ) options, args = parser.parse_args(args) if args == []: print("usage: dulwich clone host:path [PATH]") sys.exit(1) source = args.pop(0) if len(args) > 0: target = args.pop(0) else: target = None porcelain.clone(source, target, bare=options.bare, depth=options.depth) class cmd_commit(Command): def run(self, args): opts, args = getopt(args, "", ["message"]) opts = dict(opts) porcelain.commit(".", message=opts["--message"]) class cmd_commit_tree(Command): def run(self, args): opts, args = getopt(args, "", ["message"]) if args == []: print("usage: dulwich commit-tree tree") sys.exit(1) opts = dict(opts) porcelain.commit_tree(".", tree=args[0], message=opts["--message"]) class cmd_update_server_info(Command): def run(self, args): porcelain.update_server_info(".") class cmd_symbolic_ref(Command): def run(self, args): opts, args = getopt(args, "", ["ref-name", "force"]) if not args: print("Usage: dulwich symbolic-ref REF_NAME [--force]") sys.exit(1) ref_name = args.pop(0) porcelain.symbolic_ref(".", ref_name=ref_name, force="--force" in args) class cmd_show(Command): def run(self, argv): parser = argparse.ArgumentParser() parser.add_argument('objectish', type=str, nargs='*') args = parser.parse_args(argv) porcelain.show(".", args.objectish or None) class cmd_diff_tree(Command): def run(self, args): opts, args = getopt(args, "", []) if len(args) < 2: print("Usage: dulwich diff-tree OLD-TREE NEW-TREE") sys.exit(1) porcelain.diff_tree(".", args[0], args[1]) class cmd_rev_list(Command): def run(self, args): opts, args = getopt(args, "", []) if len(args) < 1: print("Usage: dulwich rev-list COMMITID...") sys.exit(1) porcelain.rev_list(".", args) class cmd_tag(Command): def run(self, args): parser = optparse.OptionParser() parser.add_option( "-a", "--annotated", help="Create an annotated tag.", action="store_true", ) parser.add_option( "-s", "--sign", help="Sign the annotated tag.", action="store_true" ) options, args = parser.parse_args(args) porcelain.tag_create( ".", args[0], annotated=options.annotated, sign=options.sign ) class cmd_repack(Command): def run(self, args): opts, args = getopt(args, "", []) opts = dict(opts) porcelain.repack(".") class cmd_reset(Command): def run(self, args): opts, args = getopt(args, "", ["hard", "soft", "mixed"]) opts = dict(opts) mode = "" if "--hard" in opts: mode = "hard" elif "--soft" in opts: mode = "soft" elif "--mixed" in opts: mode = "mixed" porcelain.reset(".", mode=mode, *args) class cmd_daemon(Command): def run(self, args): from dulwich import log_utils from dulwich.protocol import TCP_GIT_PORT parser = optparse.OptionParser() parser.add_option( "-l", "--listen_address", dest="listen_address", default="localhost", help="Binding IP address.", ) parser.add_option( "-p", "--port", dest="port", type=int, default=TCP_GIT_PORT, help="Binding TCP port.", ) options, args = parser.parse_args(args) log_utils.default_logging_config() if len(args) >= 1: gitdir = args[0] else: gitdir = "." - from dulwich import porcelain porcelain.daemon(gitdir, address=options.listen_address, port=options.port) class cmd_web_daemon(Command): def run(self, args): from dulwich import log_utils parser = optparse.OptionParser() parser.add_option( "-l", "--listen_address", dest="listen_address", default="", help="Binding IP address.", ) parser.add_option( "-p", "--port", dest="port", type=int, default=8000, help="Binding TCP port.", ) options, args = parser.parse_args(args) log_utils.default_logging_config() if len(args) >= 1: gitdir = args[0] else: gitdir = "." - from dulwich import porcelain porcelain.web_daemon(gitdir, address=options.listen_address, port=options.port) class cmd_write_tree(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) sys.stdout.write("%s\n" % porcelain.write_tree(".")) class cmd_receive_pack(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) if len(args) >= 1: gitdir = args[0] else: gitdir = "." porcelain.receive_pack(gitdir) class cmd_upload_pack(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) if len(args) >= 1: gitdir = args[0] else: gitdir = "." porcelain.upload_pack(gitdir) class cmd_status(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) if len(args) >= 1: gitdir = args[0] else: gitdir = "." status = porcelain.status(gitdir) if any(names for (kind, names) in status.staged.items()): sys.stdout.write("Changes to be committed:\n\n") for kind, names in status.staged.items(): for name in names: sys.stdout.write( "\t%s: %s\n" % (kind, name.decode(sys.getfilesystemencoding())) ) sys.stdout.write("\n") if status.unstaged: sys.stdout.write("Changes not staged for commit:\n\n") for name in status.unstaged: sys.stdout.write("\t%s\n" % name.decode(sys.getfilesystemencoding())) sys.stdout.write("\n") if status.untracked: sys.stdout.write("Untracked files:\n\n") for name in status.untracked: sys.stdout.write("\t%s\n" % name) sys.stdout.write("\n") class cmd_ls_remote(Command): def run(self, args): opts, args = getopt(args, "", []) if len(args) < 1: print("Usage: dulwich ls-remote URL") sys.exit(1) refs = porcelain.ls_remote(args[0]) for ref in sorted(refs): sys.stdout.write("%s\t%s\n" % (ref, refs[ref])) class cmd_ls_tree(Command): def run(self, args): parser = optparse.OptionParser() parser.add_option( "-r", "--recursive", action="store_true", help="Recusively list tree contents.", ) parser.add_option("--name-only", action="store_true", help="Only display name.") options, args = parser.parse_args(args) try: treeish = args.pop(0) except IndexError: treeish = None porcelain.ls_tree( ".", treeish, outstream=sys.stdout, recursive=options.recursive, name_only=options.name_only, ) class cmd_pack_objects(Command): def run(self, args): opts, args = getopt(args, "", ["stdout"]) opts = dict(opts) if len(args) < 1 and "--stdout" not in args: print("Usage: dulwich pack-objects basename") sys.exit(1) object_ids = [line.strip() for line in sys.stdin.readlines()] basename = args[0] if "--stdout" in opts: packf = getattr(sys.stdout, "buffer", sys.stdout) idxf = None close = [] else: packf = open(basename + ".pack", "w") idxf = open(basename + ".idx", "w") close = [packf, idxf] porcelain.pack_objects(".", object_ids, packf, idxf) for f in close: f.close() class cmd_pull(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) try: from_location = args[0] except IndexError: from_location = None porcelain.pull(".", from_location) class cmd_push(Command): def run(self, argv): parser = argparse.ArgumentParser() parser.add_argument('to_location', type=str) parser.add_argument('refspec', type=str, nargs='*') args = parser.parse_args(argv) porcelain.push('.', args.to_location, args.refspec or None) class cmd_remote_add(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) porcelain.remote_add(".", args[0], args[1]) class SuperCommand(Command): subcommands = {} # type: Dict[str, Type[Command]] def run(self, args): if not args: print("Supported subcommands: %s" % ", ".join(self.subcommands.keys())) return False cmd = args[0] try: cmd_kls = self.subcommands[cmd] except KeyError: print("No such subcommand: %s" % args[0]) return False return cmd_kls().run(args[1:]) class cmd_remote(SuperCommand): subcommands = { "add": cmd_remote_add, } class cmd_check_ignore(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) ret = 1 for path in porcelain.check_ignore(".", args): print(path) ret = 0 return ret class cmd_check_mailmap(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) for arg in args: canonical_identity = porcelain.check_mailmap(".", arg) print(canonical_identity) class cmd_stash_list(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) for i, entry in porcelain.stash_list("."): print("stash@{%d}: %s" % (i, entry.message.rstrip("\n"))) class cmd_stash_push(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) porcelain.stash_push(".") print("Saved working directory and index state") class cmd_stash_pop(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) porcelain.stash_pop(".") print("Restrored working directory and index state") class cmd_stash(SuperCommand): subcommands = { "list": cmd_stash_list, "pop": cmd_stash_pop, "push": cmd_stash_push, } class cmd_ls_files(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) for name in porcelain.ls_files("."): print(name) class cmd_describe(Command): def run(self, args): parser = optparse.OptionParser() options, args = parser.parse_args(args) print(porcelain.describe(".")) class cmd_help(Command): def run(self, args): parser = optparse.OptionParser() parser.add_option( "-a", "--all", dest="all", action="store_true", help="List all commands.", ) options, args = parser.parse_args(args) if options.all: print("Available commands:") for cmd in sorted(commands): print(" %s" % cmd) else: print( """\ The dulwich command line tool is currently a very basic frontend for the Dulwich python module. For full functionality, please see the API reference. For a list of supported commands, see 'dulwich help -a'. """ ) commands = { "add": cmd_add, "archive": cmd_archive, "check-ignore": cmd_check_ignore, "check-mailmap": cmd_check_mailmap, "clone": cmd_clone, "commit": cmd_commit, "commit-tree": cmd_commit_tree, "describe": cmd_describe, "daemon": cmd_daemon, "diff": cmd_diff, "diff-tree": cmd_diff_tree, "dump-pack": cmd_dump_pack, "dump-index": cmd_dump_index, "fetch-pack": cmd_fetch_pack, "fetch": cmd_fetch, "fsck": cmd_fsck, "help": cmd_help, "init": cmd_init, "log": cmd_log, "ls-files": cmd_ls_files, "ls-remote": cmd_ls_remote, "ls-tree": cmd_ls_tree, "pack-objects": cmd_pack_objects, "pull": cmd_pull, "push": cmd_push, "receive-pack": cmd_receive_pack, "remote": cmd_remote, "repack": cmd_repack, "reset": cmd_reset, "rev-list": cmd_rev_list, "rm": cmd_rm, "show": cmd_show, "stash": cmd_stash, "status": cmd_status, "symbolic-ref": cmd_symbolic_ref, "tag": cmd_tag, "update-server-info": cmd_update_server_info, "upload-pack": cmd_upload_pack, "web-daemon": cmd_web_daemon, "write-tree": cmd_write_tree, } def main(argv=None): if argv is None: argv = sys.argv if len(argv) < 1: print("Usage: dulwich <%s> [OPTIONS...]" % ("|".join(commands.keys()))) return 1 cmd = argv[0] try: cmd_kls = commands[cmd] except KeyError: print("No such subcommand: %s" % cmd) return 1 # TODO(jelmer): Return non-0 on errors return cmd_kls().run(argv[1:]) if __name__ == "__main__": if "DULWICH_PDB" in os.environ and getattr(signal, "SIGQUIT", None): signal.signal(signal.SIGQUIT, signal_quit) # type: ignore signal.signal(signal.SIGINT, signal_int) sys.exit(main(sys.argv[1:])) diff --git a/dulwich/client.py b/dulwich/client.py index 6e32d411..ca03f251 100644 --- a/dulwich/client.py +++ b/dulwich/client.py @@ -1,2194 +1,2194 @@ # client.py -- Implementation of the client side git protocols # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Client side support for the Git protocol. The Dulwich client supports the following capabilities: * thin-pack * multi_ack_detailed * multi_ack * side-band-64k * ofs-delta * quiet * report-status * delete-refs * shallow Known capabilities that are not supported: * no-progress * include-tag """ from contextlib import closing from io import BytesIO, BufferedReader import logging import os import select import socket import subprocess import sys from typing import Optional, Dict, Callable, Set from urllib.parse import ( quote as urlquote, unquote as urlunquote, urlparse, urljoin, urlunsplit, urlunparse, ) import dulwich from dulwich.config import get_xdg_config_home_path from dulwich.errors import ( GitProtocolError, NotGitRepository, SendPackError, ) from dulwich.protocol import ( HangupException, _RBUFSIZE, agent_string, capability_agent, extract_capability_names, CAPABILITY_AGENT, CAPABILITY_DELETE_REFS, CAPABILITY_INCLUDE_TAG, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_OFS_DELTA, CAPABILITY_QUIET, CAPABILITY_REPORT_STATUS, CAPABILITY_SHALLOW, CAPABILITY_SYMREF, CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK, CAPABILITIES_REF, KNOWN_RECEIVE_CAPABILITIES, KNOWN_UPLOAD_CAPABILITIES, COMMAND_DEEPEN, COMMAND_SHALLOW, COMMAND_UNSHALLOW, COMMAND_DONE, COMMAND_HAVE, COMMAND_WANT, SIDE_BAND_CHANNEL_DATA, SIDE_BAND_CHANNEL_PROGRESS, SIDE_BAND_CHANNEL_FATAL, PktLineParser, Protocol, ProtocolFile, TCP_GIT_PORT, ZERO_SHA, extract_capabilities, parse_capability, ) from dulwich.pack import ( write_pack_data, write_pack_objects, ) from dulwich.refs import ( read_info_refs, ANNOTATED_TAG_SUFFIX, ) logger = logging.getLogger(__name__) class InvalidWants(Exception): """Invalid wants.""" def __init__(self, wants): Exception.__init__( self, "requested wants not in server provided refs: %r" % wants ) class HTTPUnauthorized(Exception): """Raised when authentication fails.""" def __init__(self, www_authenticate, url): Exception.__init__(self, "No valid credentials provided") self.www_authenticate = www_authenticate self.url = url def _fileno_can_read(fileno): """Check if a file descriptor is readable.""" return len(select.select([fileno], [], [], 0)[0]) > 0 def _win32_peek_avail(handle): """Wrapper around PeekNamedPipe to check how many bytes are available.""" from ctypes import byref, wintypes, windll c_avail = wintypes.DWORD() c_message = wintypes.DWORD() success = windll.kernel32.PeekNamedPipe( handle, None, 0, None, byref(c_avail), byref(c_message) ) if not success: raise OSError(wintypes.GetLastError()) return c_avail.value COMMON_CAPABILITIES = [CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND_64K] UPLOAD_CAPABILITIES = [ CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_SHALLOW, ] + COMMON_CAPABILITIES RECEIVE_CAPABILITIES = [ CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS, ] + COMMON_CAPABILITIES class ReportStatusParser(object): """Handle status as reported by servers with 'report-status' capability.""" def __init__(self): self._done = False self._pack_status = None self._ref_statuses = [] def check(self): """Check if there were any errors and, if so, raise exceptions. Raises: SendPackError: Raised when the server could not unpack Returns: iterator over refs """ if self._pack_status not in (b"unpack ok", None): raise SendPackError(self._pack_status) for status in self._ref_statuses: try: status, rest = status.split(b" ", 1) except ValueError: # malformed response, move on to the next one continue if status == b"ng": ref, error = rest.split(b" ", 1) yield ref, error.decode("utf-8") elif status == b"ok": yield rest, None else: raise GitProtocolError("invalid ref status %r" % status) def handle_packet(self, pkt): """Handle a packet. Raises: GitProtocolError: Raised when packets are received after a flush packet. """ if self._done: raise GitProtocolError("received more data after status report") if pkt is None: self._done = True return if self._pack_status is None: self._pack_status = pkt.strip() else: ref_status = pkt.strip() self._ref_statuses.append(ref_status) def read_pkt_refs(proto): server_capabilities = None refs = {} # Receive refs from server for pkt in proto.read_pkt_seq(): (sha, ref) = pkt.rstrip(b"\n").split(None, 1) if sha == b"ERR": raise GitProtocolError(ref.decode("utf-8", "replace")) if server_capabilities is None: (ref, server_capabilities) = extract_capabilities(ref) refs[ref] = sha if len(refs) == 0: return {}, set([]) if refs == {CAPABILITIES_REF: ZERO_SHA}: refs = {} return refs, set(server_capabilities) class FetchPackResult(object): """Result of a fetch-pack operation. Attributes: refs: Dictionary with all remote refs symrefs: Dictionary with remote symrefs agent: User agent string """ _FORWARDED_ATTRS = [ "clear", "copy", "fromkeys", "get", "items", "keys", "pop", "popitem", "setdefault", "update", "values", "viewitems", "viewkeys", "viewvalues", ] def __init__(self, refs, symrefs, agent, new_shallow=None, new_unshallow=None): self.refs = refs self.symrefs = symrefs self.agent = agent self.new_shallow = new_shallow self.new_unshallow = new_unshallow def _warn_deprecated(self): import warnings warnings.warn( "Use FetchPackResult.refs instead.", DeprecationWarning, stacklevel=3, ) def __eq__(self, other): if isinstance(other, dict): self._warn_deprecated() return self.refs == other return ( self.refs == other.refs and self.symrefs == other.symrefs and self.agent == other.agent ) def __contains__(self, name): self._warn_deprecated() return name in self.refs def __getitem__(self, name): self._warn_deprecated() return self.refs[name] def __len__(self): self._warn_deprecated() return len(self.refs) def __iter__(self): self._warn_deprecated() return iter(self.refs) def __getattribute__(self, name): if name in type(self)._FORWARDED_ATTRS: self._warn_deprecated() return getattr(self.refs, name) return super(FetchPackResult, self).__getattribute__(name) def __repr__(self): return "%s(%r, %r, %r)" % ( self.__class__.__name__, self.refs, self.symrefs, self.agent, ) class SendPackResult(object): """Result of a upload-pack operation. Attributes: refs: Dictionary with all remote refs agent: User agent string ref_status: Optional dictionary mapping ref name to error message (if it failed to update), or None if it was updated successfully """ _FORWARDED_ATTRS = [ "clear", "copy", "fromkeys", "get", "items", "keys", "pop", "popitem", "setdefault", "update", "values", "viewitems", "viewkeys", "viewvalues", ] def __init__(self, refs, agent=None, ref_status=None): self.refs = refs self.agent = agent self.ref_status = ref_status def _warn_deprecated(self): import warnings warnings.warn( "Use SendPackResult.refs instead.", DeprecationWarning, stacklevel=3, ) def __eq__(self, other): if isinstance(other, dict): self._warn_deprecated() return self.refs == other return self.refs == other.refs and self.agent == other.agent def __contains__(self, name): self._warn_deprecated() return name in self.refs def __getitem__(self, name): self._warn_deprecated() return self.refs[name] def __len__(self): self._warn_deprecated() return len(self.refs) def __iter__(self): self._warn_deprecated() return iter(self.refs) def __getattribute__(self, name): if name in type(self)._FORWARDED_ATTRS: self._warn_deprecated() return getattr(self.refs, name) return super(SendPackResult, self).__getattribute__(name) def __repr__(self): return "%s(%r, %r)" % (self.__class__.__name__, self.refs, self.agent) def _read_shallow_updates(proto): new_shallow = set() new_unshallow = set() for pkt in proto.read_pkt_seq(): cmd, sha = pkt.split(b" ", 1) if cmd == COMMAND_SHALLOW: new_shallow.add(sha.strip()) elif cmd == COMMAND_UNSHALLOW: new_unshallow.add(sha.strip()) else: raise GitProtocolError("unknown command %s" % pkt) return (new_shallow, new_unshallow) # TODO(durin42): this doesn't correctly degrade if the server doesn't # support some capabilities. This should work properly with servers # that don't support multi_ack. class GitClient(object): """Git smart server client.""" def __init__( self, thin_packs=True, report_activity=None, quiet=False, include_tags=False, ): """Create a new GitClient instance. Args: thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. include_tags: send annotated tags when sending the objects they point to """ self._report_activity = report_activity self._report_status_parser = None self._fetch_capabilities = set(UPLOAD_CAPABILITIES) self._fetch_capabilities.add(capability_agent()) self._send_capabilities = set(RECEIVE_CAPABILITIES) self._send_capabilities.add(capability_agent()) if quiet: self._send_capabilities.add(CAPABILITY_QUIET) if not thin_packs: self._fetch_capabilities.remove(CAPABILITY_THIN_PACK) if include_tags: self._fetch_capabilities.add(CAPABILITY_INCLUDE_TAG) def get_url(self, path): """Retrieves full url to given path. Args: path: Repository path (as string) Returns: Url to path (as string) """ raise NotImplementedError(self.get_url) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): """Create an instance of this client from a urlparse.parsed object. Args: parsedurl: Result of urlparse() Returns: A `GitClient` object """ raise NotImplementedError(cls.from_parsedurl) def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of objects and list of pack data to include progress: Optional progress function Returns: SendPackResult object Raises: SendPackError: if server rejects the pack data """ raise NotImplementedError(self.send_pack) def fetch(self, path, target, determine_wants=None, progress=None, depth=None): """Fetch into a target repository. Args: path: Path to fetch from (as bytestring) target: Target repository to fetch into determine_wants: Optional function to determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. Defaults to all shas. progress: Optional progress function depth: Depth to fetch at Returns: Dictionary with all remote refs (not just those fetched) """ if determine_wants is None: determine_wants = target.object_store.determine_wants_all if CAPABILITY_THIN_PACK in self._fetch_capabilities: # TODO(jelmer): Avoid reading entire file into memory and # only processing it after the whole file has been fetched. f = BytesIO() def commit(): if f.tell(): f.seek(0) target.object_store.add_thin_pack(f.read, None) def abort(): pass else: f, commit, abort = target.object_store.add_pack() try: result = self.fetch_pack( path, determine_wants, target.get_graph_walker(), f.write, progress=progress, depth=depth, ) except BaseException: abort() raise else: commit() target.update_shallow(result.new_shallow, result.new_unshallow) return result def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ raise NotImplementedError(self.fetch_pack) def get_refs(self, path): """Retrieve the current refs from a git smart server. Args: path: Path to the repo to fetch from. (as bytestring) Returns: """ raise NotImplementedError(self.get_refs) def _read_side_band64k_data(self, proto, channel_callbacks): """Read per-channel data. This requires the side-band-64k capability. Args: proto: Protocol object to read from channel_callbacks: Dictionary mapping channels to packet handlers to use. None for a callback discards channel data. """ for pkt in proto.read_pkt_seq(): channel = ord(pkt[:1]) pkt = pkt[1:] try: cb = channel_callbacks[channel] except KeyError: raise AssertionError("Invalid sideband channel %d" % channel) else: if cb is not None: cb(pkt) @staticmethod def _should_send_pack(new_refs): # The packfile MUST NOT be sent if the only command used is delete. return any(sha != ZERO_SHA for sha in new_refs.values()) def _handle_receive_pack_head(self, proto, capabilities, old_refs, new_refs): """Handle the head of a 'git-receive-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities old_refs: Old refs, as received from the server new_refs: Refs to change Returns: (have, want) tuple """ want = [] have = [x for x in old_refs.values() if not x == ZERO_SHA] sent_capabilities = False for refname in new_refs: if not isinstance(refname, bytes): raise TypeError("refname is not a bytestring: %r" % refname) old_sha1 = old_refs.get(refname, ZERO_SHA) if not isinstance(old_sha1, bytes): raise TypeError( "old sha1 for %s is not a bytestring: %r" % (refname, old_sha1) ) new_sha1 = new_refs.get(refname, ZERO_SHA) if not isinstance(new_sha1, bytes): raise TypeError( "old sha1 for %s is not a bytestring %r" % (refname, new_sha1) ) if old_sha1 != new_sha1: logger.debug( 'Sending updated ref %r: %r -> %r', refname, old_sha1, new_sha1) if sent_capabilities: proto.write_pkt_line(old_sha1 + b" " + new_sha1 + b" " + refname) else: proto.write_pkt_line( old_sha1 + b" " + new_sha1 + b" " + refname + b"\0" + b" ".join(sorted(capabilities)) ) sent_capabilities = True if new_sha1 not in have and new_sha1 != ZERO_SHA: want.append(new_sha1) proto.write_pkt_line(None) return (have, want) def _negotiate_receive_pack_capabilities(self, server_capabilities): negotiated_capabilities = self._send_capabilities & server_capabilities agent = None for capability in server_capabilities: k, v = parse_capability(capability) if k == CAPABILITY_AGENT: agent = v unknown_capabilities = ( # noqa: F841 extract_capability_names(server_capabilities) - KNOWN_RECEIVE_CAPABILITIES ) # TODO(jelmer): warn about unknown capabilities return negotiated_capabilities, agent def _handle_receive_pack_tail( self, proto: Protocol, capabilities: Set[bytes], progress: Callable[[bytes], None] = None, ) -> Optional[Dict[bytes, Optional[str]]]: """Handle the tail of a 'git-receive-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities progress: Optional progress reporting function Returns: dict mapping ref name to: error message if the ref failed to update None if it was updated successfully """ if CAPABILITY_SIDE_BAND_64K in capabilities: if progress is None: def progress(x): pass channel_callbacks = {2: progress} if CAPABILITY_REPORT_STATUS in capabilities: channel_callbacks[1] = PktLineParser( self._report_status_parser.handle_packet ).parse self._read_side_band64k_data(proto, channel_callbacks) else: if CAPABILITY_REPORT_STATUS in capabilities: for pkt in proto.read_pkt_seq(): self._report_status_parser.handle_packet(pkt) if self._report_status_parser is not None: return dict(self._report_status_parser.check()) return None def _negotiate_upload_pack_capabilities(self, server_capabilities): unknown_capabilities = ( # noqa: F841 extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES ) # TODO(jelmer): warn about unknown capabilities symrefs = {} agent = None for capability in server_capabilities: k, v = parse_capability(capability) if k == CAPABILITY_SYMREF: (src, dst) = v.split(b":", 1) symrefs[src] = dst if k == CAPABILITY_AGENT: agent = v negotiated_capabilities = self._fetch_capabilities & server_capabilities return (negotiated_capabilities, symrefs, agent) def _handle_upload_pack_head( self, proto, capabilities, graph_walker, wants, can_read, depth ): """Handle the head of a 'git-upload-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities graph_walker: GraphWalker instance to call .ack() on wants: List of commits to fetch can_read: function that returns a boolean that indicates whether there is extra graph data to read on proto depth: Depth for request Returns: """ assert isinstance(wants, list) and isinstance(wants[0], bytes) proto.write_pkt_line( COMMAND_WANT + b" " + wants[0] + b" " + b" ".join(sorted(capabilities)) + b"\n" ) for want in wants[1:]: proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n") if depth not in (0, None) or getattr(graph_walker, "shallow", None): if CAPABILITY_SHALLOW not in capabilities: raise GitProtocolError( "server does not support shallow capability required for " "depth" ) for sha in graph_walker.shallow: proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n") if depth is not None: proto.write_pkt_line( COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n" ) proto.write_pkt_line(None) if can_read is not None: (new_shallow, new_unshallow) = _read_shallow_updates(proto) else: new_shallow = new_unshallow = None else: new_shallow = new_unshallow = set() proto.write_pkt_line(None) have = next(graph_walker) while have: proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n") if can_read is not None and can_read(): pkt = proto.read_pkt_line() parts = pkt.rstrip(b"\n").split(b" ") if parts[0] == b"ACK": graph_walker.ack(parts[1]) if parts[2] in (b"continue", b"common"): pass elif parts[2] == b"ready": break else: raise AssertionError( "%s not in ('continue', 'ready', 'common)" % parts[2] ) have = next(graph_walker) proto.write_pkt_line(COMMAND_DONE + b"\n") return (new_shallow, new_unshallow) def _handle_upload_pack_tail( self, proto, capabilities, graph_walker, pack_data, progress=None, rbufsize=_RBUFSIZE, ): """Handle the tail of a 'git-upload-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities graph_walker: GraphWalker instance to call .ack() on pack_data: Function to call with pack data progress: Optional progress reporting function rbufsize: Read buffer size Returns: """ pkt = proto.read_pkt_line() while pkt: parts = pkt.rstrip(b"\n").split(b" ") if parts[0] == b"ACK": graph_walker.ack(parts[1]) if len(parts) < 3 or parts[2] not in ( b"ready", b"continue", b"common", ): break pkt = proto.read_pkt_line() if CAPABILITY_SIDE_BAND_64K in capabilities: if progress is None: # Just ignore progress data def progress(x): pass self._read_side_band64k_data( proto, { SIDE_BAND_CHANNEL_DATA: pack_data, SIDE_BAND_CHANNEL_PROGRESS: progress, }, ) else: while True: data = proto.read(rbufsize) if data == b"": break pack_data(data) def check_wants(wants, refs): """Check that a set of wants is valid. Args: wants: Set of object SHAs to fetch refs: Refs dictionary to check against Returns: """ missing = set(wants) - { v for (k, v) in refs.items() if not k.endswith(ANNOTATED_TAG_SUFFIX) } if missing: raise InvalidWants(missing) def _remote_error_from_stderr(stderr): if stderr is None: return HangupException() lines = [line.rstrip(b"\n") for line in stderr.readlines()] for line in lines: if line.startswith(b"ERROR: "): return GitProtocolError(line[len(b"ERROR: ") :].decode("utf-8", "replace")) return HangupException(lines) class TraditionalGitClient(GitClient): """Traditional Git client.""" DEFAULT_ENCODING = "utf-8" def __init__(self, path_encoding=DEFAULT_ENCODING, **kwargs): self._remote_path_encoding = path_encoding super(TraditionalGitClient, self).__init__(**kwargs) async def _connect(self, cmd, path): """Create a connection to the server. This method is abstract - concrete implementations should implement their own variant which connects to the server and returns an initialized Protocol object with the service ready for use and a can_read function which may be used to see if reads would block. Args: cmd: The git service name to which we should connect. path: The path we should pass to the service. (as bytestirng) """ raise NotImplementedError() def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of objects and pack data to upload. progress: Optional callback called with progress updates Returns: SendPackResult Raises: SendPackError: if server rejects the pack data """ proto, unused_can_read, stderr = self._connect(b"receive-pack", path) with proto: try: old_refs, server_capabilities = read_pkt_refs(proto) except HangupException: raise _remote_error_from_stderr(stderr) ( negotiated_capabilities, agent, ) = self._negotiate_receive_pack_capabilities(server_capabilities) if CAPABILITY_REPORT_STATUS in negotiated_capabilities: self._report_status_parser = ReportStatusParser() report_status_parser = self._report_status_parser try: new_refs = orig_new_refs = update_refs(dict(old_refs)) except BaseException: proto.write_pkt_line(None) raise if set(new_refs.items()).issubset(set(old_refs.items())): proto.write_pkt_line(None) return SendPackResult(new_refs, agent=agent, ref_status={}) if CAPABILITY_DELETE_REFS not in server_capabilities: # Server does not support deletions. Fail later. new_refs = dict(orig_new_refs) for ref, sha in orig_new_refs.items(): if sha == ZERO_SHA: if CAPABILITY_REPORT_STATUS in negotiated_capabilities: report_status_parser._ref_statuses.append( b"ng " + ref + b" remote does not support deleting refs" ) report_status_parser._ref_status_ok = False del new_refs[ref] if new_refs is None: proto.write_pkt_line(None) return SendPackResult(old_refs, agent=agent, ref_status={}) - if len(new_refs) == 0 and len(orig_new_refs): + if len(new_refs) == 0 and orig_new_refs: # NOOP - Original new refs filtered out by policy proto.write_pkt_line(None) if report_status_parser is not None: ref_status = dict(report_status_parser.check()) else: ref_status = None return SendPackResult(old_refs, agent=agent, ref_status=ref_status) (have, want) = self._handle_receive_pack_head( proto, negotiated_capabilities, old_refs, new_refs ) pack_data_count, pack_data = generate_pack_data( have, want, ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities), ) if self._should_send_pack(new_refs): write_pack_data(proto.write_file(), pack_data_count, pack_data) ref_status = self._handle_receive_pack_tail( proto, negotiated_capabilities, progress ) return SendPackResult(new_refs, agent=agent, ref_status=ref_status) def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ proto, can_read, stderr = self._connect(b"upload-pack", path) with proto: try: refs, server_capabilities = read_pkt_refs(proto) except HangupException: raise _remote_error_from_stderr(stderr) ( negotiated_capabilities, symrefs, agent, ) = self._negotiate_upload_pack_capabilities(server_capabilities) if refs is None: proto.write_pkt_line(None) return FetchPackResult(refs, symrefs, agent) try: wants = determine_wants(refs) except BaseException: proto.write_pkt_line(None) raise if wants is not None: wants = [cid for cid in wants if cid != ZERO_SHA] if not wants: proto.write_pkt_line(None) return FetchPackResult(refs, symrefs, agent) (new_shallow, new_unshallow) = self._handle_upload_pack_head( proto, negotiated_capabilities, graph_walker, wants, can_read, depth=depth, ) self._handle_upload_pack_tail( proto, negotiated_capabilities, graph_walker, pack_data, progress, ) return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow) def get_refs(self, path): """Retrieve the current refs from a git smart server.""" # stock `git ls-remote` uses upload-pack proto, _, stderr = self._connect(b"upload-pack", path) with proto: try: refs, _ = read_pkt_refs(proto) except HangupException: raise _remote_error_from_stderr(stderr) proto.write_pkt_line(None) return refs def archive( self, path, committish, write_data, progress=None, write_error=None, format=None, subdirs=None, prefix=None, ): proto, can_read, stderr = self._connect(b"upload-archive", path) with proto: if format is not None: proto.write_pkt_line(b"argument --format=" + format) proto.write_pkt_line(b"argument " + committish) if subdirs is not None: for subdir in subdirs: proto.write_pkt_line(b"argument " + subdir) if prefix is not None: proto.write_pkt_line(b"argument --prefix=" + prefix) proto.write_pkt_line(None) try: pkt = proto.read_pkt_line() except HangupException: raise _remote_error_from_stderr(stderr) if pkt == b"NACK\n" or pkt == b"NACK": return elif pkt == b"ACK\n" or pkt == b"ACK": pass elif pkt.startswith(b"ERR "): raise GitProtocolError(pkt[4:].rstrip(b"\n").decode("utf-8", "replace")) else: raise AssertionError("invalid response %r" % pkt) ret = proto.read_pkt_line() if ret is not None: raise AssertionError("expected pkt tail") self._read_side_band64k_data( proto, { SIDE_BAND_CHANNEL_DATA: write_data, SIDE_BAND_CHANNEL_PROGRESS: progress, SIDE_BAND_CHANNEL_FATAL: write_error, }, ) class TCPGitClient(TraditionalGitClient): """A Git Client that works over TCP directly (i.e. git://).""" def __init__(self, host, port=None, **kwargs): if port is None: port = TCP_GIT_PORT self._host = host self._port = port super(TCPGitClient, self).__init__(**kwargs) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(parsedurl.hostname, port=parsedurl.port, **kwargs) def get_url(self, path): netloc = self._host if self._port is not None and self._port != TCP_GIT_PORT: netloc += ":%d" % self._port return urlunsplit(("git", netloc, path, "", "")) def _connect(self, cmd, path): if not isinstance(cmd, bytes): raise TypeError(cmd) if not isinstance(path, bytes): path = path.encode(self._remote_path_encoding) sockaddrs = socket.getaddrinfo( self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM ) s = None err = socket.error("no address found for %s" % self._host) for (family, socktype, proto, canonname, sockaddr) in sockaddrs: s = socket.socket(family, socktype, proto) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: s.connect(sockaddr) break except socket.error as e: err = e if s is not None: s.close() s = None if s is None: raise err # -1 means system default buffering rfile = s.makefile("rb", -1) # 0 means unbuffered wfile = s.makefile("wb", 0) def close(): rfile.close() wfile.close() s.close() proto = Protocol( rfile.read, wfile.write, close, report_activity=self._report_activity, ) if path.startswith(b"/~"): path = path[1:] # TODO(jelmer): Alternative to ascii? proto.send_cmd(b"git-" + cmd, path, b"host=" + self._host.encode("ascii")) return proto, lambda: _fileno_can_read(s), None class SubprocessWrapper(object): """A socket-like object that talks to a subprocess via pipes.""" def __init__(self, proc): self.proc = proc self.read = BufferedReader(proc.stdout).read self.write = proc.stdin.write @property def stderr(self): return self.proc.stderr def can_read(self): if sys.platform == "win32": from msvcrt import get_osfhandle handle = get_osfhandle(self.proc.stdout.fileno()) return _win32_peek_avail(handle) != 0 else: return _fileno_can_read(self.proc.stdout.fileno()) def close(self): self.proc.stdin.close() self.proc.stdout.close() if self.proc.stderr: self.proc.stderr.close() self.proc.wait() def find_git_command(): """Find command to run for system Git (usually C Git).""" if sys.platform == "win32": # support .exe, .bat and .cmd try: # to avoid overhead import win32api except ImportError: # run through cmd.exe with some overhead return ["cmd", "/c", "git"] else: status, git = win32api.FindExecutable("git") return [git] else: return ["git"] class SubprocessGitClient(TraditionalGitClient): """Git client that talks to a server using a subprocess.""" @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(**kwargs) git_command = None def _connect(self, service, path): if not isinstance(service, bytes): raise TypeError(service) if isinstance(path, bytes): path = path.decode(self._remote_path_encoding) if self.git_command is None: git_command = find_git_command() argv = git_command + [service.decode("ascii"), path] p = subprocess.Popen( argv, bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) pw = SubprocessWrapper(p) return ( Protocol( pw.read, pw.write, pw.close, report_activity=self._report_activity, ), pw.can_read, p.stderr, ) class LocalGitClient(GitClient): """Git Client that just uses a local Repo.""" def __init__(self, thin_packs=True, report_activity=None, config=None): """Create a new LocalGitClient instance. Args: thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. """ self._report_activity = report_activity # Ignore the thin_packs argument def get_url(self, path): return urlunsplit(("file", "", path, "", "")) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(**kwargs) @classmethod def _open_repo(cls, path): from dulwich.repo import Repo if not isinstance(path, str): path = os.fsdecode(path) return closing(Repo(path)) def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) with number of items and pack data to upload. progress: Optional progress function Returns: SendPackResult Raises: SendPackError: if server rejects the pack data """ if not progress: def progress(x): pass with self._open_repo(path) as target: old_refs = target.get_refs() new_refs = update_refs(dict(old_refs)) have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA] want = [] for refname, new_sha1 in new_refs.items(): if ( new_sha1 not in have and new_sha1 not in want and new_sha1 != ZERO_SHA ): want.append(new_sha1) if not want and set(new_refs.items()).issubset(set(old_refs.items())): return SendPackResult(new_refs, ref_status={}) target.object_store.add_pack_data( *generate_pack_data(have, want, ofs_delta=True) ) ref_status = {} for refname, new_sha1 in new_refs.items(): old_sha1 = old_refs.get(refname, ZERO_SHA) if new_sha1 != ZERO_SHA: if not target.refs.set_if_equals(refname, old_sha1, new_sha1): msg = "unable to set %s to %s" % (refname, new_sha1) progress(msg) ref_status[refname] = msg else: if not target.refs.remove_if_equals(refname, old_sha1): progress("unable to remove %s" % refname) ref_status[refname] = "unable to remove" return SendPackResult(new_refs, ref_status=ref_status) def fetch(self, path, target, determine_wants=None, progress=None, depth=None): """Fetch into a target repository. Args: path: Path to fetch from (as bytestring) target: Target repository to fetch into determine_wants: Optional function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. Defaults to all shas. progress: Optional progress function depth: Shallow fetch depth Returns: FetchPackResult object """ with self._open_repo(path) as r: refs = r.fetch( target, determine_wants=determine_wants, progress=progress, depth=depth, ) return FetchPackResult(refs, r.refs.get_symrefs(), agent_string()) def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ with self._open_repo(path) as r: objects_iter = r.fetch_objects( determine_wants, graph_walker, progress=progress, depth=depth ) symrefs = r.refs.get_symrefs() agent = agent_string() # Did the process short-circuit (e.g. in a stateless RPC call)? # Note that the client still expects a 0-object pack in most cases. if objects_iter is None: return FetchPackResult(None, symrefs, agent) protocol = ProtocolFile(None, pack_data) write_pack_objects(protocol, objects_iter) return FetchPackResult(r.get_refs(), symrefs, agent) def get_refs(self, path): """Retrieve the current refs from a git smart server.""" with self._open_repo(path) as target: return target.get_refs() # What Git client to use for local access default_local_git_client_cls = LocalGitClient class SSHVendor(object): """A client side SSH implementation.""" def connect_ssh( self, host, command, username=None, port=None, password=None, key_filename=None, ): # This function was deprecated in 0.9.1 import warnings warnings.warn( "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command", DeprecationWarning, ) return self.run_command( host, command, username=username, port=port, password=password, key_filename=key_filename, ) def run_command( self, host, command, username=None, port=None, password=None, key_filename=None, ): """Connect to an SSH server. Run a command remotely and return a file-like object for interaction with the remote command. Args: host: Host name command: Command to run (as argv array) username: Optional ame of user to log in as port: Optional SSH port to use password: Optional ssh password for login or private key key_filename: Optional path to private keyfile Returns: """ raise NotImplementedError(self.run_command) class StrangeHostname(Exception): """Refusing to connect to strange SSH hostname.""" def __init__(self, hostname): super(StrangeHostname, self).__init__(hostname) class SubprocessSSHVendor(SSHVendor): """SSH vendor that shells out to the local 'ssh' command.""" def run_command( self, host, command, username=None, port=None, password=None, key_filename=None, ): if password is not None: raise NotImplementedError( "Setting password not supported by SubprocessSSHVendor." ) args = ["ssh", "-x"] if port: args.extend(["-p", str(port)]) if key_filename: args.extend(["-i", str(key_filename)]) if username: host = "%s@%s" % (username, host) if host.startswith("-"): raise StrangeHostname(hostname=host) args.append(host) proc = subprocess.Popen( args + [command], bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) return SubprocessWrapper(proc) class PLinkSSHVendor(SSHVendor): """SSH vendor that shells out to the local 'plink' command.""" def run_command( self, host, command, username=None, port=None, password=None, key_filename=None, ): if sys.platform == "win32": args = ["plink.exe", "-ssh"] else: args = ["plink", "-ssh"] if password is not None: import warnings warnings.warn( "Invoking PLink with a password exposes the password in the " "process list." ) args.extend(["-pw", str(password)]) if port: args.extend(["-P", str(port)]) if key_filename: args.extend(["-i", str(key_filename)]) if username: host = "%s@%s" % (username, host) if host.startswith("-"): raise StrangeHostname(hostname=host) args.append(host) proc = subprocess.Popen( args + [command], bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) return SubprocessWrapper(proc) def ParamikoSSHVendor(**kwargs): import warnings warnings.warn( "ParamikoSSHVendor has been moved to dulwich.contrib.paramiko_vendor.", DeprecationWarning, ) from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor return ParamikoSSHVendor(**kwargs) # Can be overridden by users get_ssh_vendor = SubprocessSSHVendor class SSHGitClient(TraditionalGitClient): def __init__( self, host, port=None, username=None, vendor=None, config=None, password=None, key_filename=None, **kwargs ): self.host = host self.port = port self.username = username self.password = password self.key_filename = key_filename super(SSHGitClient, self).__init__(**kwargs) self.alternative_paths = {} if vendor is not None: self.ssh_vendor = vendor else: self.ssh_vendor = get_ssh_vendor() def get_url(self, path): netloc = self.host if self.port is not None: netloc += ":%d" % self.port if self.username is not None: netloc = urlquote(self.username, "@/:") + "@" + netloc return urlunsplit(("ssh", netloc, path, "", "")) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls( host=parsedurl.hostname, port=parsedurl.port, username=parsedurl.username, **kwargs ) def _get_cmd_path(self, cmd): cmd = self.alternative_paths.get(cmd, b"git-" + cmd) assert isinstance(cmd, bytes) return cmd def _connect(self, cmd, path): if not isinstance(cmd, bytes): raise TypeError(cmd) if isinstance(path, bytes): path = path.decode(self._remote_path_encoding) if path.startswith("/~"): path = path[1:] argv = ( self._get_cmd_path(cmd).decode(self._remote_path_encoding) + " '" + path + "'" ) kwargs = {} if self.password is not None: kwargs["password"] = self.password if self.key_filename is not None: kwargs["key_filename"] = self.key_filename con = self.ssh_vendor.run_command( self.host, argv, port=self.port, username=self.username, **kwargs ) return ( Protocol( con.read, con.write, con.close, report_activity=self._report_activity, ), con.can_read, getattr(con, "stderr", None), ) def default_user_agent_string(): # Start user agent with "git/", because GitHub requires this. :-( See # https://github.com/jelmer/dulwich/issues/562 for details. return "git/dulwich/%s" % ".".join([str(x) for x in dulwich.__version__]) def default_urllib3_manager( # noqa: C901 config, pool_manager_cls=None, proxy_manager_cls=None, **override_kwargs ): """Return `urllib3` connection pool manager. Honour detected proxy configurations. Args: config: dulwich.config.ConfigDict` instance with Git configuration. kwargs: Additional arguments for urllib3.ProxyManager Returns: `pool_manager_cls` (defaults to `urllib3.ProxyManager`) instance for proxy configurations, `proxy_manager_cls` (defaults to `urllib3.PoolManager`) instance otherwise. """ proxy_server = user_agent = None ca_certs = ssl_verify = None if proxy_server is None: for proxyname in ("https_proxy", "http_proxy", "all_proxy"): proxy_server = os.environ.get(proxyname) if proxy_server is not None: break if config is not None: if proxy_server is None: try: proxy_server = config.get(b"http", b"proxy") except KeyError: pass try: user_agent = config.get(b"http", b"useragent") except KeyError: pass # TODO(jelmer): Support per-host settings try: ssl_verify = config.get_boolean(b"http", b"sslVerify") except KeyError: ssl_verify = True try: ca_certs = config.get(b"http", b"sslCAInfo") except KeyError: ca_certs = None if user_agent is None: user_agent = default_user_agent_string() headers = {"User-agent": user_agent} kwargs = {} if ssl_verify is True: kwargs["cert_reqs"] = "CERT_REQUIRED" elif ssl_verify is False: kwargs["cert_reqs"] = "CERT_NONE" else: # Default to SSL verification kwargs["cert_reqs"] = "CERT_REQUIRED" if ca_certs is not None: kwargs["ca_certs"] = ca_certs kwargs.update(override_kwargs) # Try really hard to find a SSL certificate path if "ca_certs" not in kwargs and kwargs.get("cert_reqs") != "CERT_NONE": try: import certifi except ImportError: pass else: kwargs["ca_certs"] = certifi.where() import urllib3 if proxy_server is not None: if proxy_manager_cls is None: proxy_manager_cls = urllib3.ProxyManager # `urllib3` requires a `str` object in both Python 2 and 3, while # `ConfigDict` coerces entries to `bytes` on Python 3. Compensate. if not isinstance(proxy_server, str): proxy_server = proxy_server.decode() manager = proxy_manager_cls(proxy_server, headers=headers, **kwargs) else: if pool_manager_cls is None: pool_manager_cls = urllib3.PoolManager manager = pool_manager_cls(headers=headers, **kwargs) return manager class HttpGitClient(GitClient): def __init__( self, base_url, dumb=None, pool_manager=None, config=None, username=None, password=None, **kwargs ): self._base_url = base_url.rstrip("/") + "/" self._username = username self._password = password self.dumb = dumb if pool_manager is None: self.pool_manager = default_urllib3_manager(config) else: self.pool_manager = pool_manager if username is not None: # No escaping needed: ":" is not allowed in username: # https://tools.ietf.org/html/rfc2617#section-2 credentials = "%s:%s" % (username, password) import urllib3.util basic_auth = urllib3.util.make_headers(basic_auth=credentials) self.pool_manager.headers.update(basic_auth) GitClient.__init__(self, **kwargs) def get_url(self, path): return self._get_url(path).rstrip("/") @classmethod def from_parsedurl(cls, parsedurl, **kwargs): password = parsedurl.password if password is not None: kwargs["password"] = urlunquote(password) username = parsedurl.username if username is not None: kwargs["username"] = urlunquote(username) netloc = parsedurl.hostname if parsedurl.port: netloc = "%s:%s" % (netloc, parsedurl.port) if parsedurl.username: netloc = "%s@%s" % (parsedurl.username, netloc) parsedurl = parsedurl._replace(netloc=netloc) return cls(urlunparse(parsedurl), **kwargs) def __repr__(self): return "%s(%r, dumb=%r)" % ( type(self).__name__, self._base_url, self.dumb, ) def _get_url(self, path): if not isinstance(path, str): # urllib3.util.url._encode_invalid_chars() converts the path back # to bytes using the utf-8 codec. path = path.decode("utf-8") return urljoin(self._base_url, path).rstrip("/") + "/" def _http_request(self, url, headers=None, data=None, allow_compression=False): """Perform HTTP request. Args: url: Request URL. headers: Optional custom headers to override defaults. data: Request data. allow_compression: Allow GZipped communication. Returns: Tuple (`response`, `read`), where response is an `urllib3` response object with additional `content_type` and `redirect_location` properties, and `read` is a consumable read method for the response data. """ req_headers = self.pool_manager.headers.copy() if headers is not None: req_headers.update(headers) req_headers["Pragma"] = "no-cache" if allow_compression: req_headers["Accept-Encoding"] = "gzip" else: req_headers["Accept-Encoding"] = "identity" if data is None: resp = self.pool_manager.request("GET", url, headers=req_headers) else: resp = self.pool_manager.request( "POST", url, headers=req_headers, body=data ) if resp.status == 404: raise NotGitRepository() if resp.status == 401: raise HTTPUnauthorized(resp.getheader("WWW-Authenticate"), url) if resp.status != 200: raise GitProtocolError( "unexpected http resp %d for %s" % (resp.status, url) ) # TODO: Optimization available by adding `preload_content=False` to the # request and just passing the `read` method on instead of going via # `BytesIO`, if we can guarantee that the entire response is consumed # before issuing the next to still allow for connection reuse from the # pool. read = BytesIO(resp.data).read resp.content_type = resp.getheader("Content-Type") # Check if geturl() is available (urllib3 version >= 1.23) try: resp_url = resp.geturl() except AttributeError: # get_redirect_location() is available for urllib3 >= 1.1 resp.redirect_location = resp.get_redirect_location() else: resp.redirect_location = resp_url if resp_url != url else "" return resp, read def _discover_references(self, service, base_url): assert base_url[-1] == "/" tail = "info/refs" headers = {"Accept": "*/*"} if self.dumb is not True: tail += "?service=%s" % service.decode("ascii") url = urljoin(base_url, tail) resp, read = self._http_request(url, headers, allow_compression=True) if resp.redirect_location: # Something changed (redirect!), so let's update the base URL if not resp.redirect_location.endswith(tail): raise GitProtocolError( "Redirected from URL %s to URL %s without %s" % (url, resp.redirect_location, tail) ) base_url = resp.redirect_location[: -len(tail)] try: self.dumb = not resp.content_type.startswith("application/x-git-") if not self.dumb: proto = Protocol(read, None) # The first line should mention the service try: [pkt] = list(proto.read_pkt_seq()) except ValueError: raise GitProtocolError("unexpected number of packets received") if pkt.rstrip(b"\n") != (b"# service=" + service): raise GitProtocolError( "unexpected first line %r from smart server" % pkt ) return read_pkt_refs(proto) + (base_url,) else: return read_info_refs(resp), set(), base_url finally: resp.close() def _smart_request(self, service, url, data): assert url[-1] == "/" url = urljoin(url, service) result_content_type = "application/x-%s-result" % service headers = { "Content-Type": "application/x-%s-request" % service, "Accept": result_content_type, "Content-Length": str(len(data)), } resp, read = self._http_request(url, headers, data) if resp.content_type != result_content_type: raise GitProtocolError( "Invalid content-type from server: %s" % resp.content_type ) return resp, read def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receives dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of elements and pack data to upload. progress: Optional progress function Returns: SendPackResult Raises: SendPackError: if server rejects the pack data """ url = self._get_url(path) old_refs, server_capabilities, url = self._discover_references( b"git-receive-pack", url ) ( negotiated_capabilities, agent, ) = self._negotiate_receive_pack_capabilities(server_capabilities) negotiated_capabilities.add(capability_agent()) if CAPABILITY_REPORT_STATUS in negotiated_capabilities: self._report_status_parser = ReportStatusParser() new_refs = update_refs(dict(old_refs)) if new_refs is None: # Determine wants function is aborting the push. return SendPackResult(old_refs, agent=agent, ref_status={}) if set(new_refs.items()).issubset(set(old_refs.items())): return SendPackResult(new_refs, agent=agent, ref_status={}) if self.dumb: raise NotImplementedError(self.fetch_pack) req_data = BytesIO() req_proto = Protocol(None, req_data.write) (have, want) = self._handle_receive_pack_head( req_proto, negotiated_capabilities, old_refs, new_refs ) pack_data_count, pack_data = generate_pack_data( have, want, ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities), ) if self._should_send_pack(new_refs): write_pack_data(req_proto.write_file(), pack_data_count, pack_data) resp, read = self._smart_request( "git-receive-pack", url, data=req_data.getvalue() ) try: resp_proto = Protocol(read, None) ref_status = self._handle_receive_pack_tail( resp_proto, negotiated_capabilities, progress ) return SendPackResult(new_refs, agent=agent, ref_status=ref_status) finally: resp.close() def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Path to fetch from determine_wants: Callback that returns list of commits to fetch graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Depth for request Returns: FetchPackResult object """ url = self._get_url(path) refs, server_capabilities, url = self._discover_references( b"git-upload-pack", url ) ( negotiated_capabilities, symrefs, agent, ) = self._negotiate_upload_pack_capabilities(server_capabilities) wants = determine_wants(refs) if wants is not None: wants = [cid for cid in wants if cid != ZERO_SHA] if not wants: return FetchPackResult(refs, symrefs, agent) if self.dumb: raise NotImplementedError(self.fetch_pack) req_data = BytesIO() req_proto = Protocol(None, req_data.write) (new_shallow, new_unshallow) = self._handle_upload_pack_head( req_proto, negotiated_capabilities, graph_walker, wants, can_read=None, depth=depth, ) resp, read = self._smart_request( "git-upload-pack", url, data=req_data.getvalue() ) try: resp_proto = Protocol(read, None) if new_shallow is None and new_unshallow is None: (new_shallow, new_unshallow) = _read_shallow_updates(resp_proto) self._handle_upload_pack_tail( resp_proto, negotiated_capabilities, graph_walker, pack_data, progress, ) return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow) finally: resp.close() def get_refs(self, path): """Retrieve the current refs from a git smart server.""" url = self._get_url(path) refs, _, _ = self._discover_references(b"git-upload-pack", url) return refs def get_transport_and_path_from_url(url, config=None, **kwargs): """Obtain a git client from a URL. Args: url: URL to open (a unicode string) config: Optional config object thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. Returns: Tuple with client instance and relative path. """ parsed = urlparse(url) if parsed.scheme == "git": return (TCPGitClient.from_parsedurl(parsed, **kwargs), parsed.path) elif parsed.scheme in ("git+ssh", "ssh"): return SSHGitClient.from_parsedurl(parsed, **kwargs), parsed.path elif parsed.scheme in ("http", "https"): return ( HttpGitClient.from_parsedurl(parsed, config=config, **kwargs), parsed.path, ) elif parsed.scheme == "file": return ( default_local_git_client_cls.from_parsedurl(parsed, **kwargs), parsed.path, ) raise ValueError("unknown scheme '%s'" % parsed.scheme) def parse_rsync_url(location): """Parse a rsync-style URL.""" if ":" in location and "@" not in location: # SSH with no user@, zero or one leading slash. (host, path) = location.split(":", 1) user = None elif ":" in location: # SSH with user@host:foo. user_host, path = location.split(":", 1) if "@" in user_host: user, host = user_host.rsplit("@", 1) else: user = None host = user_host else: raise ValueError("not a valid rsync-style URL") return (user, host, path) def get_transport_and_path(location, **kwargs): """Obtain a git client from a URL. Args: location: URL or path (a string) config: Optional config object thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. Returns: Tuple with client instance and relative path. """ # First, try to parse it as a URL try: return get_transport_and_path_from_url(location, **kwargs) except ValueError: pass if sys.platform == "win32" and location[0].isalpha() and location[1:3] == ":\\": # Windows local path return default_local_git_client_cls(**kwargs), location try: (username, hostname, path) = parse_rsync_url(location) except ValueError: # Otherwise, assume it's a local path. return default_local_git_client_cls(**kwargs), location else: return SSHGitClient(hostname, username=username, **kwargs), path DEFAULT_GIT_CREDENTIALS_PATHS = [ os.path.expanduser("~/.git-credentials"), get_xdg_config_home_path("git", "credentials"), ] def get_credentials_from_store( scheme, hostname, username=None, fnames=DEFAULT_GIT_CREDENTIALS_PATHS ): for fname in fnames: try: with open(fname, "rb") as f: for line in f: parsed_line = urlparse(line.strip()) if ( parsed_line.scheme == scheme and parsed_line.hostname == hostname and (username is None or parsed_line.username == username) ): return parsed_line.username, parsed_line.password except FileNotFoundError: # If the file doesn't exist, try the next one. continue diff --git a/dulwich/contrib/swift.py b/dulwich/contrib/swift.py index b407ecfb..3d76acbe 100644 --- a/dulwich/contrib/swift.py +++ b/dulwich/contrib/swift.py @@ -1,1079 +1,1079 @@ # swift.py -- Repo implementation atop OpenStack SWIFT # Copyright (C) 2013 eNovance SAS # # Author: Fabien Boucher # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Repo implementation atop OpenStack SWIFT.""" # TODO: Refactor to share more code with dulwich/repo.py. # TODO(fbo): Second attempt to _send() must be notified via real log # TODO(fbo): More logs for operations import os import stat import zlib import tempfile import posixpath import urllib.parse as urlparse from io import BytesIO from configparser import ConfigParser from geventhttpclient import HTTPClient from dulwich.greenthreads import ( GreenThreadsMissingObjectFinder, GreenThreadsObjectStoreIterator, ) from dulwich.lru_cache import LRUSizeCache from dulwich.objects import ( Blob, Commit, Tree, Tag, S_ISGITLINK, ) from dulwich.object_store import ( PackBasedObjectStore, PACKDIR, INFODIR, ) from dulwich.pack import ( PackData, Pack, PackIndexer, PackStreamCopier, write_pack_header, compute_file_sha, iter_sha1, write_pack_index_v2, load_pack_index_file, read_pack_header, _compute_object_size, unpack_object, write_pack_object, ) from dulwich.protocol import TCP_GIT_PORT from dulwich.refs import ( InfoRefsContainer, read_info_refs, write_info_refs, ) from dulwich.repo import ( BaseRepo, OBJECTDIR, ) from dulwich.server import ( Backend, TCPGitServer, ) import json import sys """ # Configuration file sample [swift] # Authentication URL (Keystone or Swift) auth_url = http://127.0.0.1:5000/v2.0 # Authentication version to use auth_ver = 2 # The tenant and username separated by a semicolon username = admin;admin # The user password password = pass # The Object storage region to use (auth v2) (Default RegionOne) region_name = RegionOne # The Object storage endpoint URL to use (auth v2) (Default internalURL) endpoint_type = internalURL # Concurrency to use for parallel tasks (Default 10) concurrency = 10 # Size of the HTTP pool (Default 10) http_pool_length = 10 # Timeout delay for HTTP connections (Default 20) http_timeout = 20 # Chunk size to read from pack (Bytes) (Default 12228) chunk_length = 12228 # Cache size (MBytes) (Default 20) cache_length = 20 """ class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator): def __len__(self): - while len(self.finder.objects_to_send): + while self.finder.objects_to_send: for _ in range(0, len(self.finder.objects_to_send)): sha = self.finder.next() self._shas.append(sha) return len(self._shas) class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder): def next(self): while True: if not self.objects_to_send: return None (sha, name, leaf) = self.objects_to_send.pop() if sha not in self.sha_done: break if not leaf: info = self.object_store.pack_info_get(sha) if info[0] == Commit.type_num: self.add_todo([(info[2], "", False)]) elif info[0] == Tree.type_num: self.add_todo([tuple(i) for i in info[1]]) elif info[0] == Tag.type_num: self.add_todo([(info[1], None, False)]) if sha in self._tagged: self.add_todo([(self._tagged[sha], None, True)]) self.sha_done.add(sha) self.progress("counting objects: %d\r" % len(self.sha_done)) return (sha, name) def load_conf(path=None, file=None): """Load configuration in global var CONF Args: path: The path to the configuration file file: If provided read instead the file like object """ conf = ConfigParser() if file: try: conf.read_file(file, path) except AttributeError: # read_file only exists in Python3 conf.readfp(file) return conf confpath = None if not path: try: confpath = os.environ["DULWICH_SWIFT_CFG"] except KeyError: raise Exception("You need to specify a configuration file") else: confpath = path if not os.path.isfile(confpath): raise Exception("Unable to read configuration file %s" % confpath) conf.read(confpath) return conf def swift_load_pack_index(scon, filename): """Read a pack index file from Swift Args: scon: a `SwiftConnector` instance filename: Path to the index file objectise Returns: a `PackIndexer` instance """ with scon.get_object(filename) as f: return load_pack_index_file(filename, f) def pack_info_create(pack_data, pack_index): pack = Pack.from_objects(pack_data, pack_index) info = {} for obj in pack.iterobjects(): # Commit if obj.type_num == Commit.type_num: info[obj.id] = (obj.type_num, obj.parents, obj.tree) # Tree elif obj.type_num == Tree.type_num: shas = [ (s, n, not stat.S_ISDIR(m)) for n, m, s in obj.items() if not S_ISGITLINK(m) ] info[obj.id] = (obj.type_num, shas) # Blob elif obj.type_num == Blob.type_num: info[obj.id] = None # Tag elif obj.type_num == Tag.type_num: info[obj.id] = (obj.type_num, obj.object[1]) return zlib.compress(json.dumps(info)) def load_pack_info(filename, scon=None, file=None): if not file: f = scon.get_object(filename) else: f = file if not f: return None try: return json.loads(zlib.decompress(f.read())) finally: f.close() class SwiftException(Exception): pass class SwiftConnector(object): """A Connector to swift that manage authentication and errors catching""" def __init__(self, root, conf): """Initialize a SwiftConnector Args: root: The swift container that will act as Git bare repository conf: A ConfigParser Object """ self.conf = conf self.auth_ver = self.conf.get("swift", "auth_ver") if self.auth_ver not in ["1", "2"]: raise NotImplementedError("Wrong authentication version use either 1 or 2") self.auth_url = self.conf.get("swift", "auth_url") self.user = self.conf.get("swift", "username") self.password = self.conf.get("swift", "password") self.concurrency = self.conf.getint("swift", "concurrency") or 10 self.http_timeout = self.conf.getint("swift", "http_timeout") or 20 self.http_pool_length = self.conf.getint("swift", "http_pool_length") or 10 self.region_name = self.conf.get("swift", "region_name") or "RegionOne" self.endpoint_type = self.conf.get("swift", "endpoint_type") or "internalURL" self.cache_length = self.conf.getint("swift", "cache_length") or 20 self.chunk_length = self.conf.getint("swift", "chunk_length") or 12228 self.root = root block_size = 1024 * 12 # 12KB if self.auth_ver == "1": self.storage_url, self.token = self.swift_auth_v1() else: self.storage_url, self.token = self.swift_auth_v2() token_header = {"X-Auth-Token": str(self.token)} self.httpclient = HTTPClient.from_url( str(self.storage_url), concurrency=self.http_pool_length, block_size=block_size, connection_timeout=self.http_timeout, network_timeout=self.http_timeout, headers=token_header, ) self.base_path = str( posixpath.join(urlparse.urlparse(self.storage_url).path, self.root) ) def swift_auth_v1(self): self.user = self.user.replace(";", ":") auth_httpclient = HTTPClient.from_url( self.auth_url, connection_timeout=self.http_timeout, network_timeout=self.http_timeout, ) headers = {"X-Auth-User": self.user, "X-Auth-Key": self.password} path = urlparse.urlparse(self.auth_url).path ret = auth_httpclient.request("GET", path, headers=headers) # Should do something with redirections (301 in my case) if ret.status_code < 200 or ret.status_code >= 300: raise SwiftException( "AUTH v1.0 request failed on " + "%s with error code %s (%s)" % ( str(auth_httpclient.get_base_url()) + path, ret.status_code, str(ret.items()), ) ) storage_url = ret["X-Storage-Url"] token = ret["X-Auth-Token"] return storage_url, token def swift_auth_v2(self): self.tenant, self.user = self.user.split(";") auth_dict = {} auth_dict["auth"] = { "passwordCredentials": { "username": self.user, "password": self.password, }, "tenantName": self.tenant, } auth_json = json.dumps(auth_dict) headers = {"Content-Type": "application/json"} auth_httpclient = HTTPClient.from_url( self.auth_url, connection_timeout=self.http_timeout, network_timeout=self.http_timeout, ) path = urlparse.urlparse(self.auth_url).path if not path.endswith("tokens"): path = posixpath.join(path, "tokens") ret = auth_httpclient.request("POST", path, body=auth_json, headers=headers) if ret.status_code < 200 or ret.status_code >= 300: raise SwiftException( "AUTH v2.0 request failed on " + "%s with error code %s (%s)" % ( str(auth_httpclient.get_base_url()) + path, ret.status_code, str(ret.items()), ) ) auth_ret_json = json.loads(ret.read()) token = auth_ret_json["access"]["token"]["id"] catalogs = auth_ret_json["access"]["serviceCatalog"] object_store = [ o_store for o_store in catalogs if o_store["type"] == "object-store" ][0] endpoints = object_store["endpoints"] endpoint = [endp for endp in endpoints if endp["region"] == self.region_name][0] return endpoint[self.endpoint_type], token def test_root_exists(self): """Check that Swift container exist Returns: True if exist or None it not """ ret = self.httpclient.request("HEAD", self.base_path) if ret.status_code == 404: return None if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "HEAD request failed with error code %s" % ret.status_code ) return True def create_root(self): """Create the Swift container Raises: SwiftException: if unable to create """ if not self.test_root_exists(): ret = self.httpclient.request("PUT", self.base_path) if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "PUT request failed with error code %s" % ret.status_code ) def get_container_objects(self): """Retrieve objects list in a container Returns: A list of dict that describe objects or None if container does not exist """ qs = "?format=json" path = self.base_path + qs ret = self.httpclient.request("GET", path) if ret.status_code == 404: return None if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "GET request failed with error code %s" % ret.status_code ) content = ret.read() return json.loads(content) def get_object_stat(self, name): """Retrieve object stat Args: name: The object name Returns: A dict that describe the object or None if object does not exist """ path = self.base_path + "/" + name ret = self.httpclient.request("HEAD", path) if ret.status_code == 404: return None if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "HEAD request failed with error code %s" % ret.status_code ) resp_headers = {} for header, value in ret.items(): resp_headers[header.lower()] = value return resp_headers def put_object(self, name, content): """Put an object Args: name: The object name content: A file object Raises: SwiftException: if unable to create """ content.seek(0) data = content.read() path = self.base_path + "/" + name headers = {"Content-Length": str(len(data))} def _send(): ret = self.httpclient.request("PUT", path, body=data, headers=headers) return ret try: # Sometime got Broken Pipe - Dirty workaround ret = _send() except Exception: # Second attempt work ret = _send() if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "PUT request failed with error code %s" % ret.status_code ) def get_object(self, name, range=None): """Retrieve an object Args: name: The object name range: A string range like "0-10" to retrieve specified bytes in object content Returns: A file like instance or bytestring if range is specified """ headers = {} if range: headers["Range"] = "bytes=%s" % range path = self.base_path + "/" + name ret = self.httpclient.request("GET", path, headers=headers) if ret.status_code == 404: return None if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "GET request failed with error code %s" % ret.status_code ) content = ret.read() if range: return content return BytesIO(content) def del_object(self, name): """Delete an object Args: name: The object name Raises: SwiftException: if unable to delete """ path = self.base_path + "/" + name ret = self.httpclient.request("DELETE", path) if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "DELETE request failed with error code %s" % ret.status_code ) def del_root(self): """Delete the root container by removing container content Raises: SwiftException: if unable to delete """ for obj in self.get_container_objects(): self.del_object(obj["name"]) ret = self.httpclient.request("DELETE", self.base_path) if ret.status_code < 200 or ret.status_code > 300: raise SwiftException( "DELETE request failed with error code %s" % ret.status_code ) class SwiftPackReader(object): """A SwiftPackReader that mimic read and sync method The reader allows to read a specified amount of bytes from a given offset of a Swift object. A read offset is kept internaly. The reader will read from Swift a specified amount of data to complete its internal buffer. chunk_length specifiy the amount of data to read from Swift. """ def __init__(self, scon, filename, pack_length): """Initialize a SwiftPackReader Args: scon: a `SwiftConnector` instance filename: the pack filename pack_length: The size of the pack object """ self.scon = scon self.filename = filename self.pack_length = pack_length self.offset = 0 self.base_offset = 0 self.buff = b"" self.buff_length = self.scon.chunk_length def _read(self, more=False): if more: self.buff_length = self.buff_length * 2 offset = self.base_offset r = min(self.base_offset + self.buff_length, self.pack_length) ret = self.scon.get_object(self.filename, range="%s-%s" % (offset, r)) self.buff = ret def read(self, length): """Read a specified amount of Bytes form the pack object Args: length: amount of bytes to read Returns: a bytestring """ end = self.offset + length if self.base_offset + end > self.pack_length: data = self.buff[self.offset :] self.offset = end return data if end > len(self.buff): # Need to read more from swift self._read(more=True) return self.read(length) data = self.buff[self.offset : end] self.offset = end return data def seek(self, offset): """Seek to a specified offset Args: offset: the offset to seek to """ self.base_offset = offset self._read() self.offset = 0 def read_checksum(self): """Read the checksum from the pack Returns: the checksum bytestring """ return self.scon.get_object(self.filename, range="-20") class SwiftPackData(PackData): """The data contained in a packfile. We use the SwiftPackReader to read bytes from packs stored in Swift using the Range header feature of Swift. """ def __init__(self, scon, filename): """Initialize a SwiftPackReader Args: scon: a `SwiftConnector` instance filename: the pack filename """ self.scon = scon self._filename = filename self._header_size = 12 headers = self.scon.get_object_stat(self._filename) self.pack_length = int(headers["content-length"]) pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length) (version, self._num_objects) = read_pack_header(pack_reader.read) self._offset_cache = LRUSizeCache( 1024 * 1024 * self.scon.cache_length, compute_size=_compute_object_size, ) self.pack = None def get_object_at(self, offset): if offset in self._offset_cache: return self._offset_cache[offset] assert offset >= self._header_size pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length) pack_reader.seek(offset) unpacked, _ = unpack_object(pack_reader.read) return (unpacked.pack_type_num, unpacked._obj()) def get_stored_checksum(self): pack_reader = SwiftPackReader(self.scon, self._filename, self.pack_length) return pack_reader.read_checksum() def close(self): pass class SwiftPack(Pack): """A Git pack object. Same implementation as pack.Pack except that _idx_load and _data_load are bounded to Swift version of load_pack_index and PackData. """ def __init__(self, *args, **kwargs): self.scon = kwargs["scon"] del kwargs["scon"] super(SwiftPack, self).__init__(*args, **kwargs) self._pack_info_path = self._basename + ".info" self._pack_info = None self._pack_info_load = lambda: load_pack_info(self._pack_info_path, self.scon) self._idx_load = lambda: swift_load_pack_index(self.scon, self._idx_path) self._data_load = lambda: SwiftPackData(self.scon, self._data_path) @property def pack_info(self): """The pack data object being used.""" if self._pack_info is None: self._pack_info = self._pack_info_load() return self._pack_info class SwiftObjectStore(PackBasedObjectStore): """A Swift Object Store Allow to manage a bare Git repository from Openstack Swift. This object store only supports pack files and not loose objects. """ def __init__(self, scon): """Open a Swift object store. Args: scon: A `SwiftConnector` instance """ super(SwiftObjectStore, self).__init__() self.scon = scon self.root = self.scon.root self.pack_dir = posixpath.join(OBJECTDIR, PACKDIR) self._alternates = None def _update_pack_cache(self): objects = self.scon.get_container_objects() pack_files = [ o["name"].replace(".pack", "") for o in objects if o["name"].endswith(".pack") ] ret = [] for basename in pack_files: pack = SwiftPack(basename, scon=self.scon) self._pack_cache[basename] = pack ret.append(pack) return ret def _iter_loose_objects(self): """Loose objects are not supported by this repository""" return [] def iter_shas(self, finder): """An iterator over pack's ObjectStore. Returns: a `ObjectStoreIterator` or `GreenThreadsObjectStoreIterator` instance if gevent is enabled """ shas = iter(finder.next, None) return PackInfoObjectStoreIterator(self, shas, finder, self.scon.concurrency) def find_missing_objects(self, *args, **kwargs): kwargs["concurrency"] = self.scon.concurrency return PackInfoMissingObjectFinder(self, *args, **kwargs) def pack_info_get(self, sha): for pack in self.packs: if sha in pack: return pack.pack_info[sha] def _collect_ancestors(self, heads, common=set()): def _find_parents(commit): for pack in self.packs: if commit in pack: try: parents = pack.pack_info[commit][1] except KeyError: # Seems to have no parents return [] return parents bases = set() commits = set() queue = [] queue.extend(heads) while queue: e = queue.pop(0) if e in common: bases.add(e) elif e not in commits: commits.add(e) parents = _find_parents(e) queue.extend(parents) return (commits, bases) def add_pack(self): """Add a new pack to this object store. Returns: Fileobject to write to and a commit function to call when the pack is finished. """ f = BytesIO() def commit(): f.seek(0) pack = PackData(file=f, filename="") entries = pack.sorted_entries() - if len(entries): + if entries: basename = posixpath.join( self.pack_dir, "pack-%s" % iter_sha1(entry[0] for entry in entries), ) index = BytesIO() write_pack_index_v2(index, entries, pack.get_stored_checksum()) self.scon.put_object(basename + ".pack", f) f.close() self.scon.put_object(basename + ".idx", index) index.close() final_pack = SwiftPack(basename, scon=self.scon) final_pack.check_length_and_checksum() self._add_cached_pack(basename, final_pack) return final_pack else: return None def abort(): pass return f, commit, abort def add_object(self, obj): self.add_objects( [ (obj, None), ] ) def _pack_cache_stale(self): return False def _get_loose_object(self, sha): return None def add_thin_pack(self, read_all, read_some): """Read a thin pack Read it from a stream and complete it in a temporary file. Then the pack and the corresponding index file are uploaded to Swift. """ fd, path = tempfile.mkstemp(prefix="tmp_pack_") f = os.fdopen(fd, "w+b") try: indexer = PackIndexer(f, resolve_ext_ref=self.get_raw) copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer) copier.verify() return self._complete_thin_pack(f, path, copier, indexer) finally: f.close() os.unlink(path) def _complete_thin_pack(self, f, path, copier, indexer): entries = list(indexer) # Update the header with the new number of objects. f.seek(0) write_pack_header(f, len(entries) + len(indexer.ext_refs())) # Must flush before reading (http://bugs.python.org/issue3207) f.flush() # Rescan the rest of the pack, computing the SHA with the new header. new_sha = compute_file_sha(f, end_ofs=-20) # Must reposition before writing (http://bugs.python.org/issue3207) f.seek(0, os.SEEK_CUR) # Complete the pack. for ext_sha in indexer.ext_refs(): assert len(ext_sha) == 20 type_num, data = self.get_raw(ext_sha) offset = f.tell() crc32 = write_pack_object(f, type_num, data, sha=new_sha) entries.append((ext_sha, offset, crc32)) pack_sha = new_sha.digest() f.write(pack_sha) f.flush() # Move the pack in. entries.sort() pack_base_name = posixpath.join( self.pack_dir, "pack-" + os.fsdecode(iter_sha1(e[0] for e in entries)), ) self.scon.put_object(pack_base_name + ".pack", f) # Write the index. filename = pack_base_name + ".idx" index_file = BytesIO() write_pack_index_v2(index_file, entries, pack_sha) self.scon.put_object(filename, index_file) # Write pack info. f.seek(0) pack_data = PackData(filename="", file=f) index_file.seek(0) pack_index = load_pack_index_file("", index_file) serialized_pack_info = pack_info_create(pack_data, pack_index) f.close() index_file.close() pack_info_file = BytesIO(serialized_pack_info) filename = pack_base_name + ".info" self.scon.put_object(filename, pack_info_file) pack_info_file.close() # Add the pack to the store and return it. final_pack = SwiftPack(pack_base_name, scon=self.scon) final_pack.check_length_and_checksum() self._add_cached_pack(pack_base_name, final_pack) return final_pack class SwiftInfoRefsContainer(InfoRefsContainer): """Manage references in info/refs object.""" def __init__(self, scon, store): self.scon = scon self.filename = "info/refs" self.store = store f = self.scon.get_object(self.filename) if not f: f = BytesIO(b"") super(SwiftInfoRefsContainer, self).__init__(f) def _load_check_ref(self, name, old_ref): self._check_refname(name) f = self.scon.get_object(self.filename) if not f: return {} refs = read_info_refs(f) if old_ref is not None: if refs[name] != old_ref: return False return refs def _write_refs(self, refs): f = BytesIO() f.writelines(write_info_refs(refs, self.store)) self.scon.put_object(self.filename, f) def set_if_equals(self, name, old_ref, new_ref): """Set a refname to new_ref only if it currently equals old_ref.""" if name == "HEAD": return True refs = self._load_check_ref(name, old_ref) if not isinstance(refs, dict): return False refs[name] = new_ref self._write_refs(refs) self._refs[name] = new_ref return True def remove_if_equals(self, name, old_ref): """Remove a refname only if it currently equals old_ref.""" if name == "HEAD": return True refs = self._load_check_ref(name, old_ref) if not isinstance(refs, dict): return False del refs[name] self._write_refs(refs) del self._refs[name] return True def allkeys(self): try: self._refs["HEAD"] = self._refs["refs/heads/master"] except KeyError: pass return self._refs.keys() class SwiftRepo(BaseRepo): def __init__(self, root, conf): """Init a Git bare Repository on top of a Swift container. References are managed in info/refs objects by `SwiftInfoRefsContainer`. The root attribute is the Swift container that contain the Git bare repository. Args: root: The container which contains the bare repo conf: A ConfigParser object """ self.root = root.lstrip("/") self.conf = conf self.scon = SwiftConnector(self.root, self.conf) objects = self.scon.get_container_objects() if not objects: raise Exception("There is not any GIT repo here : %s" % self.root) objects = [o["name"].split("/")[0] for o in objects] if OBJECTDIR not in objects: raise Exception("This repository (%s) is not bare." % self.root) self.bare = True self._controldir = self.root object_store = SwiftObjectStore(self.scon) refs = SwiftInfoRefsContainer(self.scon, object_store) BaseRepo.__init__(self, object_store, refs) def _determine_file_mode(self): """Probe the file-system to determine whether permissions can be trusted. Returns: True if permissions can be trusted, False otherwise. """ return False def _put_named_file(self, filename, contents): """Put an object in a Swift container Args: filename: the path to the object to put on Swift contents: the content as bytestring """ with BytesIO() as f: f.write(contents) self.scon.put_object(filename, f) @classmethod def init_bare(cls, scon, conf): """Create a new bare repository. Args: scon: a `SwiftConnector` instance conf: a ConfigParser object Returns: a `SwiftRepo` instance """ scon.create_root() for obj in [ posixpath.join(OBJECTDIR, PACKDIR), posixpath.join(INFODIR, "refs"), ]: scon.put_object(obj, BytesIO(b"")) ret = cls(scon.root, conf) ret._init_files(True) return ret class SwiftSystemBackend(Backend): def __init__(self, logger, conf): self.conf = conf self.logger = logger def open_repository(self, path): self.logger.info("opening repository at %s", path) return SwiftRepo(path, self.conf) def cmd_daemon(args): """Entry point for starting a TCP git server.""" import optparse parser = optparse.OptionParser() parser.add_option( "-l", "--listen_address", dest="listen_address", default="127.0.0.1", help="Binding IP address.", ) parser.add_option( "-p", "--port", dest="port", type=int, default=TCP_GIT_PORT, help="Binding TCP port.", ) parser.add_option( "-c", "--swift_config", dest="swift_config", default="", help="Path to the configuration file for Swift backend.", ) options, args = parser.parse_args(args) try: import gevent import geventhttpclient # noqa: F401 except ImportError: print( "gevent and geventhttpclient libraries are mandatory " " for use the Swift backend." ) sys.exit(1) import gevent.monkey gevent.monkey.patch_socket() from dulwich import log_utils logger = log_utils.getLogger(__name__) conf = load_conf(options.swift_config) backend = SwiftSystemBackend(logger, conf) log_utils.default_logging_config() server = TCPGitServer(backend, options.listen_address, port=options.port) server.serve_forever() def cmd_init(args): import optparse parser = optparse.OptionParser() parser.add_option( "-c", "--swift_config", dest="swift_config", default="", help="Path to the configuration file for Swift backend.", ) options, args = parser.parse_args(args) conf = load_conf(options.swift_config) if args == []: parser.error("missing repository name") repo = args[0] scon = SwiftConnector(repo, conf) SwiftRepo.init_bare(scon, conf) def main(argv=sys.argv): commands = { "init": cmd_init, "daemon": cmd_daemon, } if len(sys.argv) < 2: print("Usage: %s <%s> [OPTIONS...]" % (sys.argv[0], "|".join(commands.keys()))) sys.exit(1) cmd = sys.argv[1] if cmd not in commands: print("No such subcommand: %s" % cmd) sys.exit(1) commands[cmd](sys.argv[2:]) if __name__ == "__main__": main() diff --git a/dulwich/contrib/test_swift.py b/dulwich/contrib/test_swift.py index 35f6ba7c..8c44edd6 100644 --- a/dulwich/contrib/test_swift.py +++ b/dulwich/contrib/test_swift.py @@ -1,500 +1,500 @@ # test_swift.py -- Unittests for the Swift backend. # Copyright (C) 2013 eNovance SAS # # Author: Fabien Boucher # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for dulwich.contrib.swift.""" import posixpath from time import time from io import BytesIO, StringIO from unittest import skipIf from dulwich.tests import ( TestCase, ) from dulwich.tests.test_object_store import ( ObjectStoreTests, ) from dulwich.objects import ( Blob, Commit, Tree, Tag, parse_timezone, ) import json missing_libs = [] try: import gevent # noqa:F401 except ImportError: missing_libs.append("gevent") try: import geventhttpclient # noqa:F401 except ImportError: missing_libs.append("geventhttpclient") try: from unittest.mock import patch except ImportError: missing_libs.append("mock") skipmsg = "Required libraries are not installed (%r)" % missing_libs if not missing_libs: from dulwich.contrib import swift config_file = """[swift] auth_url = http://127.0.0.1:8080/auth/%(version_str)s auth_ver = %(version_int)s username = test;tester password = testing region_name = %(region_name)s endpoint_type = %(endpoint_type)s concurrency = %(concurrency)s chunk_length = %(chunk_length)s cache_length = %(cache_length)s http_pool_length = %(http_pool_length)s http_timeout = %(http_timeout)s """ def_config_file = { "version_str": "v1.0", "version_int": 1, "concurrency": 1, "chunk_length": 12228, "cache_length": 1, "region_name": "test", "endpoint_type": "internalURL", "http_pool_length": 1, "http_timeout": 1, } def create_swift_connector(store={}): return lambda root, conf: FakeSwiftConnector(root, conf=conf, store=store) class Response(object): def __init__(self, headers={}, status=200, content=None): self.headers = headers self.status_code = status self.content = content def __getitem__(self, key): return self.headers[key] def items(self): return self.headers.items() def read(self): return self.content def fake_auth_request_v1(*args, **kwargs): ret = Response( { "X-Storage-Url": "http://127.0.0.1:8080/v1.0/AUTH_fakeuser", "X-Auth-Token": "12" * 10, }, 200, ) return ret def fake_auth_request_v1_error(*args, **kwargs): ret = Response({}, 401) return ret def fake_auth_request_v2(*args, **kwargs): s_url = "http://127.0.0.1:8080/v1.0/AUTH_fakeuser" resp = { "access": { "token": {"id": "12" * 10}, "serviceCatalog": [ { "type": "object-store", "endpoints": [ { "region": "test", "internalURL": s_url, }, ], }, ], } } ret = Response(status=200, content=json.dumps(resp)) return ret def create_commit(data, marker=b"Default", blob=None): if not blob: blob = Blob.from_string(b"The blob content " + marker) tree = Tree() tree.add(b"thefile_" + marker, 0o100644, blob.id) cmt = Commit() if data: assert isinstance(data[-1], Commit) cmt.parents = [data[-1].id] cmt.tree = tree.id author = b"John Doe " + marker + b" " cmt.author = cmt.committer = author tz = parse_timezone(b"-0200")[0] cmt.commit_time = cmt.author_time = int(time()) cmt.commit_timezone = cmt.author_timezone = tz cmt.encoding = b"UTF-8" cmt.message = b"The commit message " + marker tag = Tag() tag.tagger = b"john@doe.net" tag.message = b"Annotated tag" tag.tag_timezone = parse_timezone(b"-0200")[0] tag.tag_time = cmt.author_time tag.object = (Commit, cmt.id) tag.name = b"v_" + marker + b"_0.1" return blob, tree, tag, cmt def create_commits(length=1, marker=b"Default"): data = [] for i in range(0, length): _marker = ("%s_%s" % (marker, i)).encode() blob, tree, tag, cmt = create_commit(data, _marker) data.extend([blob, tree, tag, cmt]) return data @skipIf(missing_libs, skipmsg) class FakeSwiftConnector(object): def __init__(self, root, conf, store=None): if store: self.store = store else: self.store = {} self.conf = conf self.root = root self.concurrency = 1 self.chunk_length = 12228 self.cache_length = 1 def put_object(self, name, content): name = posixpath.join(self.root, name) if hasattr(content, "seek"): content.seek(0) content = content.read() self.store[name] = content def get_object(self, name, range=None): name = posixpath.join(self.root, name) if not range: try: return BytesIO(self.store[name]) except KeyError: return None else: l, r = range.split("-") try: if not l: r = -int(r) return self.store[name][r:] else: return self.store[name][int(l) : int(r)] except KeyError: return None def get_container_objects(self): return [{"name": k.replace(self.root + "/", "")} for k in self.store] def create_root(self): if self.root in self.store.keys(): pass else: self.store[self.root] = "" def get_object_stat(self, name): name = posixpath.join(self.root, name) if name not in self.store: return None return {"content-length": len(self.store[name])} @skipIf(missing_libs, skipmsg) class TestSwiftRepo(TestCase): def setUp(self): super(TestSwiftRepo, self).setUp() self.conf = swift.load_conf(file=StringIO(config_file % def_config_file)) def test_init(self): store = {"fakerepo/objects/pack": ""} with patch( "dulwich.contrib.swift.SwiftConnector", new_callable=create_swift_connector, store=store, ): swift.SwiftRepo("fakerepo", conf=self.conf) def test_init_no_data(self): with patch( "dulwich.contrib.swift.SwiftConnector", new_callable=create_swift_connector, ): self.assertRaises(Exception, swift.SwiftRepo, "fakerepo", self.conf) def test_init_bad_data(self): store = {"fakerepo/.git/objects/pack": ""} with patch( "dulwich.contrib.swift.SwiftConnector", new_callable=create_swift_connector, store=store, ): self.assertRaises(Exception, swift.SwiftRepo, "fakerepo", self.conf) def test_put_named_file(self): store = {"fakerepo/objects/pack": ""} with patch( "dulwich.contrib.swift.SwiftConnector", new_callable=create_swift_connector, store=store, ): repo = swift.SwiftRepo("fakerepo", conf=self.conf) desc = b"Fake repo" repo._put_named_file("description", desc) self.assertEqual(repo.scon.store["fakerepo/description"], desc) def test_init_bare(self): fsc = FakeSwiftConnector("fakeroot", conf=self.conf) with patch( "dulwich.contrib.swift.SwiftConnector", new_callable=create_swift_connector, store=fsc.store, ): swift.SwiftRepo.init_bare(fsc, conf=self.conf) self.assertIn("fakeroot/objects/pack", fsc.store) self.assertIn("fakeroot/info/refs", fsc.store) self.assertIn("fakeroot/description", fsc.store) @skipIf(missing_libs, skipmsg) class TestSwiftInfoRefsContainer(TestCase): def setUp(self): super(TestSwiftInfoRefsContainer, self).setUp() content = ( b"22effb216e3a82f97da599b8885a6cadb488b4c5\trefs/heads/master\n" b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev" ) self.store = {"fakerepo/info/refs": content} self.conf = swift.load_conf(file=StringIO(config_file % def_config_file)) self.fsc = FakeSwiftConnector("fakerepo", conf=self.conf) self.object_store = {} def test_init(self): """info/refs does not exists""" irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) self.assertEqual(len(irc._refs), 0) self.fsc.store = self.store irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) self.assertIn(b"refs/heads/dev", irc.allkeys()) self.assertIn(b"refs/heads/master", irc.allkeys()) def test_set_if_equals(self): self.fsc.store = self.store irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) irc.set_if_equals( b"refs/heads/dev", b"cca703b0e1399008b53a1a236d6b4584737649e4", b"1" * 40, ) self.assertEqual(irc[b"refs/heads/dev"], b"1" * 40) def test_remove_if_equals(self): self.fsc.store = self.store irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) irc.remove_if_equals( b"refs/heads/dev", b"cca703b0e1399008b53a1a236d6b4584737649e4" ) self.assertNotIn(b"refs/heads/dev", irc.allkeys()) @skipIf(missing_libs, skipmsg) class TestSwiftConnector(TestCase): def setUp(self): super(TestSwiftConnector, self).setUp() self.conf = swift.load_conf(file=StringIO(config_file % def_config_file)) with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1): self.conn = swift.SwiftConnector("fakerepo", conf=self.conf) def test_init_connector(self): self.assertEqual(self.conn.auth_ver, "1") self.assertEqual(self.conn.auth_url, "http://127.0.0.1:8080/auth/v1.0") self.assertEqual(self.conn.user, "test:tester") self.assertEqual(self.conn.password, "testing") self.assertEqual(self.conn.root, "fakerepo") self.assertEqual( self.conn.storage_url, "http://127.0.0.1:8080/v1.0/AUTH_fakeuser" ) self.assertEqual(self.conn.token, "12" * 10) self.assertEqual(self.conn.http_timeout, 1) self.assertEqual(self.conn.http_pool_length, 1) self.assertEqual(self.conn.concurrency, 1) self.conf.set("swift", "auth_ver", "2") self.conf.set("swift", "auth_url", "http://127.0.0.1:8080/auth/v2.0") with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v2): conn = swift.SwiftConnector("fakerepo", conf=self.conf) self.assertEqual(conn.user, "tester") self.assertEqual(conn.tenant, "test") self.conf.set("swift", "auth_ver", "1") self.conf.set("swift", "auth_url", "http://127.0.0.1:8080/auth/v1.0") with patch("geventhttpclient.HTTPClient.request", fake_auth_request_v1_error): self.assertRaises( swift.SwiftException, lambda: swift.SwiftConnector("fakerepo", conf=self.conf), ) def test_root_exists(self): with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()): self.assertEqual(self.conn.test_root_exists(), True) def test_root_not_exists(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response(status=404), ): self.assertEqual(self.conn.test_root_exists(), None) def test_create_root(self): with patch( "dulwich.contrib.swift.SwiftConnector.test_root_exists", lambda *args: None, ): with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()): self.assertEqual(self.conn.create_root(), None) def test_create_root_fails(self): with patch( "dulwich.contrib.swift.SwiftConnector.test_root_exists", lambda *args: None, ): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response(status=404), ): - self.assertRaises(swift.SwiftException, lambda: self.conn.create_root()) + self.assertRaises(swift.SwiftException, self.conn.create_root) def test_get_container_objects(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response( content=json.dumps((({"name": "a"}, {"name": "b"}))) ), ): self.assertEqual(len(self.conn.get_container_objects()), 2) def test_get_container_objects_fails(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response(status=404), ): self.assertEqual(self.conn.get_container_objects(), None) def test_get_object_stat(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response(headers={"content-length": "10"}), ): self.assertEqual(self.conn.get_object_stat("a")["content-length"], "10") def test_get_object_stat_fails(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response(status=404), ): self.assertEqual(self.conn.get_object_stat("a"), None) def test_put_object(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args, **kwargs: Response(), ): self.assertEqual(self.conn.put_object("a", BytesIO(b"content")), None) def test_put_object_fails(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args, **kwargs: Response(status=400), ): self.assertRaises( swift.SwiftException, lambda: self.conn.put_object("a", BytesIO(b"content")), ) def test_get_object(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args, **kwargs: Response(content=b"content"), ): self.assertEqual(self.conn.get_object("a").read(), b"content") with patch( "geventhttpclient.HTTPClient.request", lambda *args, **kwargs: Response(content=b"content"), ): self.assertEqual(self.conn.get_object("a", range="0-6"), b"content") def test_get_object_fails(self): with patch( "geventhttpclient.HTTPClient.request", lambda *args, **kwargs: Response(status=404), ): self.assertEqual(self.conn.get_object("a"), None) def test_del_object(self): with patch("geventhttpclient.HTTPClient.request", lambda *args: Response()): self.assertEqual(self.conn.del_object("a"), None) def test_del_root(self): with patch( "dulwich.contrib.swift.SwiftConnector.del_object", lambda *args: None, ): with patch( "dulwich.contrib.swift.SwiftConnector." "get_container_objects", lambda *args: ({"name": "a"}, {"name": "b"}), ): with patch( "geventhttpclient.HTTPClient.request", lambda *args: Response(), ): self.assertEqual(self.conn.del_root(), None) @skipIf(missing_libs, skipmsg) class SwiftObjectStoreTests(ObjectStoreTests, TestCase): def setUp(self): TestCase.setUp(self) conf = swift.load_conf(file=StringIO(config_file % def_config_file)) fsc = FakeSwiftConnector("fakerepo", conf=conf) self.store = swift.SwiftObjectStore(fsc) diff --git a/dulwich/diff_tree.py b/dulwich/diff_tree.py index 2962575f..6e53c52d 100644 --- a/dulwich/diff_tree.py +++ b/dulwich/diff_tree.py @@ -1,648 +1,648 @@ # diff_tree.py -- Utilities for diffing files and trees. # Copyright (C) 2010 Google, Inc. # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Utilities for diffing files and trees.""" from collections import ( defaultdict, namedtuple, ) from io import BytesIO from itertools import chain import stat from dulwich.objects import ( S_ISGITLINK, TreeEntry, ) # TreeChange type constants. CHANGE_ADD = "add" CHANGE_MODIFY = "modify" CHANGE_DELETE = "delete" CHANGE_RENAME = "rename" CHANGE_COPY = "copy" CHANGE_UNCHANGED = "unchanged" RENAME_CHANGE_TYPES = (CHANGE_RENAME, CHANGE_COPY) _NULL_ENTRY = TreeEntry(None, None, None) _MAX_SCORE = 100 RENAME_THRESHOLD = 60 MAX_FILES = 200 REWRITE_THRESHOLD = None class TreeChange(namedtuple("TreeChange", ["type", "old", "new"])): """Named tuple a single change between two trees.""" @classmethod def add(cls, new): return cls(CHANGE_ADD, _NULL_ENTRY, new) @classmethod def delete(cls, old): return cls(CHANGE_DELETE, old, _NULL_ENTRY) def _tree_entries(path, tree): result = [] if not tree: return result for entry in tree.iteritems(name_order=True): result.append(entry.in_path(path)) return result def _merge_entries(path, tree1, tree2): """Merge the entries of two trees. Args: path: A path to prepend to all tree entry names. tree1: The first Tree object to iterate, or None. tree2: The second Tree object to iterate, or None. Returns: A list of pairs of TreeEntry objects for each pair of entries in the trees. If an entry exists in one tree but not the other, the other entry will have all attributes set to None. If neither entry's path is None, they are guaranteed to match. """ entries1 = _tree_entries(path, tree1) entries2 = _tree_entries(path, tree2) i1 = i2 = 0 len1 = len(entries1) len2 = len(entries2) result = [] while i1 < len1 and i2 < len2: entry1 = entries1[i1] entry2 = entries2[i2] if entry1.path < entry2.path: result.append((entry1, _NULL_ENTRY)) i1 += 1 elif entry1.path > entry2.path: result.append((_NULL_ENTRY, entry2)) i2 += 1 else: result.append((entry1, entry2)) i1 += 1 i2 += 1 for i in range(i1, len1): result.append((entries1[i], _NULL_ENTRY)) for i in range(i2, len2): result.append((_NULL_ENTRY, entries2[i])) return result def _is_tree(entry): mode = entry.mode if mode is None: return False return stat.S_ISDIR(mode) def walk_trees(store, tree1_id, tree2_id, prune_identical=False): """Recursively walk all the entries of two trees. Iteration is depth-first pre-order, as in e.g. os.walk. Args: store: An ObjectStore for looking up objects. tree1_id: The SHA of the first Tree object to iterate, or None. tree2_id: The SHA of the second Tree object to iterate, or None. param prune_identical: If True, identical subtrees will not be walked. Returns: Iterator over Pairs of TreeEntry objects for each pair of entries in the trees and their subtrees recursively. If an entry exists in one tree but not the other, the other entry will have all attributes set to None. If neither entry's path is None, they are guaranteed to match. """ # This could be fairly easily generalized to >2 trees if we find a use # case. mode1 = tree1_id and stat.S_IFDIR or None mode2 = tree2_id and stat.S_IFDIR or None todo = [(TreeEntry(b"", mode1, tree1_id), TreeEntry(b"", mode2, tree2_id))] while todo: entry1, entry2 = todo.pop() is_tree1 = _is_tree(entry1) is_tree2 = _is_tree(entry2) if prune_identical and is_tree1 and is_tree2 and entry1 == entry2: continue tree1 = is_tree1 and store[entry1.sha] or None tree2 = is_tree2 and store[entry2.sha] or None path = entry1.path or entry2.path todo.extend(reversed(_merge_entries(path, tree1, tree2))) yield entry1, entry2 def _skip_tree(entry, include_trees): if entry.mode is None or (not include_trees and stat.S_ISDIR(entry.mode)): return _NULL_ENTRY return entry def tree_changes( store, tree1_id, tree2_id, want_unchanged=False, rename_detector=None, include_trees=False, change_type_same=False, ): """Find the differences between the contents of two trees. Args: store: An ObjectStore for looking up objects. tree1_id: The SHA of the source tree. tree2_id: The SHA of the target tree. want_unchanged: If True, include TreeChanges for unmodified entries as well. include_trees: Whether to include trees rename_detector: RenameDetector object for detecting renames. change_type_same: Whether to report change types in the same entry or as delete+add. Returns: Iterator over TreeChange instances for each change between the source and target tree. """ if rename_detector is not None and tree1_id is not None and tree2_id is not None: for change in rename_detector.changes_with_renames( tree1_id, tree2_id, want_unchanged=want_unchanged, include_trees=include_trees, ): yield change return entries = walk_trees( store, tree1_id, tree2_id, prune_identical=(not want_unchanged) ) for entry1, entry2 in entries: if entry1 == entry2 and not want_unchanged: continue # Treat entries for trees as missing. entry1 = _skip_tree(entry1, include_trees) entry2 = _skip_tree(entry2, include_trees) if entry1 != _NULL_ENTRY and entry2 != _NULL_ENTRY: if ( stat.S_IFMT(entry1.mode) != stat.S_IFMT(entry2.mode) and not change_type_same ): # File type changed: report as delete/add. yield TreeChange.delete(entry1) entry1 = _NULL_ENTRY change_type = CHANGE_ADD elif entry1 == entry2: change_type = CHANGE_UNCHANGED else: change_type = CHANGE_MODIFY elif entry1 != _NULL_ENTRY: change_type = CHANGE_DELETE elif entry2 != _NULL_ENTRY: change_type = CHANGE_ADD else: # Both were None because at least one was a tree. continue yield TreeChange(change_type, entry1, entry2) def _all_eq(seq, key, value): for e in seq: if key(e) != value: return False return True def _all_same(seq, key): return _all_eq(seq[1:], key, key(seq[0])) def tree_changes_for_merge(store, parent_tree_ids, tree_id, rename_detector=None): """Get the tree changes for a merge tree relative to all its parents. Args: store: An ObjectStore for looking up objects. parent_tree_ids: An iterable of the SHAs of the parent trees. tree_id: The SHA of the merge tree. rename_detector: RenameDetector object for detecting renames. Returns: Iterator over lists of TreeChange objects, one per conflicted path in the merge. Each list contains one element per parent, with the TreeChange for that path relative to that parent. An element may be None if it never existed in one parent and was deleted in two others. A path is only included in the output if it is a conflict, i.e. its SHA in the merge tree is not found in any of the parents, or in the case of deletes, if not all of the old SHAs match. """ all_parent_changes = [ tree_changes(store, t, tree_id, rename_detector=rename_detector) for t in parent_tree_ids ] num_parents = len(parent_tree_ids) changes_by_path = defaultdict(lambda: [None] * num_parents) # Organize by path. for i, parent_changes in enumerate(all_parent_changes): for change in parent_changes: if change.type == CHANGE_DELETE: path = change.old.path else: path = change.new.path changes_by_path[path][i] = change def old_sha(c): return c.old.sha def change_type(c): return c.type # Yield only conflicting changes. for _, changes in sorted(changes_by_path.items()): assert len(changes) == num_parents have = [c for c in changes if c is not None] if _all_eq(have, change_type, CHANGE_DELETE): if not _all_same(have, old_sha): yield changes elif not _all_same(have, change_type): yield changes elif None not in changes: # If no change was found relative to one parent, that means the SHA # must have matched the SHA in that parent, so it is not a # conflict. yield changes _BLOCK_SIZE = 64 def _count_blocks(obj): """Count the blocks in an object. Splits the data into blocks either on lines or <=64-byte chunks of lines. Args: obj: The object to count blocks for. Returns: A dict of block hashcode -> total bytes occurring. """ block_counts = defaultdict(int) block = BytesIO() n = 0 # Cache attrs as locals to avoid expensive lookups in the inner loop. block_write = block.write block_seek = block.seek block_truncate = block.truncate block_getvalue = block.getvalue for c in chain.from_iterable(obj.as_raw_chunks()): c = c.to_bytes(1, "big") block_write(c) n += 1 if c == b"\n" or n == _BLOCK_SIZE: value = block_getvalue() block_counts[hash(value)] += len(value) block_seek(0) block_truncate() n = 0 if n > 0: last_block = block_getvalue() block_counts[hash(last_block)] += len(last_block) return block_counts def _common_bytes(blocks1, blocks2): """Count the number of common bytes in two block count dicts. Args: block1: The first dict of block hashcode -> total bytes. block2: The second dict of block hashcode -> total bytes. Returns: The number of bytes in common between blocks1 and blocks2. This is only approximate due to possible hash collisions. """ # Iterate over the smaller of the two dicts, since this is symmetrical. if len(blocks1) > len(blocks2): blocks1, blocks2 = blocks2, blocks1 score = 0 for block, count1 in blocks1.items(): count2 = blocks2.get(block) if count2: score += min(count1, count2) return score def _similarity_score(obj1, obj2, block_cache=None): """Compute a similarity score for two objects. Args: obj1: The first object to score. obj2: The second object to score. block_cache: An optional dict of SHA to block counts to cache results between calls. Returns: The similarity score between the two objects, defined as the number of bytes in common between the two objects divided by the maximum size, scaled to the range 0-100. """ if block_cache is None: block_cache = {} if obj1.id not in block_cache: block_cache[obj1.id] = _count_blocks(obj1) if obj2.id not in block_cache: block_cache[obj2.id] = _count_blocks(obj2) common_bytes = _common_bytes(block_cache[obj1.id], block_cache[obj2.id]) max_size = max(obj1.raw_length(), obj2.raw_length()) if not max_size: return _MAX_SCORE return int(float(common_bytes) * _MAX_SCORE / max_size) def _tree_change_key(entry): # Sort by old path then new path. If only one exists, use it for both keys. path1 = entry.old.path path2 = entry.new.path if path1 is None: path1 = path2 if path2 is None: path2 = path1 return (path1, path2) class RenameDetector(object): """Object for handling rename detection between two trees.""" def __init__( self, store, rename_threshold=RENAME_THRESHOLD, max_files=MAX_FILES, rewrite_threshold=REWRITE_THRESHOLD, find_copies_harder=False, ): """Initialize the rename detector. Args: store: An ObjectStore for looking up objects. rename_threshold: The threshold similarity score for considering an add/delete pair to be a rename/copy; see _similarity_score. max_files: The maximum number of adds and deletes to consider, or None for no limit. The detector is guaranteed to compare no more than max_files ** 2 add/delete pairs. This limit is provided because rename detection can be quadratic in the project size. If the limit is exceeded, no content rename detection is attempted. rewrite_threshold: The threshold similarity score below which a modify should be considered a delete/add, or None to not break modifies; see _similarity_score. find_copies_harder: If True, consider unmodified files when detecting copies. """ self._store = store self._rename_threshold = rename_threshold self._rewrite_threshold = rewrite_threshold self._max_files = max_files self._find_copies_harder = find_copies_harder self._want_unchanged = False def _reset(self): self._adds = [] self._deletes = [] self._changes = [] def _should_split(self, change): if ( self._rewrite_threshold is None or change.type != CHANGE_MODIFY or change.old.sha == change.new.sha ): return False old_obj = self._store[change.old.sha] new_obj = self._store[change.new.sha] return _similarity_score(old_obj, new_obj) < self._rewrite_threshold def _add_change(self, change): if change.type == CHANGE_ADD: self._adds.append(change) elif change.type == CHANGE_DELETE: self._deletes.append(change) elif self._should_split(change): self._deletes.append(TreeChange.delete(change.old)) self._adds.append(TreeChange.add(change.new)) elif ( self._find_copies_harder and change.type == CHANGE_UNCHANGED ) or change.type == CHANGE_MODIFY: # Treat all modifies as potential deletes for rename detection, # but don't split them (to avoid spurious renames). Setting # find_copies_harder means we treat unchanged the same as # modified. self._deletes.append(change) else: self._changes.append(change) def _collect_changes(self, tree1_id, tree2_id): want_unchanged = self._find_copies_harder or self._want_unchanged for change in tree_changes( self._store, tree1_id, tree2_id, want_unchanged=want_unchanged, include_trees=self._include_trees, ): self._add_change(change) def _prune(self, add_paths, delete_paths): self._adds = [a for a in self._adds if a.new.path not in add_paths] self._deletes = [d for d in self._deletes if d.old.path not in delete_paths] def _find_exact_renames(self): add_map = defaultdict(list) for add in self._adds: add_map[add.new.sha].append(add.new) delete_map = defaultdict(list) for delete in self._deletes: # Keep track of whether the delete was actually marked as a delete. # If not, it needs to be marked as a copy. is_delete = delete.type == CHANGE_DELETE delete_map[delete.old.sha].append((delete.old, is_delete)) add_paths = set() delete_paths = set() for sha, sha_deletes in delete_map.items(): sha_adds = add_map[sha] for (old, is_delete), new in zip(sha_deletes, sha_adds): if stat.S_IFMT(old.mode) != stat.S_IFMT(new.mode): continue if is_delete: delete_paths.add(old.path) add_paths.add(new.path) new_type = is_delete and CHANGE_RENAME or CHANGE_COPY self._changes.append(TreeChange(new_type, old, new)) num_extra_adds = len(sha_adds) - len(sha_deletes) # TODO(dborowitz): Less arbitrary way of dealing with extra copies. old = sha_deletes[0][0] if num_extra_adds > 0: for new in sha_adds[-num_extra_adds:]: add_paths.add(new.path) self._changes.append(TreeChange(CHANGE_COPY, old, new)) self._prune(add_paths, delete_paths) def _should_find_content_renames(self): return len(self._adds) * len(self._deletes) <= self._max_files ** 2 def _rename_type(self, check_paths, delete, add): if check_paths and delete.old.path == add.new.path: # If the paths match, this must be a split modify, so make sure it # comes out as a modify. return CHANGE_MODIFY elif delete.type != CHANGE_DELETE: # If it's in deletes but not marked as a delete, it must have been # added due to find_copies_harder, and needs to be marked as a # copy. return CHANGE_COPY return CHANGE_RENAME def _find_content_rename_candidates(self): candidates = self._candidates = [] # TODO: Optimizations: # - Compare object sizes before counting blocks. # - Skip if delete's S_IFMT differs from all adds. # - Skip if adds or deletes is empty. # Match C git's behavior of not attempting to find content renames if # the matrix size exceeds the threshold. if not self._should_find_content_renames(): return block_cache = {} check_paths = self._rename_threshold is not None for delete in self._deletes: if S_ISGITLINK(delete.old.mode): continue # Git links don't exist in this repo. old_sha = delete.old.sha old_obj = self._store[old_sha] block_cache[old_sha] = _count_blocks(old_obj) for add in self._adds: if stat.S_IFMT(delete.old.mode) != stat.S_IFMT(add.new.mode): continue new_obj = self._store[add.new.sha] score = _similarity_score(old_obj, new_obj, block_cache=block_cache) if score > self._rename_threshold: new_type = self._rename_type(check_paths, delete, add) rename = TreeChange(new_type, delete.old, add.new) candidates.append((-score, rename)) def _choose_content_renames(self): # Sort scores from highest to lowest, but keep names in ascending # order. self._candidates.sort() delete_paths = set() add_paths = set() for _, change in self._candidates: new_path = change.new.path if new_path in add_paths: continue old_path = change.old.path orig_type = change.type if old_path in delete_paths: change = TreeChange(CHANGE_COPY, change.old, change.new) # If the candidate was originally a copy, that means it came from a # modified or unchanged path, so we don't want to prune it. if orig_type != CHANGE_COPY: delete_paths.add(old_path) add_paths.add(new_path) self._changes.append(change) self._prune(add_paths, delete_paths) def _join_modifies(self): if self._rewrite_threshold is None: return modifies = {} - delete_map = dict((d.old.path, d) for d in self._deletes) + delete_map = {d.old.path: d for d in self._deletes} for add in self._adds: path = add.new.path delete = delete_map.get(path) if delete is not None and stat.S_IFMT(delete.old.mode) == stat.S_IFMT( add.new.mode ): modifies[path] = TreeChange(CHANGE_MODIFY, delete.old, add.new) self._adds = [a for a in self._adds if a.new.path not in modifies] self._deletes = [a for a in self._deletes if a.new.path not in modifies] self._changes += modifies.values() def _sorted_changes(self): result = [] result.extend(self._adds) result.extend(self._deletes) result.extend(self._changes) result.sort(key=_tree_change_key) return result def _prune_unchanged(self): if self._want_unchanged: return self._deletes = [d for d in self._deletes if d.type != CHANGE_UNCHANGED] def changes_with_renames( self, tree1_id, tree2_id, want_unchanged=False, include_trees=False ): """Iterate TreeChanges between two tree SHAs, with rename detection.""" self._reset() self._want_unchanged = want_unchanged self._include_trees = include_trees self._collect_changes(tree1_id, tree2_id) self._find_exact_renames() self._find_content_rename_candidates() self._choose_content_renames() self._join_modifies() self._prune_unchanged() return self._sorted_changes() # Hold on to the pure-python implementations for testing. _is_tree_py = _is_tree _merge_entries_py = _merge_entries _count_blocks_py = _count_blocks try: # Try to import C versions from dulwich._diff_tree import ( # type: ignore _is_tree, _merge_entries, _count_blocks, ) except ImportError: pass diff --git a/dulwich/greenthreads.py b/dulwich/greenthreads.py index 9ed49054..ec89e8f3 100644 --- a/dulwich/greenthreads.py +++ b/dulwich/greenthreads.py @@ -1,146 +1,146 @@ # greenthreads.py -- Utility module for querying an ObjectStore with gevent # Copyright (C) 2013 eNovance SAS # # Author: Fabien Boucher # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Utility module for querying an ObjectStore with gevent.""" import gevent from gevent import pool from dulwich.objects import ( Commit, Tag, ) from dulwich.object_store import ( MissingObjectFinder, _collect_filetree_revs, ObjectStoreIterator, ) def _split_commits_and_tags(obj_store, lst, ignore_unknown=False, pool=None): """Split object id list into two list with commit SHA1s and tag SHA1s. Same implementation as object_store._split_commits_and_tags except we use gevent to parallelize object retrieval. """ commits = set() tags = set() def find_commit_type(sha): try: o = obj_store[sha] except KeyError: if not ignore_unknown: raise else: if isinstance(o, Commit): commits.add(sha) elif isinstance(o, Tag): tags.add(sha) commits.add(o.object[1]) else: raise KeyError("Not a commit or a tag: %s" % sha) jobs = [pool.spawn(find_commit_type, s) for s in lst] gevent.joinall(jobs) return (commits, tags) class GreenThreadsMissingObjectFinder(MissingObjectFinder): """Find the objects missing from another object store. Same implementation as object_store.MissingObjectFinder except we use gevent to parallelize object retrieval. """ def __init__( self, object_store, haves, wants, progress=None, get_tagged=None, concurrency=1, get_parents=None, ): def collect_tree_sha(sha): self.sha_done.add(sha) cmt = object_store[sha] _collect_filetree_revs(object_store, cmt.tree, self.sha_done) self.object_store = object_store p = pool.Pool(size=concurrency) have_commits, have_tags = _split_commits_and_tags(object_store, haves, True, p) want_commits, want_tags = _split_commits_and_tags(object_store, wants, False, p) all_ancestors = object_store._collect_ancestors(have_commits)[0] missing_commits, common_commits = object_store._collect_ancestors( want_commits, all_ancestors ) self.sha_done = set() jobs = [p.spawn(collect_tree_sha, c) for c in common_commits] gevent.joinall(jobs) for t in have_tags: self.sha_done.add(t) missing_tags = want_tags.difference(have_tags) wants = missing_commits.union(missing_tags) self.objects_to_send = set([(w, None, False) for w in wants]) if progress is None: self.progress = lambda x: None else: self.progress = progress self._tagged = get_tagged and get_tagged() or {} class GreenThreadsObjectStoreIterator(ObjectStoreIterator): """ObjectIterator that works on top of an ObjectStore. Same implementation as object_store.ObjectStoreIterator except we use gevent to parallelize object retrieval. """ def __init__(self, store, shas, finder, concurrency=1): self.finder = finder self.p = pool.Pool(size=concurrency) super(GreenThreadsObjectStoreIterator, self).__init__(store, shas) def retrieve(self, args): sha, path = args return self.store[sha], path def __iter__(self): for sha, path in self.p.imap_unordered(self.retrieve, self.itershas()): yield sha, path def __len__(self): if len(self._shas) > 0: return len(self._shas) - while len(self.finder.objects_to_send): + while self.finder.objects_to_send: jobs = [] for _ in range(0, len(self.finder.objects_to_send)): jobs.append(self.p.spawn(self.finder.next)) gevent.joinall(jobs) for j in jobs: if j.value is not None: self._shas.append(j.value) return len(self._shas) diff --git a/dulwich/lru_cache.py b/dulwich/lru_cache.py index 24b7c18f..be18766f 100644 --- a/dulwich/lru_cache.py +++ b/dulwich/lru_cache.py @@ -1,383 +1,383 @@ # lru_cache.py -- Simple LRU cache for dulwich # Copyright (C) 2006, 2008 Canonical Ltd # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """A simple least-recently-used (LRU) cache.""" _null_key = object() class _LRUNode(object): """This maintains the linked-list which is the lru internals.""" __slots__ = ("prev", "next_key", "key", "value", "cleanup", "size") def __init__(self, key, value, cleanup=None): self.prev = None self.next_key = _null_key self.key = key self.value = value self.cleanup = cleanup # TODO: We could compute this 'on-the-fly' like we used to, and remove # one pointer from this object, we just need to decide if it # actually costs us much of anything in normal usage self.size = None def __repr__(self): if self.prev is None: prev_key = None else: prev_key = self.prev.key return "%s(%r n:%r p:%r)" % ( self.__class__.__name__, self.key, self.next_key, prev_key, ) def run_cleanup(self): if self.cleanup is not None: self.cleanup(self.key, self.value) self.cleanup = None # Just make sure to break any refcycles, etc self.value = None class LRUCache(object): """A class which manages a cache of entries, removing unused ones.""" def __init__(self, max_cache=100, after_cleanup_count=None): self._cache = {} # The "HEAD" of the lru linked list self._most_recently_used = None # The "TAIL" of the lru linked list self._least_recently_used = None self._update_max_cache(max_cache, after_cleanup_count) def __contains__(self, key): return key in self._cache def __getitem__(self, key): cache = self._cache node = cache[key] # Inlined from _record_access to decrease the overhead of __getitem__ # We also have more knowledge about structure if __getitem__ is # succeeding, then we know that self._most_recently_used must not be # None, etc. mru = self._most_recently_used if node is mru: # Nothing to do, this node is already at the head of the queue return node.value # Remove this node from the old location node_prev = node.prev next_key = node.next_key # benchmarking shows that the lookup of _null_key in globals is faster # than the attribute lookup for (node is self._least_recently_used) if next_key is _null_key: # 'node' is the _least_recently_used, because it doesn't have a # 'next' item. So move the current lru to the previous node. self._least_recently_used = node_prev else: node_next = cache[next_key] node_next.prev = node_prev node_prev.next_key = next_key # Insert this node at the front of the list node.next_key = mru.key mru.prev = node self._most_recently_used = node node.prev = None return node.value def __len__(self): return len(self._cache) def _walk_lru(self): """Walk the LRU list, only meant to be used in tests.""" node = self._most_recently_used if node is not None: if node.prev is not None: raise AssertionError( "the _most_recently_used entry is not" " supposed to have a previous entry" " %s" % (node,) ) while node is not None: if node.next_key is _null_key: if node is not self._least_recently_used: raise AssertionError( "only the last node should have" " no next value: %s" % (node,) ) node_next = None else: node_next = self._cache[node.next_key] if node_next.prev is not node: raise AssertionError( "inconsistency found, node.next.prev" " != node: %s" % (node,) ) if node.prev is None: if node is not self._most_recently_used: raise AssertionError( "only the _most_recently_used should" " not have a previous node: %s" % (node,) ) else: if node.prev.next_key != node.key: raise AssertionError( "inconsistency found, node.prev.next" " != node: %s" % (node,) ) yield node node = node_next def add(self, key, value, cleanup=None): """Add a new value to the cache. Also, if the entry is ever removed from the cache, call cleanup(key, value). Args: key: The key to store it under value: The object to store cleanup: None or a function taking (key, value) to indicate 'value' should be cleaned up. """ if key is _null_key: raise ValueError("cannot use _null_key as a key") if key in self._cache: node = self._cache[key] node.run_cleanup() node.value = value node.cleanup = cleanup else: node = _LRUNode(key, value, cleanup=cleanup) self._cache[key] = node self._record_access(node) if len(self._cache) > self._max_cache: # Trigger the cleanup self.cleanup() def cache_size(self): """Get the number of entries we will cache.""" return self._max_cache def get(self, key, default=None): node = self._cache.get(key, None) if node is None: return default self._record_access(node) return node.value def keys(self): """Get the list of keys currently cached. Note that values returned here may not be available by the time you request them later. This is simply meant as a peak into the current state. Returns: An unordered list of keys that are currently cached. """ return self._cache.keys() def items(self): """Get the key:value pairs as a dict.""" - return dict((k, n.value) for k, n in self._cache.items()) + return {k: n.value for k, n in self._cache.items()} def cleanup(self): """Clear the cache until it shrinks to the requested size. This does not completely wipe the cache, just makes sure it is under the after_cleanup_count. """ # Make sure the cache is shrunk to the correct size while len(self._cache) > self._after_cleanup_count: self._remove_lru() def __setitem__(self, key, value): """Add a value to the cache, there will be no cleanup function.""" self.add(key, value, cleanup=None) def _record_access(self, node): """Record that key was accessed.""" # Move 'node' to the front of the queue if self._most_recently_used is None: self._most_recently_used = node self._least_recently_used = node return elif node is self._most_recently_used: # Nothing to do, this node is already at the head of the queue return # We've taken care of the tail pointer, remove the node, and insert it # at the front # REMOVE if node is self._least_recently_used: self._least_recently_used = node.prev if node.prev is not None: node.prev.next_key = node.next_key if node.next_key is not _null_key: node_next = self._cache[node.next_key] node_next.prev = node.prev # INSERT node.next_key = self._most_recently_used.key self._most_recently_used.prev = node self._most_recently_used = node node.prev = None def _remove_node(self, node): if node is self._least_recently_used: self._least_recently_used = node.prev self._cache.pop(node.key) # If we have removed all entries, remove the head pointer as well if self._least_recently_used is None: self._most_recently_used = None node.run_cleanup() # Now remove this node from the linked list if node.prev is not None: node.prev.next_key = node.next_key if node.next_key is not _null_key: node_next = self._cache[node.next_key] node_next.prev = node.prev # And remove this node's pointers node.prev = None node.next_key = _null_key def _remove_lru(self): """Remove one entry from the lru, and handle consequences. If there are no more references to the lru, then this entry should be removed from the cache. """ self._remove_node(self._least_recently_used) def clear(self): """Clear out all of the cache.""" # Clean up in LRU order while self._cache: self._remove_lru() def resize(self, max_cache, after_cleanup_count=None): """Change the number of entries that will be cached.""" self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count) def _update_max_cache(self, max_cache, after_cleanup_count=None): self._max_cache = max_cache if after_cleanup_count is None: self._after_cleanup_count = self._max_cache * 8 / 10 else: self._after_cleanup_count = min(after_cleanup_count, self._max_cache) self.cleanup() class LRUSizeCache(LRUCache): """An LRUCache that removes things based on the size of the values. This differs in that it doesn't care how many actual items there are, it just restricts the cache to be cleaned up after so much data is stored. The size of items added will be computed using compute_size(value), which defaults to len() if not supplied. """ def __init__( self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None ): """Create a new LRUSizeCache. Args: max_size: The max number of bytes to store before we start clearing out entries. after_cleanup_size: After cleaning up, shrink everything to this size. compute_size: A function to compute the size of the values. We use a function here, so that you can pass 'len' if you are just using simple strings, or a more complex function if you are using something like a list of strings, or even a custom object. The function should take the form "compute_size(value) => integer". If not supplied, it defaults to 'len()' """ self._value_size = 0 self._compute_size = compute_size if compute_size is None: self._compute_size = len self._update_max_size(max_size, after_cleanup_size=after_cleanup_size) LRUCache.__init__(self, max_cache=max(int(max_size / 512), 1)) def add(self, key, value, cleanup=None): """Add a new value to the cache. Also, if the entry is ever removed from the cache, call cleanup(key, value). Args: key: The key to store it under value: The object to store cleanup: None or a function taking (key, value) to indicate 'value' should be cleaned up. """ if key is _null_key: raise ValueError("cannot use _null_key as a key") node = self._cache.get(key, None) value_len = self._compute_size(value) if value_len >= self._after_cleanup_size: # The new value is 'too big to fit', as it would fill up/overflow # the cache all by itself if node is not None: # We won't be replacing the old node, so just remove it self._remove_node(node) if cleanup is not None: cleanup(key, value) return if node is None: node = _LRUNode(key, value, cleanup=cleanup) self._cache[key] = node else: self._value_size -= node.size node.size = value_len self._value_size += value_len self._record_access(node) if self._value_size > self._max_size: # Time to cleanup self.cleanup() def cleanup(self): """Clear the cache until it shrinks to the requested size. This does not completely wipe the cache, just makes sure it is under the after_cleanup_size. """ # Make sure the cache is shrunk to the correct size while self._value_size > self._after_cleanup_size: self._remove_lru() def _remove_node(self, node): self._value_size -= node.size LRUCache._remove_node(self, node) def resize(self, max_size, after_cleanup_size=None): """Change the number of bytes that will be cached.""" self._update_max_size(max_size, after_cleanup_size=after_cleanup_size) max_cache = max(int(max_size / 512), 1) self._update_max_cache(max_cache) def _update_max_size(self, max_size, after_cleanup_size=None): self._max_size = max_size if after_cleanup_size is None: self._after_cleanup_size = self._max_size * 8 // 10 else: self._after_cleanup_size = min(after_cleanup_size, self._max_size) diff --git a/dulwich/protocol.py b/dulwich/protocol.py index 0c4cae1f..9641d06f 100644 --- a/dulwich/protocol.py +++ b/dulwich/protocol.py @@ -1,583 +1,583 @@ # protocol.py -- Shared parts of the git protocols # Copyright (C) 2008 John Carr # Copyright (C) 2008-2012 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Generic functions for talking the git smart server protocol.""" from io import BytesIO from os import ( SEEK_END, ) import socket import dulwich from dulwich.errors import ( HangupException, GitProtocolError, ) TCP_GIT_PORT = 9418 ZERO_SHA = b"0" * 40 SINGLE_ACK = 0 MULTI_ACK = 1 MULTI_ACK_DETAILED = 2 # pack data SIDE_BAND_CHANNEL_DATA = 1 # progress messages SIDE_BAND_CHANNEL_PROGRESS = 2 # fatal error message just before stream aborts SIDE_BAND_CHANNEL_FATAL = 3 CAPABILITY_ATOMIC = b"atomic" CAPABILITY_DEEPEN_SINCE = b"deepen-since" CAPABILITY_DEEPEN_NOT = b"deepen-not" CAPABILITY_DEEPEN_RELATIVE = b"deepen-relative" CAPABILITY_DELETE_REFS = b"delete-refs" CAPABILITY_INCLUDE_TAG = b"include-tag" CAPABILITY_MULTI_ACK = b"multi_ack" CAPABILITY_MULTI_ACK_DETAILED = b"multi_ack_detailed" CAPABILITY_NO_DONE = b"no-done" CAPABILITY_NO_PROGRESS = b"no-progress" CAPABILITY_OFS_DELTA = b"ofs-delta" CAPABILITY_QUIET = b"quiet" CAPABILITY_REPORT_STATUS = b"report-status" CAPABILITY_SHALLOW = b"shallow" CAPABILITY_SIDE_BAND = b"side-band" CAPABILITY_SIDE_BAND_64K = b"side-band-64k" CAPABILITY_THIN_PACK = b"thin-pack" CAPABILITY_AGENT = b"agent" CAPABILITY_SYMREF = b"symref" CAPABILITY_ALLOW_TIP_SHA1_IN_WANT = b"allow-tip-sha1-in-want" CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT = b"allow-reachable-sha1-in-want" # Magic ref that is used to attach capabilities to when # there are no refs. Should always be ste to ZERO_SHA. CAPABILITIES_REF = b"capabilities^{}" COMMON_CAPABILITIES = [ CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K, CAPABILITY_AGENT, CAPABILITY_NO_PROGRESS, ] KNOWN_UPLOAD_CAPABILITIES = set( COMMON_CAPABILITIES + [ CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_INCLUDE_TAG, CAPABILITY_DEEPEN_SINCE, CAPABILITY_SYMREF, CAPABILITY_SHALLOW, CAPABILITY_DEEPEN_NOT, CAPABILITY_DEEPEN_RELATIVE, CAPABILITY_ALLOW_TIP_SHA1_IN_WANT, CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT, ] ) KNOWN_RECEIVE_CAPABILITIES = set( COMMON_CAPABILITIES + [ CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS, CAPABILITY_QUIET, CAPABILITY_ATOMIC, ] ) def agent_string(): return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii") def capability_agent(): return CAPABILITY_AGENT + b"=" + agent_string() def capability_symref(from_ref, to_ref): return CAPABILITY_SYMREF + b"=" + from_ref + b":" + to_ref def extract_capability_names(capabilities): - return set(parse_capability(c)[0] for c in capabilities) + return {parse_capability(c)[0] for c in capabilities} def parse_capability(capability): parts = capability.split(b"=", 1) if len(parts) == 1: return (parts[0], None) return tuple(parts) def symref_capabilities(symrefs): return [capability_symref(*k) for k in symrefs] COMMAND_DEEPEN = b"deepen" COMMAND_SHALLOW = b"shallow" COMMAND_UNSHALLOW = b"unshallow" COMMAND_DONE = b"done" COMMAND_WANT = b"want" COMMAND_HAVE = b"have" class ProtocolFile(object): """A dummy file for network ops that expect file-like objects.""" def __init__(self, read, write): self.read = read self.write = write def tell(self): pass def close(self): pass def format_cmd_pkt(cmd, *args): return cmd + b" " + b"".join([(a + b"\0") for a in args]) def parse_cmd_pkt(line): splice_at = line.find(b" ") cmd, args = line[:splice_at], line[splice_at + 1 :] assert args[-1:] == b"\x00" return cmd, args[:-1].split(b"\0") def pkt_line(data): """Wrap data in a pkt-line. Args: data: The data to wrap, as a str or None. Returns: The data prefixed with its length in pkt-line format; if data was None, returns the flush-pkt ('0000'). """ if data is None: return b"0000" return ("%04x" % (len(data) + 4)).encode("ascii") + data class Protocol(object): """Class for interacting with a remote git process over the wire. Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line consists of the length of the line as a 4-byte hex string, followed by the payload data. The length includes the 4-byte header. The special line '0000' indicates the end of a section of input and is called a 'flush-pkt'. For details on the pkt-line format, see the cgit distribution: Documentation/technical/protocol-common.txt """ def __init__(self, read, write, close=None, report_activity=None): self.read = read self.write = write self._close = close self.report_activity = report_activity self._readahead = None def close(self): if self._close: self._close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def read_pkt_line(self): """Reads a pkt-line from the remote git process. This method may read from the readahead buffer; see unread_pkt_line. Returns: The next string from the stream, without the length prefix, or None for a flush-pkt ('0000'). """ if self._readahead is None: read = self.read else: read = self._readahead.read self._readahead = None try: sizestr = read(4) if not sizestr: raise HangupException() size = int(sizestr, 16) if size == 0: if self.report_activity: self.report_activity(4, "read") return None if self.report_activity: self.report_activity(size, "read") pkt_contents = read(size - 4) except socket.error as e: raise GitProtocolError(e) else: if len(pkt_contents) + 4 != size: raise GitProtocolError( "Length of pkt read %04x does not match length prefix %04x" % (len(pkt_contents) + 4, size) ) return pkt_contents def eof(self): """Test whether the protocol stream has reached EOF. Note that this refers to the actual stream EOF and not just a flush-pkt. Returns: True if the stream is at EOF, False otherwise. """ try: next_line = self.read_pkt_line() except HangupException: return True self.unread_pkt_line(next_line) return False def unread_pkt_line(self, data): """Unread a single line of data into the readahead buffer. This method can be used to unread a single pkt-line into a fixed readahead buffer. Args: data: The data to unread, without the length prefix. Raises: ValueError: If more than one pkt-line is unread. """ if self._readahead is not None: raise ValueError("Attempted to unread multiple pkt-lines.") self._readahead = BytesIO(pkt_line(data)) def read_pkt_seq(self): """Read a sequence of pkt-lines from the remote git process. Returns: Yields each line of data up to but not including the next flush-pkt. """ pkt = self.read_pkt_line() while pkt: yield pkt pkt = self.read_pkt_line() def write_pkt_line(self, line): """Sends a pkt-line to the remote git process. Args: line: A string containing the data to send, without the length prefix. """ try: line = pkt_line(line) self.write(line) if self.report_activity: self.report_activity(len(line), "write") except socket.error as e: raise GitProtocolError(e) def write_file(self): """Return a writable file-like object for this protocol.""" class ProtocolFile(object): def __init__(self, proto): self._proto = proto self._offset = 0 def write(self, data): self._proto.write(data) self._offset += len(data) def tell(self): return self._offset def close(self): pass return ProtocolFile(self) def write_sideband(self, channel, blob): """Write multiplexed data to the sideband. Args: channel: An int specifying the channel to write to. blob: A blob of data (as a string) to send on this channel. """ # a pktline can be a max of 65520. a sideband line can therefore be # 65520-5 = 65515 # WTF: Why have the len in ASCII, but the channel in binary. while blob: self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515]) blob = blob[65515:] def send_cmd(self, cmd, *args): """Send a command and some arguments to a git server. Only used for the TCP git protocol (git://). Args: cmd: The remote service to access. args: List of arguments to send to remove service. """ self.write_pkt_line(format_cmd_pkt(cmd, *args)) def read_cmd(self): """Read a command and some arguments from the git client Only used for the TCP git protocol (git://). Returns: A tuple of (command, [list of arguments]). """ line = self.read_pkt_line() return parse_cmd_pkt(line) _RBUFSIZE = 8192 # Default read buffer size. class ReceivableProtocol(Protocol): """Variant of Protocol that allows reading up to a size without blocking. This class has a recv() method that behaves like socket.recv() in addition to a read() method. If you want to read n bytes from the wire and block until exactly n bytes (or EOF) are read, use read(n). If you want to read at most n bytes from the wire but don't care if you get less, use recv(n). Note that recv(n) will still block until at least one byte is read. """ def __init__( self, recv, write, close=None, report_activity=None, rbufsize=_RBUFSIZE ): super(ReceivableProtocol, self).__init__( self.read, write, close=close, report_activity=report_activity ) self._recv = recv self._rbuf = BytesIO() self._rbufsize = rbufsize def read(self, size): # From _fileobj.read in socket.py in the Python 2.6.5 standard library, # with the following modifications: # - omit the size <= 0 branch # - seek back to start rather than 0 in case some buffer has been # consumed. # - use SEEK_END instead of the magic number. # Copyright (c) 2001-2010 Python Software Foundation; All Rights # Reserved # Licensed under the Python Software Foundation License. # TODO: see if buffer is more efficient than cBytesIO. assert size > 0 # Our use of BytesIO rather than lists of string objects returned by # recv() minimizes memory usage and fragmentation that occurs when # rbufsize is large compared to the typical return value of recv(). buf = self._rbuf start = buf.tell() buf.seek(0, SEEK_END) # buffer may have been partially consumed by recv() buf_len = buf.tell() - start if buf_len >= size: # Already have size bytes in our buffer? Extract and return. buf.seek(start) rv = buf.read(size) self._rbuf = BytesIO() self._rbuf.write(buf.read()) self._rbuf.seek(0) return rv self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. while True: left = size - buf_len # recv() will malloc the amount of memory given as its # parameter even though it often returns much less data # than that. The returned data string is short lived # as we copy it into a BytesIO and free it. This avoids # fragmentation issues on many platforms. data = self._recv(left) if not data: break n = len(data) if n == size and not buf_len: # Shortcut. Avoid buffer data copies when: # - We have no data in our buffer. # AND # - Our call to recv returned exactly the # number of bytes we were asked to read. return data if n == left: buf.write(data) del data # explicit free break assert n <= left, "_recv(%d) returned %d bytes" % (left, n) buf.write(data) buf_len += n del data # explicit free # assert buf_len == buf.tell() buf.seek(start) return buf.read() def recv(self, size): assert size > 0 buf = self._rbuf start = buf.tell() buf.seek(0, SEEK_END) buf_len = buf.tell() buf.seek(start) left = buf_len - start if not left: # only read from the wire if our read buffer is exhausted data = self._recv(self._rbufsize) if len(data) == size: # shortcut: skip the buffer if we read exactly size bytes return data buf = BytesIO() buf.write(data) buf.seek(0) del data # explicit free self._rbuf = buf return buf.read(size) def extract_capabilities(text): """Extract a capabilities list from a string, if present. Args: text: String to extract from Returns: Tuple with text with capabilities removed and list of capabilities """ if b"\0" not in text: return text, [] text, capabilities = text.rstrip().split(b"\0") return (text, capabilities.strip().split(b" ")) def extract_want_line_capabilities(text): """Extract a capabilities list from a want line, if present. Note that want lines have capabilities separated from the rest of the line by a space instead of a null byte. Thus want lines have the form: want obj-id cap1 cap2 ... Args: text: Want line to extract from Returns: Tuple with text with capabilities removed and list of capabilities """ split_text = text.rstrip().split(b" ") if len(split_text) < 3: return text, [] return (b" ".join(split_text[:2]), split_text[2:]) def ack_type(capabilities): """Extract the ack type from a capabilities list.""" if b"multi_ack_detailed" in capabilities: return MULTI_ACK_DETAILED elif b"multi_ack" in capabilities: return MULTI_ACK return SINGLE_ACK class BufferedPktLineWriter(object): """Writer that wraps its data in pkt-lines and has an independent buffer. Consecutive calls to write() wrap the data in a pkt-line and then buffers it until enough lines have been written such that their total length (including length prefix) reach the buffer size. """ def __init__(self, write, bufsize=65515): """Initialize the BufferedPktLineWriter. Args: write: A write callback for the underlying writer. bufsize: The internal buffer size, including length prefixes. """ self._write = write self._bufsize = bufsize self._wbuf = BytesIO() self._buflen = 0 def write(self, data): """Write data, wrapping it in a pkt-line.""" line = pkt_line(data) line_len = len(line) over = self._buflen + line_len - self._bufsize if over >= 0: start = line_len - over self._wbuf.write(line[:start]) self.flush() else: start = 0 saved = line[start:] self._wbuf.write(saved) self._buflen += len(saved) def flush(self): """Flush all data from the buffer.""" data = self._wbuf.getvalue() if data: self._write(data) self._len = 0 self._wbuf = BytesIO() class PktLineParser(object): """Packet line parser that hands completed packets off to a callback.""" def __init__(self, handle_pkt): self.handle_pkt = handle_pkt self._readahead = BytesIO() def parse(self, data): """Parse a fragment of data and call back for any completed packets.""" self._readahead.write(data) buf = self._readahead.getvalue() if len(buf) < 4: return while len(buf) >= 4: size = int(buf[:4], 16) if size == 0: self.handle_pkt(None) buf = buf[4:] elif size <= len(buf): self.handle_pkt(buf[4:size]) buf = buf[size:] else: break self._readahead = BytesIO() self._readahead.write(buf) def get_tail(self): """Read back any unused data.""" return self._readahead.getvalue() diff --git a/dulwich/repo.py b/dulwich/repo.py index 10012e86..de446623 100644 --- a/dulwich/repo.py +++ b/dulwich/repo.py @@ -1,1630 +1,1629 @@ # repo.py -- For dealing with git repositories. # Copyright (C) 2007 James Westby # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Repository access. This module contains the base class for git repositories (BaseRepo) and an implementation which uses a repository on local disk (Repo). """ from io import BytesIO import os import sys import stat import time from typing import Optional, Tuple, TYPE_CHECKING, List, Dict, Union, Iterable if TYPE_CHECKING: # There are no circular imports here, but we try to defer imports as long # as possible to reduce start-up time for anything that doesn't need # these imports. from dulwich.config import StackedConfig, ConfigFile from dulwich.index import Index from dulwich.errors import ( NoIndexPresent, NotBlobError, NotCommitError, NotGitRepository, NotTreeError, NotTagError, CommitError, RefFormatError, HookError, ) from dulwich.file import ( GitFile, ) from dulwich.object_store import ( DiskObjectStore, MemoryObjectStore, BaseObjectStore, ObjectStoreGraphWalker, ) from dulwich.objects import ( check_hexsha, valid_hexsha, Blob, Commit, ShaFile, Tag, Tree, ) from dulwich.pack import ( pack_objects_to_data, ) from dulwich.hooks import ( Hook, PreCommitShellHook, PostCommitShellHook, CommitMsgShellHook, PostReceiveShellHook, ) from dulwich.line_ending import BlobNormalizer from dulwich.refs import ( # noqa: F401 ANNOTATED_TAG_SUFFIX, check_ref_format, RefsContainer, DictRefsContainer, InfoRefsContainer, DiskRefsContainer, read_packed_refs, read_packed_refs_with_peeled, write_packed_refs, SYMREF, ) import warnings CONTROLDIR = ".git" OBJECTDIR = "objects" REFSDIR = "refs" REFSDIR_TAGS = "tags" REFSDIR_HEADS = "heads" INDEX_FILENAME = "index" COMMONDIR = "commondir" GITDIR = "gitdir" WORKTREES = "worktrees" BASE_DIRECTORIES = [ ["branches"], [REFSDIR], [REFSDIR, REFSDIR_TAGS], [REFSDIR, REFSDIR_HEADS], ["hooks"], ["info"], ] DEFAULT_REF = b"refs/heads/master" class InvalidUserIdentity(Exception): """User identity is not of the format 'user '""" def __init__(self, identity): self.identity = identity def _get_default_identity() -> Tuple[str, str]: import getpass import socket username = getpass.getuser() try: import pwd except ImportError: fullname = None else: try: gecos = pwd.getpwnam(username).pw_gecos except KeyError: fullname = None else: fullname = gecos.split(",")[0] if not fullname: fullname = username email = os.environ.get("EMAIL") if email is None: email = "{}@{}".format(username, socket.gethostname()) return (fullname, email) def get_user_identity(config: "StackedConfig", kind: Optional[str] = None) -> bytes: """Determine the identity to use for new commits. If kind is set, this first checks GIT_${KIND}_NAME and GIT_${KIND}_EMAIL. If those variables are not set, then it will fall back to reading the user.name and user.email settings from the specified configuration. If that also fails, then it will fall back to using the current users' identity as obtained from the host system (e.g. the gecos field, $EMAIL, $USER@$(hostname -f). Args: kind: Optional kind to return identity for, usually either "AUTHOR" or "COMMITTER". Returns: A user identity """ user = None # type: Optional[bytes] email = None # type: Optional[bytes] if kind: user_uc = os.environ.get("GIT_" + kind + "_NAME") if user_uc is not None: user = user_uc.encode("utf-8") email_uc = os.environ.get("GIT_" + kind + "_EMAIL") if email_uc is not None: email = email_uc.encode("utf-8") if user is None: try: user = config.get(("user",), "name") except KeyError: user = None if email is None: try: email = config.get(("user",), "email") except KeyError: email = None default_user, default_email = _get_default_identity() if user is None: user = default_user.encode("utf-8") if email is None: email = default_email.encode("utf-8") if email.startswith(b"<") and email.endswith(b">"): email = email[1:-1] return user + b" <" + email + b">" def check_user_identity(identity): """Verify that a user identity is formatted correctly. Args: identity: User identity bytestring Raises: InvalidUserIdentity: Raised when identity is invalid """ try: fst, snd = identity.split(b" <", 1) except ValueError: raise InvalidUserIdentity(identity) if b">" not in snd: raise InvalidUserIdentity(identity) def parse_graftpoints( graftpoints: Iterable[bytes], ) -> Dict[bytes, List[bytes]]: """Convert a list of graftpoints into a dict Args: graftpoints: Iterator of graftpoint lines Each line is formatted as: []* Resulting dictionary is: : [*] https://git.wiki.kernel.org/index.php/GraftPoint """ grafts = {} for line in graftpoints: raw_graft = line.split(None, 1) commit = raw_graft[0] if len(raw_graft) == 2: parents = raw_graft[1].split() else: parents = [] for sha in [commit] + parents: check_hexsha(sha, "Invalid graftpoint") grafts[commit] = parents return grafts def serialize_graftpoints(graftpoints: Dict[bytes, List[bytes]]) -> bytes: """Convert a dictionary of grafts into string The graft dictionary is: : [*] Each line is formatted as: []* https://git.wiki.kernel.org/index.php/GraftPoint """ graft_lines = [] for commit, parents in graftpoints.items(): if parents: graft_lines.append(commit + b" " + b" ".join(parents)) else: graft_lines.append(commit) return b"\n".join(graft_lines) def _set_filesystem_hidden(path): """Mark path as to be hidden if supported by platform and filesystem. On win32 uses SetFileAttributesW api: """ if sys.platform == "win32": import ctypes from ctypes.wintypes import BOOL, DWORD, LPCWSTR FILE_ATTRIBUTE_HIDDEN = 2 SetFileAttributesW = ctypes.WINFUNCTYPE(BOOL, LPCWSTR, DWORD)( ("SetFileAttributesW", ctypes.windll.kernel32) ) if isinstance(path, bytes): path = os.fsdecode(path) if not SetFileAttributesW(path, FILE_ATTRIBUTE_HIDDEN): pass # Could raise or log `ctypes.WinError()` here # Could implement other platform specific filesytem hiding here class ParentsProvider(object): def __init__(self, store, grafts={}, shallows=[]): self.store = store self.grafts = grafts self.shallows = set(shallows) def get_parents(self, commit_id, commit=None): try: return self.grafts[commit_id] except KeyError: pass if commit_id in self.shallows: return [] if commit is None: commit = self.store[commit_id] return commit.parents class BaseRepo(object): """Base class for a git repository. :ivar object_store: Dictionary-like object for accessing the objects :ivar refs: Dictionary-like object with the refs in this repository """ def __init__(self, object_store: BaseObjectStore, refs: RefsContainer): """Open a repository. This shouldn't be called directly, but rather through one of the base classes, such as MemoryRepo or Repo. Args: object_store: Object store to use refs: Refs container to use """ self.object_store = object_store self.refs = refs self._graftpoints = {} # type: Dict[bytes, List[bytes]] self.hooks = {} # type: Dict[str, Hook] def _determine_file_mode(self) -> bool: """Probe the file-system to determine whether permissions can be trusted. Returns: True if permissions can be trusted, False otherwise. """ raise NotImplementedError(self._determine_file_mode) def _init_files(self, bare: bool) -> None: """Initialize a default set of named files.""" from dulwich.config import ConfigFile self._put_named_file("description", b"Unnamed repository") f = BytesIO() cf = ConfigFile() cf.set("core", "repositoryformatversion", "0") if self._determine_file_mode(): cf.set("core", "filemode", True) else: cf.set("core", "filemode", False) cf.set("core", "bare", bare) cf.set("core", "logallrefupdates", True) cf.write_to_file(f) self._put_named_file("config", f.getvalue()) self._put_named_file(os.path.join("info", "exclude"), b"") def get_named_file(self, path): """Get a file from the control dir with a specific name. Although the filename should be interpreted as a filename relative to the control dir in a disk-based Repo, the object returned need not be pointing to a file in that location. Args: path: The path to the file, relative to the control dir. Returns: An open file object, or None if the file does not exist. """ raise NotImplementedError(self.get_named_file) def _put_named_file(self, path, contents): """Write a file to the control dir with the given name and contents. Args: path: The path to the file, relative to the control dir. contents: A string to write to the file. """ raise NotImplementedError(self._put_named_file) def _del_named_file(self, path): """Delete a file in the contrl directory with the given name.""" raise NotImplementedError(self._del_named_file) def open_index(self): """Open the index for this repository. Raises: NoIndexPresent: If no index is present Returns: The matching `Index` """ raise NotImplementedError(self.open_index) def fetch(self, target, determine_wants=None, progress=None, depth=None): """Fetch objects into another repository. Args: target: The target repository determine_wants: Optional function to determine what refs to fetch. progress: Optional progress function depth: Optional shallow fetch depth Returns: The local refs """ if determine_wants is None: determine_wants = target.object_store.determine_wants_all count, pack_data = self.fetch_pack_data( determine_wants, target.get_graph_walker(), progress=progress, depth=depth, ) target.object_store.add_pack_data(count, pack_data, progress) return self.get_refs() def fetch_pack_data( self, determine_wants, graph_walker, progress, get_tagged=None, depth=None, ): """Fetch the pack data required for a set of revisions. Args: determine_wants: Function that takes a dictionary with heads and returns the list of heads to fetch. graph_walker: Object that can iterate over the list of revisions to fetch and has an "ack" method that will be called to acknowledge that a revision is present. progress: Simple progress function that will be called with updated progress strings. get_tagged: Function that returns a dict of pointed-to sha -> tag sha for including tags. depth: Shallow fetch depth Returns: count and iterator over pack data """ # TODO(jelmer): Fetch pack data directly, don't create objects first. objects = self.fetch_objects( determine_wants, graph_walker, progress, get_tagged, depth=depth ) return pack_objects_to_data(objects) def fetch_objects( self, determine_wants, graph_walker, progress, get_tagged=None, depth=None, ): """Fetch the missing objects required for a set of revisions. Args: determine_wants: Function that takes a dictionary with heads and returns the list of heads to fetch. graph_walker: Object that can iterate over the list of revisions to fetch and has an "ack" method that will be called to acknowledge that a revision is present. progress: Simple progress function that will be called with updated progress strings. get_tagged: Function that returns a dict of pointed-to sha -> tag sha for including tags. depth: Shallow fetch depth Returns: iterator over objects, with __len__ implemented """ if depth not in (None, 0): raise NotImplementedError("depth not supported yet") refs = {} for ref, sha in self.get_refs().items(): try: obj = self.object_store[sha] except KeyError: warnings.warn( "ref %s points at non-present sha %s" % (ref.decode("utf-8", "replace"), sha.decode("ascii")), UserWarning, ) continue else: if isinstance(obj, Tag): refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1] refs[ref] = sha wants = determine_wants(refs) if not isinstance(wants, list): raise TypeError("determine_wants() did not return a list") shallows = getattr(graph_walker, "shallow", frozenset()) unshallows = getattr(graph_walker, "unshallow", frozenset()) if wants == []: # TODO(dborowitz): find a way to short-circuit that doesn't change # this interface. if shallows or unshallows: # Do not send a pack in shallow short-circuit path return None return [] # If the graph walker is set up with an implementation that can # ACK/NAK to the wire, it will write data to the client through # this call as a side-effect. haves = self.object_store.find_common_revisions(graph_walker) # Deal with shallow requests separately because the haves do # not reflect what objects are missing if shallows or unshallows: # TODO: filter the haves commits from iter_shas. the specific # commits aren't missing. haves = [] parents_provider = ParentsProvider(self.object_store, shallows=shallows) def get_parents(commit): return parents_provider.get_parents(commit.id, commit) return self.object_store.iter_shas( self.object_store.find_missing_objects( haves, wants, self.get_shallow(), progress, get_tagged, get_parents=get_parents, ) ) def generate_pack_data(self, have, want, progress=None, ofs_delta=None): """Generate pack data objects for a set of wants/haves. Args: have: List of SHA1s of objects that should not be sent want: List of SHA1s of objects that should be sent ofs_delta: Whether OFS deltas can be included progress: Optional progress reporting method """ return self.object_store.generate_pack_data( have, want, shallow=self.get_shallow(), progress=progress, ofs_delta=ofs_delta, ) def get_graph_walker(self, heads=None): """Retrieve a graph walker. A graph walker is used by a remote repository (or proxy) to find out which objects are present in this repository. Args: heads: Repository heads to use (optional) Returns: A graph walker object """ if heads is None: heads = [ sha for sha in self.refs.as_dict(b"refs/heads").values() if sha in self.object_store ] parents_provider = ParentsProvider(self.object_store) return ObjectStoreGraphWalker( heads, parents_provider.get_parents, shallow=self.get_shallow() ) def get_refs(self) -> Dict[bytes, bytes]: """Get dictionary with all refs. Returns: A ``dict`` mapping ref names to SHA1s """ return self.refs.as_dict() def head(self) -> bytes: """Return the SHA1 pointed at by HEAD.""" return self.refs[b"HEAD"] def _get_object(self, sha, cls): assert len(sha) in (20, 40) ret = self.get_object(sha) if not isinstance(ret, cls): if cls is Commit: raise NotCommitError(ret) elif cls is Blob: raise NotBlobError(ret) elif cls is Tree: raise NotTreeError(ret) elif cls is Tag: raise NotTagError(ret) else: raise Exception( "Type invalid: %r != %r" % (ret.type_name, cls.type_name) ) return ret def get_object(self, sha: bytes) -> ShaFile: """Retrieve the object with the specified SHA. Args: sha: SHA to retrieve Returns: A ShaFile object Raises: KeyError: when the object can not be found """ return self.object_store[sha] def parents_provider(self): return ParentsProvider( self.object_store, grafts=self._graftpoints, shallows=self.get_shallow(), ) def get_parents(self, sha: bytes, commit: Commit = None) -> List[bytes]: """Retrieve the parents of a specific commit. If the specific commit is a graftpoint, the graft parents will be returned instead. Args: sha: SHA of the commit for which to retrieve the parents commit: Optional commit matching the sha Returns: List of parents """ return self.parents_provider().get_parents(sha, commit) def get_config(self): """Retrieve the config object. Returns: `ConfigFile` object for the ``.git/config`` file. """ raise NotImplementedError(self.get_config) def get_description(self): """Retrieve the description for this repository. Returns: String with the description of the repository as set by the user. """ raise NotImplementedError(self.get_description) def set_description(self, description): """Set the description for this repository. Args: description: Text to set as description for this repository. """ raise NotImplementedError(self.set_description) def get_config_stack(self) -> "StackedConfig": """Return a config stack for this repository. This stack accesses the configuration for both this repository itself (.git/config) and the global configuration, which usually lives in ~/.gitconfig. Returns: `Config` instance for this repository """ from dulwich.config import StackedConfig backends = [self.get_config()] + StackedConfig.default_backends() return StackedConfig(backends, writable=backends[0]) def get_shallow(self): """Get the set of shallow commits. Returns: Set of shallow commits. """ f = self.get_named_file("shallow") if f is None: return set() with f: - return set(line.strip() for line in f) + return {line.strip() for line in f} def update_shallow(self, new_shallow, new_unshallow): """Update the list of shallow objects. Args: new_shallow: Newly shallow objects new_unshallow: Newly no longer shallow objects """ shallow = self.get_shallow() if new_shallow: shallow.update(new_shallow) if new_unshallow: shallow.difference_update(new_unshallow) self._put_named_file("shallow", b"".join([sha + b"\n" for sha in shallow])) def get_peeled(self, ref): """Get the peeled value of a ref. Args: ref: The refname to peel. Returns: The fully-peeled SHA1 of a tag object, after peeling all intermediate tags; if the original ref does not point to a tag, this will equal the original SHA1. """ cached = self.refs.get_peeled(ref) if cached is not None: return cached return self.object_store.peel_sha(self.refs[ref]).id def get_walker(self, include=None, *args, **kwargs): """Obtain a walker for this repository. Args: include: Iterable of SHAs of commits to include along with their ancestors. Defaults to [HEAD] exclude: Iterable of SHAs of commits to exclude along with their ancestors, overriding includes. order: ORDER_* constant specifying the order of results. Anything other than ORDER_DATE may result in O(n) memory usage. reverse: If True, reverse the order of output, requiring O(n) memory. max_entries: The maximum number of entries to yield, or None for no limit. paths: Iterable of file or subtree paths to show entries for. rename_detector: diff.RenameDetector object for detecting renames. follow: If True, follow path across renames/copies. Forces a default rename_detector. since: Timestamp to list commits after. until: Timestamp to list commits before. queue_cls: A class to use for a queue of commits, supporting the iterator protocol. The constructor takes a single argument, the Walker. Returns: A `Walker` object """ from dulwich.walk import Walker if include is None: include = [self.head()] if isinstance(include, str): include = [include] kwargs["get_parents"] = lambda commit: self.get_parents(commit.id, commit) return Walker(self.object_store, include, *args, **kwargs) def __getitem__(self, name): """Retrieve a Git object by SHA1 or ref. Args: name: A Git object SHA1 or a ref name Returns: A `ShaFile` object, such as a Commit or Blob Raises: KeyError: when the specified ref or object does not exist """ if not isinstance(name, bytes): raise TypeError( "'name' must be bytestring, not %.80s" % type(name).__name__ ) if len(name) in (20, 40): try: return self.object_store[name] except (KeyError, ValueError): pass try: return self.object_store[self.refs[name]] except RefFormatError: raise KeyError(name) def __contains__(self, name: bytes) -> bool: """Check if a specific Git object or ref is present. Args: name: Git object SHA1 or ref name """ if len(name) == 20 or (len(name) == 40 and valid_hexsha(name)): return name in self.object_store or name in self.refs else: return name in self.refs def __setitem__(self, name: bytes, value: Union[ShaFile, bytes]): """Set a ref. Args: name: ref name value: Ref value - either a ShaFile object, or a hex sha """ if name.startswith(b"refs/") or name == b"HEAD": if isinstance(value, ShaFile): self.refs[name] = value.id elif isinstance(value, bytes): self.refs[name] = value else: raise TypeError(value) else: raise ValueError(name) def __delitem__(self, name: bytes): """Remove a ref. Args: name: Name of the ref to remove """ if name.startswith(b"refs/") or name == b"HEAD": del self.refs[name] else: raise ValueError(name) def _get_user_identity(self, config: "StackedConfig", kind: str = None) -> bytes: """Determine the identity to use for new commits.""" # TODO(jelmer): Deprecate this function in favor of get_user_identity return get_user_identity(config) def _add_graftpoints(self, updated_graftpoints: Dict[bytes, List[bytes]]): """Add or modify graftpoints Args: updated_graftpoints: Dict of commit shas to list of parent shas """ # Simple validation for commit, parents in updated_graftpoints.items(): for sha in [commit] + parents: check_hexsha(sha, "Invalid graftpoint") self._graftpoints.update(updated_graftpoints) def _remove_graftpoints(self, to_remove: List[bytes] = []) -> None: """Remove graftpoints Args: to_remove: List of commit shas """ for sha in to_remove: del self._graftpoints[sha] def _read_heads(self, name): f = self.get_named_file(name) if f is None: return [] with f: return [line.strip() for line in f.readlines() if line.strip()] def do_commit( # noqa: C901 self, message=None, committer=None, author=None, commit_timestamp=None, commit_timezone=None, author_timestamp=None, author_timezone=None, tree=None, encoding=None, ref=b"HEAD", merge_heads=None, no_verify=False, ): """Create a new commit. If not specified, `committer` and `author` default to get_user_identity(..., 'COMMITTER') and get_user_identity(..., 'AUTHOR') respectively. Args: message: Commit message committer: Committer fullname author: Author fullname commit_timestamp: Commit timestamp (defaults to now) commit_timezone: Commit timestamp timezone (defaults to GMT) author_timestamp: Author timestamp (defaults to commit timestamp) author_timezone: Author timestamp timezone (defaults to commit timestamp timezone) tree: SHA1 of the tree root to use (if not specified the current index will be committed). encoding: Encoding ref: Optional ref to commit to (defaults to current branch) merge_heads: Merge heads (defaults to .git/MERGE_HEADS) no_verify: Skip pre-commit and commit-msg hooks Returns: New commit SHA1 """ - import time c = Commit() if tree is None: index = self.open_index() c.tree = index.commit(self.object_store) else: if len(tree) != 40: raise ValueError("tree must be a 40-byte hex sha string") c.tree = tree try: if not no_verify: self.hooks["pre-commit"].execute() except HookError as e: raise CommitError(e) except KeyError: # no hook defined, silent fallthrough pass config = self.get_config_stack() if merge_heads is None: merge_heads = self._read_heads("MERGE_HEADS") if committer is None: committer = get_user_identity(config, kind="COMMITTER") check_user_identity(committer) c.committer = committer if commit_timestamp is None: # FIXME: Support GIT_COMMITTER_DATE environment variable commit_timestamp = time.time() c.commit_time = int(commit_timestamp) if commit_timezone is None: # FIXME: Use current user timezone rather than UTC commit_timezone = 0 c.commit_timezone = commit_timezone if author is None: author = get_user_identity(config, kind="AUTHOR") c.author = author check_user_identity(author) if author_timestamp is None: # FIXME: Support GIT_AUTHOR_DATE environment variable author_timestamp = commit_timestamp c.author_time = int(author_timestamp) if author_timezone is None: author_timezone = commit_timezone c.author_timezone = author_timezone if encoding is None: try: encoding = config.get(("i18n",), "commitEncoding") except KeyError: pass # No dice if encoding is not None: c.encoding = encoding if message is None: # FIXME: Try to read commit message from .git/MERGE_MSG raise ValueError("No commit message specified") try: if no_verify: c.message = message else: c.message = self.hooks["commit-msg"].execute(message) if c.message is None: c.message = message except HookError as e: raise CommitError(e) except KeyError: # no hook defined, message not modified c.message = message if ref is None: # Create a dangling commit c.parents = merge_heads self.object_store.add_object(c) else: try: old_head = self.refs[ref] c.parents = [old_head] + merge_heads self.object_store.add_object(c) ok = self.refs.set_if_equals( ref, old_head, c.id, message=b"commit: " + message, committer=committer, timestamp=commit_timestamp, timezone=commit_timezone, ) except KeyError: c.parents = merge_heads self.object_store.add_object(c) ok = self.refs.add_if_new( ref, c.id, message=b"commit: " + message, committer=committer, timestamp=commit_timestamp, timezone=commit_timezone, ) if not ok: # Fail if the atomic compare-and-swap failed, leaving the # commit and all its objects as garbage. raise CommitError("%s changed during commit" % (ref,)) self._del_named_file("MERGE_HEADS") try: self.hooks["post-commit"].execute() except HookError as e: # silent failure warnings.warn("post-commit hook failed: %s" % e, UserWarning) except KeyError: # no hook defined, silent fallthrough pass return c.id def read_gitfile(f): """Read a ``.git`` file. The first line of the file should start with "gitdir: " Args: f: File-like object to read from Returns: A path """ cs = f.read() if not cs.startswith("gitdir: "): raise ValueError("Expected file to start with 'gitdir: '") return cs[len("gitdir: ") :].rstrip("\n") class UnsupportedVersion(Exception): """Unsupported repository version.""" def __init__(self, version): self.version = version class Repo(BaseRepo): """A git repository backed by local disk. To open an existing repository, call the contructor with the path of the repository. To create a new repository, use the Repo.init class method. """ def __init__(self, root, object_store=None, bare=None): hidden_path = os.path.join(root, CONTROLDIR) if bare is None: if (os.path.isfile(hidden_path) or os.path.isdir(os.path.join(hidden_path, OBJECTDIR))): bare = False elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and os.path.isdir(os.path.join(root, REFSDIR))): bare = True else: raise NotGitRepository( "No git repository was found at %(path)s" % dict(path=root) ) self.bare = bare if bare is False: if os.path.isfile(hidden_path): with open(hidden_path, "r") as f: path = read_gitfile(f) self.bare = False self._controldir = os.path.join(root, path) else: self._controldir = hidden_path else: self._controldir = root commondir = self.get_named_file(COMMONDIR) if commondir is not None: with commondir: self._commondir = os.path.join( self.controldir(), os.fsdecode(commondir.read().rstrip(b"\r\n")), ) else: self._commondir = self._controldir self.path = root config = self.get_config() try: format_version = int(config.get("core", "repositoryformatversion")) except KeyError: format_version = 0 if format_version != 0: raise UnsupportedVersion(format_version) if object_store is None: object_store = DiskObjectStore.from_config( os.path.join(self.commondir(), OBJECTDIR), config ) refs = DiskRefsContainer( self.commondir(), self._controldir, logger=self._write_reflog ) BaseRepo.__init__(self, object_store, refs) self._graftpoints = {} graft_file = self.get_named_file( os.path.join("info", "grafts"), basedir=self.commondir() ) if graft_file: with graft_file: self._graftpoints.update(parse_graftpoints(graft_file)) graft_file = self.get_named_file("shallow", basedir=self.commondir()) if graft_file: with graft_file: self._graftpoints.update(parse_graftpoints(graft_file)) self.hooks["pre-commit"] = PreCommitShellHook(self.controldir()) self.hooks["commit-msg"] = CommitMsgShellHook(self.controldir()) self.hooks["post-commit"] = PostCommitShellHook(self.controldir()) self.hooks["post-receive"] = PostReceiveShellHook(self.controldir()) def _write_reflog( self, ref, old_sha, new_sha, committer, timestamp, timezone, message ): from .reflog import format_reflog_line path = os.path.join(self.controldir(), "logs", os.fsdecode(ref)) try: os.makedirs(os.path.dirname(path)) except FileExistsError: pass if committer is None: config = self.get_config_stack() committer = self._get_user_identity(config) check_user_identity(committer) if timestamp is None: timestamp = int(time.time()) if timezone is None: timezone = 0 # FIXME with open(path, "ab") as f: f.write( format_reflog_line( old_sha, new_sha, committer, timestamp, timezone, message ) + b"\n" ) @classmethod def discover(cls, start="."): """Iterate parent directories to discover a repository Return a Repo object for the first parent directory that looks like a Git repository. Args: start: The directory to start discovery from (defaults to '.') """ remaining = True path = os.path.abspath(start) while remaining: try: return cls(path) except NotGitRepository: path, remaining = os.path.split(path) raise NotGitRepository( "No git repository was found at %(path)s" % dict(path=start) ) def controldir(self): """Return the path of the control directory.""" return self._controldir def commondir(self): """Return the path of the common directory. For a main working tree, it is identical to controldir(). For a linked working tree, it is the control directory of the main working tree.""" return self._commondir def _determine_file_mode(self): """Probe the file-system to determine whether permissions can be trusted. Returns: True if permissions can be trusted, False otherwise. """ fname = os.path.join(self.path, ".probe-permissions") with open(fname, "w") as f: f.write("") st1 = os.lstat(fname) try: os.chmod(fname, st1.st_mode ^ stat.S_IXUSR) except PermissionError: return False st2 = os.lstat(fname) os.unlink(fname) mode_differs = st1.st_mode != st2.st_mode st2_has_exec = (st2.st_mode & stat.S_IXUSR) != 0 return mode_differs and st2_has_exec def _put_named_file(self, path, contents): """Write a file to the control dir with the given name and contents. Args: path: The path to the file, relative to the control dir. contents: A string to write to the file. """ path = path.lstrip(os.path.sep) with GitFile(os.path.join(self.controldir(), path), "wb") as f: f.write(contents) def _del_named_file(self, path): try: os.unlink(os.path.join(self.controldir(), path)) except FileNotFoundError: return def get_named_file(self, path, basedir=None): """Get a file from the control dir with a specific name. Although the filename should be interpreted as a filename relative to the control dir in a disk-based Repo, the object returned need not be pointing to a file in that location. Args: path: The path to the file, relative to the control dir. basedir: Optional argument that specifies an alternative to the control dir. Returns: An open file object, or None if the file does not exist. """ # TODO(dborowitz): sanitize filenames, since this is used directly by # the dumb web serving code. if basedir is None: basedir = self.controldir() path = path.lstrip(os.path.sep) try: return open(os.path.join(basedir, path), "rb") except FileNotFoundError: return None def index_path(self): """Return path to the index file.""" return os.path.join(self.controldir(), INDEX_FILENAME) def open_index(self) -> "Index": """Open the index for this repository. Raises: NoIndexPresent: If no index is present Returns: The matching `Index` """ from dulwich.index import Index if not self.has_index(): raise NoIndexPresent() return Index(self.index_path()) def has_index(self): """Check if an index is present.""" # Bare repos must never have index files; non-bare repos may have a # missing index file, which is treated as empty. return not self.bare def stage(self, fs_paths): """Stage a set of paths. Args: fs_paths: List of paths, relative to the repository path """ root_path_bytes = os.fsencode(self.path) if not isinstance(fs_paths, list): fs_paths = [fs_paths] from dulwich.index import ( blob_from_path_and_stat, index_entry_from_stat, _fs_to_tree_path, ) index = self.open_index() blob_normalizer = self.get_blob_normalizer() for fs_path in fs_paths: if not isinstance(fs_path, bytes): fs_path = os.fsencode(fs_path) if os.path.isabs(fs_path): raise ValueError( "path %r should be relative to " "repository root, not absolute" % fs_path ) tree_path = _fs_to_tree_path(fs_path) full_path = os.path.join(root_path_bytes, fs_path) try: st = os.lstat(full_path) except OSError: # File no longer exists try: del index[tree_path] except KeyError: pass # already removed else: if not stat.S_ISREG(st.st_mode) and not stat.S_ISLNK(st.st_mode): try: del index[tree_path] except KeyError: pass else: blob = blob_from_path_and_stat(full_path, st) blob = blob_normalizer.checkin_normalize(blob, fs_path) self.object_store.add_object(blob) index[tree_path] = index_entry_from_stat(st, blob.id, 0) index.write() def clone( self, target_path, mkdir=True, bare=False, origin=b"origin", checkout=None, ): """Clone this repository. Args: target_path: Target path mkdir: Create the target directory bare: Whether to create a bare repository origin: Base name for refs in target repository cloned from this repository Returns: Created repository as `Repo` """ if not bare: target = self.init(target_path, mkdir=mkdir) else: if checkout: raise ValueError("checkout and bare are incompatible") target = self.init_bare(target_path, mkdir=mkdir) self.fetch(target) encoded_path = self.path if not isinstance(encoded_path, bytes): encoded_path = os.fsencode(encoded_path) ref_message = b"clone: from " + encoded_path target.refs.import_refs( b"refs/remotes/" + origin, self.refs.as_dict(b"refs/heads"), message=ref_message, ) target.refs.import_refs( b"refs/tags", self.refs.as_dict(b"refs/tags"), message=ref_message ) try: target.refs.add_if_new( DEFAULT_REF, self.refs[DEFAULT_REF], message=ref_message ) except KeyError: pass target_config = target.get_config() target_config.set(("remote", "origin"), "url", encoded_path) target_config.set( ("remote", "origin"), "fetch", "+refs/heads/*:refs/remotes/origin/*", ) target_config.write_to_path() # Update target head head_chain, head_sha = self.refs.follow(b"HEAD") if head_chain and head_sha is not None: target.refs.set_symbolic_ref(b"HEAD", head_chain[-1], message=ref_message) target[b"HEAD"] = head_sha if checkout is None: checkout = not bare if checkout: # Checkout HEAD to target dir target.reset_index() return target def reset_index(self, tree=None): """Reset the index back to a specific tree. Args: tree: Tree SHA to reset to, None for current HEAD tree. """ from dulwich.index import ( build_index_from_tree, validate_path_element_default, validate_path_element_ntfs, ) if tree is None: tree = self[b"HEAD"].tree config = self.get_config() honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt") if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"): validate_path_element = validate_path_element_ntfs else: validate_path_element = validate_path_element_default return build_index_from_tree( self.path, self.index_path(), self.object_store, tree, honor_filemode=honor_filemode, validate_path_element=validate_path_element, ) def get_config(self) -> "ConfigFile": """Retrieve the config object. Returns: `ConfigFile` object for the ``.git/config`` file. """ from dulwich.config import ConfigFile path = os.path.join(self._controldir, "config") try: return ConfigFile.from_path(path) except FileNotFoundError: ret = ConfigFile() ret.path = path return ret def get_description(self): """Retrieve the description of this repository. Returns: A string describing the repository or None. """ path = os.path.join(self._controldir, "description") try: with GitFile(path, "rb") as f: return f.read() except FileNotFoundError: return None def __repr__(self): return "" % self.path def set_description(self, description): """Set the description for this repository. Args: description: Text to set as description for this repository. """ self._put_named_file("description", description) @classmethod def _init_maybe_bare(cls, path, controldir, bare, object_store=None): for d in BASE_DIRECTORIES: os.mkdir(os.path.join(controldir, *d)) if object_store is None: object_store = DiskObjectStore.init(os.path.join(controldir, OBJECTDIR)) ret = cls(path, bare=bare, object_store=object_store) ret.refs.set_symbolic_ref(b"HEAD", DEFAULT_REF) ret._init_files(bare) return ret @classmethod def init(cls, path, mkdir=False): """Create a new repository. Args: path: Path in which to create the repository mkdir: Whether to create the directory Returns: `Repo` instance """ if mkdir: os.mkdir(path) controldir = os.path.join(path, CONTROLDIR) os.mkdir(controldir) _set_filesystem_hidden(controldir) return cls._init_maybe_bare(path, controldir, False) @classmethod def _init_new_working_directory(cls, path, main_repo, identifier=None, mkdir=False): """Create a new working directory linked to a repository. Args: path: Path in which to create the working tree. main_repo: Main repository to reference identifier: Worktree identifier mkdir: Whether to create the directory Returns: `Repo` instance """ if mkdir: os.mkdir(path) if identifier is None: identifier = os.path.basename(path) main_worktreesdir = os.path.join(main_repo.controldir(), WORKTREES) worktree_controldir = os.path.join(main_worktreesdir, identifier) gitdirfile = os.path.join(path, CONTROLDIR) with open(gitdirfile, "wb") as f: f.write(b"gitdir: " + os.fsencode(worktree_controldir) + b"\n") try: os.mkdir(main_worktreesdir) except FileExistsError: pass try: os.mkdir(worktree_controldir) except FileExistsError: pass with open(os.path.join(worktree_controldir, GITDIR), "wb") as f: f.write(os.fsencode(gitdirfile) + b"\n") with open(os.path.join(worktree_controldir, COMMONDIR), "wb") as f: f.write(b"../..\n") with open(os.path.join(worktree_controldir, "HEAD"), "wb") as f: f.write(main_repo.head() + b"\n") r = cls(path) r.reset_index() return r @classmethod def init_bare(cls, path, mkdir=False, object_store=None): """Create a new bare repository. ``path`` should already exist and be an empty directory. Args: path: Path to create bare repository in Returns: a `Repo` instance """ if mkdir: os.mkdir(path) return cls._init_maybe_bare(path, path, True, object_store=object_store) create = init_bare def close(self): """Close any files opened by this repository.""" self.object_store.close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def get_blob_normalizer(self): """Return a BlobNormalizer object""" # TODO Parse the git attributes files git_attributes = {} return BlobNormalizer(self.get_config_stack(), git_attributes) class MemoryRepo(BaseRepo): """Repo that stores refs, objects, and named files in memory. MemoryRepos are always bare: they have no working tree and no index, since those have a stronger dependency on the filesystem. """ def __init__(self): from dulwich.config import ConfigFile self._reflog = [] refs_container = DictRefsContainer({}, logger=self._append_reflog) BaseRepo.__init__(self, MemoryObjectStore(), refs_container) self._named_files = {} self.bare = True self._config = ConfigFile() self._description = None def _append_reflog(self, *args): self._reflog.append(args) def set_description(self, description): self._description = description def get_description(self): return self._description def _determine_file_mode(self): """Probe the file-system to determine whether permissions can be trusted. Returns: True if permissions can be trusted, False otherwise. """ return sys.platform != "win32" def _put_named_file(self, path, contents): """Write a file to the control dir with the given name and contents. Args: path: The path to the file, relative to the control dir. contents: A string to write to the file. """ self._named_files[path] = contents def _del_named_file(self, path): try: del self._named_files[path] except KeyError: pass def get_named_file(self, path, basedir=None): """Get a file from the control dir with a specific name. Although the filename should be interpreted as a filename relative to the control dir in a disk-baked Repo, the object returned need not be pointing to a file in that location. Args: path: The path to the file, relative to the control dir. Returns: An open file object, or None if the file does not exist. """ contents = self._named_files.get(path, None) if contents is None: return None return BytesIO(contents) def open_index(self): """Fail to open index for this repo, since it is bare. Raises: NoIndexPresent: Raised when no index is present """ raise NoIndexPresent() def get_config(self): """Retrieve the config object. Returns: `ConfigFile` object. """ return self._config @classmethod def init_bare(cls, objects, refs): """Create a new bare repository in memory. Args: objects: Objects for the new repository, as iterable refs: Refs as dictionary, mapping names to object SHA1s """ ret = cls() for obj in objects: ret.object_store.add_object(obj) for refname, sha in refs.items(): ret.refs.add_if_new(refname, sha) ret._init_files(bare=True) return ret diff --git a/dulwich/tests/compat/test_pack.py b/dulwich/tests/compat/test_pack.py index 8ae614f5..6da2f0ff 100644 --- a/dulwich/tests/compat/test_pack.py +++ b/dulwich/tests/compat/test_pack.py @@ -1,172 +1,172 @@ # test_pack.py -- Compatibility tests for git packs. # Copyright (C) 2010 Google, Inc. # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Compatibility tests for git packs.""" import binascii import os import re import shutil import tempfile from dulwich.pack import ( write_pack, ) from dulwich.objects import ( Blob, ) from dulwich.tests import ( SkipTest, ) from dulwich.tests.test_pack import ( a_sha, pack1_sha, PackTests, ) from dulwich.tests.compat.utils import ( require_git_version, run_git_or_fail, ) _NON_DELTA_RE = re.compile(b"non delta: (?P\\d+) objects") def _git_verify_pack_object_list(output): pack_shas = set() for line in output.splitlines(): sha = line[:40] try: binascii.unhexlify(sha) except (TypeError, binascii.Error): continue # non-sha line pack_shas.add(sha) return pack_shas class TestPack(PackTests): """Compatibility tests for reading and writing pack files.""" def setUp(self): require_git_version((1, 5, 0)) super(TestPack, self).setUp() self._tempdir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self._tempdir) def test_copy(self): with self.get_pack(pack1_sha) as origpack: self.assertSucceeds(origpack.index.check) pack_path = os.path.join(self._tempdir, "Elch") write_pack(pack_path, origpack.pack_tuples()) output = run_git_or_fail(["verify-pack", "-v", pack_path]) - orig_shas = set(o.id for o in origpack.iterobjects()) + orig_shas = {o.id for o in origpack.iterobjects()} self.assertEqual(orig_shas, _git_verify_pack_object_list(output)) def test_deltas_work(self): with self.get_pack(pack1_sha) as orig_pack: orig_blob = orig_pack[a_sha] new_blob = Blob() new_blob.data = orig_blob.data + b"x" all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)] pack_path = os.path.join(self._tempdir, "pack_with_deltas") write_pack(pack_path, all_to_pack, deltify=True) output = run_git_or_fail(["verify-pack", "-v", pack_path]) self.assertEqual( - set(x[0].id for x in all_to_pack), + {x[0].id for x in all_to_pack}, _git_verify_pack_object_list(output), ) # We specifically made a new blob that should be a delta # against the blob a_sha, so make sure we really got only 3 # non-delta objects: got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta")) self.assertEqual( 3, got_non_delta, "Expected 3 non-delta objects, got %d" % got_non_delta, ) def test_delta_medium_object(self): # This tests an object set that will have a copy operation # 2**20 in size. with self.get_pack(pack1_sha) as orig_pack: orig_blob = orig_pack[a_sha] new_blob = Blob() new_blob.data = orig_blob.data + (b"x" * 2 ** 20) new_blob_2 = Blob() new_blob_2.data = new_blob.data + b"y" all_to_pack = list(orig_pack.pack_tuples()) + [ (new_blob, None), (new_blob_2, None), ] pack_path = os.path.join(self._tempdir, "pack_with_deltas") write_pack(pack_path, all_to_pack, deltify=True) output = run_git_or_fail(["verify-pack", "-v", pack_path]) self.assertEqual( - set(x[0].id for x in all_to_pack), + {x[0].id for x in all_to_pack}, _git_verify_pack_object_list(output), ) # We specifically made a new blob that should be a delta # against the blob a_sha, so make sure we really got only 3 # non-delta objects: got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta")) self.assertEqual( 3, got_non_delta, "Expected 3 non-delta objects, got %d" % got_non_delta, ) # We expect one object to have a delta chain length of two # (new_blob_2), so let's verify that actually happens: self.assertIn(b"chain length = 2", output) # This test is SUPER slow: over 80 seconds on a 2012-era # laptop. This is because SequenceMatcher is worst-case quadratic # on the input size. It's impractical to produce deltas for # objects this large, but it's still worth doing the right thing # when it happens. def test_delta_large_object(self): # This tests an object set that will have a copy operation # 2**25 in size. This is a copy large enough that it requires # two copy operations in git's binary delta format. raise SkipTest("skipping slow, large test") with self.get_pack(pack1_sha) as orig_pack: new_blob = Blob() new_blob.data = "big blob" + ("x" * 2 ** 25) new_blob_2 = Blob() new_blob_2.data = new_blob.data + "y" all_to_pack = list(orig_pack.pack_tuples()) + [ (new_blob, None), (new_blob_2, None), ] pack_path = os.path.join(self._tempdir, "pack_with_deltas") write_pack(pack_path, all_to_pack, deltify=True) output = run_git_or_fail(["verify-pack", "-v", pack_path]) self.assertEqual( - set(x[0].id for x in all_to_pack), + {x[0].id for x in all_to_pack}, _git_verify_pack_object_list(output), ) # We specifically made a new blob that should be a delta # against the blob a_sha, so make sure we really got only 4 # non-delta objects: got_non_delta = int(_NON_DELTA_RE.search(output).group("non_delta")) self.assertEqual( 4, got_non_delta, "Expected 4 non-delta objects, got %d" % got_non_delta, ) diff --git a/dulwich/tests/compat/test_repository.py b/dulwich/tests/compat/test_repository.py index 828a03c0..e7fa282c 100644 --- a/dulwich/tests/compat/test_repository.py +++ b/dulwich/tests/compat/test_repository.py @@ -1,219 +1,219 @@ # test_repo.py -- Git repo compatibility tests # Copyright (C) 2010 Google, Inc. # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Compatibility tests for dulwich repositories.""" from io import BytesIO from itertools import chain import os import tempfile from dulwich.objects import ( hex_to_sha, ) from dulwich.repo import ( check_ref_format, Repo, ) from dulwich.tests.compat.utils import ( require_git_version, rmtree_ro, run_git_or_fail, CompatTestCase, ) class ObjectStoreTestCase(CompatTestCase): """Tests for git repository compatibility.""" def setUp(self): super(ObjectStoreTestCase, self).setUp() self._repo = self.import_repo("server_new.export") def _run_git(self, args): return run_git_or_fail(args, cwd=self._repo.path) def _parse_refs(self, output): refs = {} for line in BytesIO(output): fields = line.rstrip(b"\n").split(b" ") self.assertEqual(3, len(fields)) refname, type_name, sha = fields check_ref_format(refname[5:]) hex_to_sha(sha) refs[refname] = (type_name, sha) return refs def _parse_objects(self, output): - return set(s.rstrip(b"\n").split(b" ")[0] for s in BytesIO(output)) + return {s.rstrip(b"\n").split(b" ")[0] for s in BytesIO(output)} def test_bare(self): self.assertTrue(self._repo.bare) self.assertFalse(os.path.exists(os.path.join(self._repo.path, ".git"))) def test_head(self): output = self._run_git(["rev-parse", "HEAD"]) head_sha = output.rstrip(b"\n") hex_to_sha(head_sha) self.assertEqual(head_sha, self._repo.refs[b"HEAD"]) def test_refs(self): output = self._run_git( ["for-each-ref", "--format=%(refname) %(objecttype) %(objectname)"] ) expected_refs = self._parse_refs(output) actual_refs = {} for refname, sha in self._repo.refs.as_dict().items(): if refname == b"HEAD": continue # handled in test_head obj = self._repo[sha] self.assertEqual(sha, obj.id) actual_refs[refname] = (obj.type_name, obj.id) self.assertEqual(expected_refs, actual_refs) # TODO(dborowitz): peeled ref tests def _get_loose_shas(self): output = self._run_git(["rev-list", "--all", "--objects", "--unpacked"]) return self._parse_objects(output) def _get_all_shas(self): output = self._run_git(["rev-list", "--all", "--objects"]) return self._parse_objects(output) def assertShasMatch(self, expected_shas, actual_shas_iter): actual_shas = set() for sha in actual_shas_iter: obj = self._repo[sha] self.assertEqual(sha, obj.id) actual_shas.add(sha) self.assertEqual(expected_shas, actual_shas) def test_loose_objects(self): # TODO(dborowitz): This is currently not very useful since # fast-imported repos only contained packed objects. expected_shas = self._get_loose_shas() self.assertShasMatch( expected_shas, self._repo.object_store._iter_loose_objects() ) def test_packed_objects(self): expected_shas = self._get_all_shas() - self._get_loose_shas() self.assertShasMatch( expected_shas, chain.from_iterable(self._repo.object_store.packs) ) def test_all_objects(self): expected_shas = self._get_all_shas() self.assertShasMatch(expected_shas, iter(self._repo.object_store)) class WorkingTreeTestCase(ObjectStoreTestCase): """Test for compatibility with git-worktree.""" min_git_version = (2, 5, 0) def create_new_worktree(self, repo_dir, branch): """Create a new worktree using git-worktree. Args: repo_dir: The directory of the main working tree. branch: The branch or commit to checkout in the new worktree. Returns: The path to the new working tree. """ temp_dir = tempfile.mkdtemp() run_git_or_fail(["worktree", "add", temp_dir, branch], cwd=repo_dir) self.addCleanup(rmtree_ro, temp_dir) return temp_dir def setUp(self): super(WorkingTreeTestCase, self).setUp() self._worktree_path = self.create_new_worktree(self._repo.path, "branch") self._worktree_repo = Repo(self._worktree_path) self.addCleanup(self._worktree_repo.close) self._mainworktree_repo = self._repo self._number_of_working_tree = 2 self._repo = self._worktree_repo def test_refs(self): super(WorkingTreeTestCase, self).test_refs() self.assertEqual( self._mainworktree_repo.refs.allkeys(), self._repo.refs.allkeys() ) def test_head_equality(self): self.assertNotEqual( self._repo.refs[b"HEAD"], self._mainworktree_repo.refs[b"HEAD"] ) def test_bare(self): self.assertFalse(self._repo.bare) self.assertTrue(os.path.isfile(os.path.join(self._repo.path, ".git"))) def _parse_worktree_list(self, output): worktrees = [] for line in BytesIO(output): fields = line.rstrip(b"\n").split() worktrees.append(tuple(f.decode() for f in fields)) return worktrees def test_git_worktree_list(self): # 'git worktree list' was introduced in 2.7.0 require_git_version((2, 7, 0)) output = run_git_or_fail(["worktree", "list"], cwd=self._repo.path) worktrees = self._parse_worktree_list(output) self.assertEqual(len(worktrees), self._number_of_working_tree) self.assertEqual(worktrees[0][1], "(bare)") self.assertTrue(os.path.samefile(worktrees[0][0], self._mainworktree_repo.path)) output = run_git_or_fail(["worktree", "list"], cwd=self._mainworktree_repo.path) worktrees = self._parse_worktree_list(output) self.assertEqual(len(worktrees), self._number_of_working_tree) self.assertEqual(worktrees[0][1], "(bare)") self.assertTrue(os.path.samefile(worktrees[0][0], self._mainworktree_repo.path)) class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase): """Test compatibility of Repo.init_new_working_directory.""" min_git_version = (2, 5, 0) def setUp(self): super(InitNewWorkingDirectoryTestCase, self).setUp() self._other_worktree = self._repo worktree_repo_path = tempfile.mkdtemp() self.addCleanup(rmtree_ro, worktree_repo_path) self._repo = Repo._init_new_working_directory( worktree_repo_path, self._mainworktree_repo ) self.addCleanup(self._repo.close) self._number_of_working_tree = 3 def test_head_equality(self): self.assertEqual( self._repo.refs[b"HEAD"], self._mainworktree_repo.refs[b"HEAD"] ) def test_bare(self): self.assertFalse(self._repo.bare) self.assertTrue(os.path.isfile(os.path.join(self._repo.path, ".git"))) diff --git a/dulwich/tests/test_pack.py b/dulwich/tests/test_pack.py index 54f3ebdc..ded79b24 100644 --- a/dulwich/tests/test_pack.py +++ b/dulwich/tests/test_pack.py @@ -1,1236 +1,1236 @@ # test_pack.py -- Tests for the handling of git packs. # Copyright (C) 2007 James Westby # Copyright (C) 2008 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for Dulwich packs.""" from io import BytesIO from hashlib import sha1 import os import shutil import tempfile import zlib from dulwich.errors import ( ApplyDeltaError, ChecksumMismatch, ) from dulwich.file import ( GitFile, ) from dulwich.object_store import ( MemoryObjectStore, ) from dulwich.objects import ( hex_to_sha, sha_to_hex, Commit, Tree, Blob, ) from dulwich.pack import ( OFS_DELTA, REF_DELTA, MemoryPackIndex, Pack, PackData, apply_delta, create_delta, deltify_pack_objects, load_pack_index, UnpackedObject, read_zlib_chunks, write_pack_header, write_pack_index_v1, write_pack_index_v2, write_pack_object, write_pack, unpack_object, compute_file_sha, PackStreamReader, DeltaChainIterator, _delta_encode_size, _encode_copy_operation, ) from dulwich.tests import ( TestCase, ) from dulwich.tests.utils import ( make_object, build_pack, ) pack1_sha = b"bc63ddad95e7321ee734ea11a7a62d314e0d7481" a_sha = b"6f670c0fb53f9463760b7295fbb814e965fb20c8" tree_sha = b"b2a2766a2879c209ab1176e7e778b81ae422eeaa" commit_sha = b"f18faa16531ac570a3fdc8c7ca16682548dafd12" class PackTests(TestCase): """Base class for testing packs""" def setUp(self): super(PackTests, self).setUp() self.tempdir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.tempdir) datadir = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/packs")) def get_pack_index(self, sha): """Returns a PackIndex from the datadir with the given sha""" return load_pack_index( os.path.join(self.datadir, "pack-%s.idx" % sha.decode("ascii")) ) def get_pack_data(self, sha): """Returns a PackData object from the datadir with the given sha""" return PackData( os.path.join(self.datadir, "pack-%s.pack" % sha.decode("ascii")) ) def get_pack(self, sha): return Pack(os.path.join(self.datadir, "pack-%s" % sha.decode("ascii"))) def assertSucceeds(self, func, *args, **kwargs): try: func(*args, **kwargs) except ChecksumMismatch as e: self.fail(e) class PackIndexTests(PackTests): """Class that tests the index of packfiles""" def test_object_index(self): """Tests that the correct object offset is returned from the index.""" p = self.get_pack_index(pack1_sha) self.assertRaises(KeyError, p.object_index, pack1_sha) self.assertEqual(p.object_index(a_sha), 178) self.assertEqual(p.object_index(tree_sha), 138) self.assertEqual(p.object_index(commit_sha), 12) def test_object_sha1(self): """Tests that the correct object offset is returned from the index.""" p = self.get_pack_index(pack1_sha) self.assertRaises(KeyError, p.object_sha1, 876) self.assertEqual(p.object_sha1(178), hex_to_sha(a_sha)) self.assertEqual(p.object_sha1(138), hex_to_sha(tree_sha)) self.assertEqual(p.object_sha1(12), hex_to_sha(commit_sha)) def test_index_len(self): p = self.get_pack_index(pack1_sha) self.assertEqual(3, len(p)) def test_get_stored_checksum(self): p = self.get_pack_index(pack1_sha) self.assertEqual( b"f2848e2ad16f329ae1c92e3b95e91888daa5bd01", sha_to_hex(p.get_stored_checksum()), ) self.assertEqual( b"721980e866af9a5f93ad674144e1459b8ba3e7b7", sha_to_hex(p.get_pack_checksum()), ) def test_index_check(self): p = self.get_pack_index(pack1_sha) self.assertSucceeds(p.check) def test_iterentries(self): p = self.get_pack_index(pack1_sha) entries = [(sha_to_hex(s), o, c) for s, o, c in p.iterentries()] self.assertEqual( [ (b"6f670c0fb53f9463760b7295fbb814e965fb20c8", 178, None), (b"b2a2766a2879c209ab1176e7e778b81ae422eeaa", 138, None), (b"f18faa16531ac570a3fdc8c7ca16682548dafd12", 12, None), ], entries, ) def test_iter(self): p = self.get_pack_index(pack1_sha) self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p)) class TestPackDeltas(TestCase): test_string1 = b"The answer was flailing in the wind" test_string2 = b"The answer was falling down the pipe" test_string3 = b"zzzzz" test_string_empty = b"" test_string_big = b"Z" * 8192 test_string_huge = b"Z" * 100000 def _test_roundtrip(self, base, target): self.assertEqual( target, b"".join(apply_delta(base, create_delta(base, target))) ) def test_nochange(self): self._test_roundtrip(self.test_string1, self.test_string1) def test_nochange_huge(self): self._test_roundtrip(self.test_string_huge, self.test_string_huge) def test_change(self): self._test_roundtrip(self.test_string1, self.test_string2) def test_rewrite(self): self._test_roundtrip(self.test_string1, self.test_string3) def test_empty_to_big(self): self._test_roundtrip(self.test_string_empty, self.test_string_big) def test_empty_to_huge(self): self._test_roundtrip(self.test_string_empty, self.test_string_huge) def test_huge_copy(self): self._test_roundtrip( self.test_string_huge + self.test_string1, self.test_string_huge + self.test_string2, ) def test_dest_overflow(self): self.assertRaises( ApplyDeltaError, apply_delta, b"a" * 0x10000, b"\x80\x80\x04\x80\x80\x04\x80" + b"a" * 0x10000, ) self.assertRaises( ApplyDeltaError, apply_delta, b"", b"\x00\x80\x02\xb0\x11\x11" ) def test_pypy_issue(self): # Test for https://github.com/jelmer/dulwich/issues/509 / # https://bitbucket.org/pypy/pypy/issues/2499/cpyext-pystring_asstring-doesnt-work chunks = [ b"tree 03207ccf58880a748188836155ceed72f03d65d6\n" b"parent 408fbab530fd4abe49249a636a10f10f44d07a21\n" b"author Victor Stinner " b"1421355207 +0100\n" b"committer Victor Stinner " b"1421355207 +0100\n" b"\n" b"Backout changeset 3a06020af8cf\n" b"\nStreamWriter: close() now clears the reference to the " b"transport\n" b"\nStreamWriter now raises an exception if it is closed: " b"write(), writelines(),\n" b"write_eof(), can_write_eof(), get_extra_info(), drain().\n" ] delta = [ b"\xcd\x03\xad\x03]tree ff3c181a393d5a7270cddc01ea863818a8621ca8\n" b"parent 20a103cc90135494162e819f98d0edfc1f1fba6b\x91]7\x0510738" b"\x91\x99@\x0b10738 +0100\x93\x04\x01\xc9" ] res = apply_delta(chunks, delta) expected = [ b"tree ff3c181a393d5a7270cddc01ea863818a8621ca8\n" b"parent 20a103cc90135494162e819f98d0edfc1f1fba6b", b"\nauthor Victor Stinner 14213", b"10738", b" +0100\ncommitter Victor Stinner " b"14213", b"10738 +0100", b"\n\nStreamWriter: close() now clears the reference to the " b"transport\n\n" b"StreamWriter now raises an exception if it is closed: " b"write(), writelines(),\n" b"write_eof(), can_write_eof(), get_extra_info(), drain().\n", ] self.assertEqual(b"".join(expected), b"".join(res)) class TestPackData(PackTests): """Tests getting the data from the packfile.""" def test_create_pack(self): self.get_pack_data(pack1_sha).close() def test_from_file(self): path = os.path.join(self.datadir, "pack-%s.pack" % pack1_sha.decode("ascii")) with open(path, "rb") as f: PackData.from_file(f, os.path.getsize(path)) def test_pack_len(self): with self.get_pack_data(pack1_sha) as p: self.assertEqual(3, len(p)) def test_index_check(self): with self.get_pack_data(pack1_sha) as p: self.assertSucceeds(p.check) def test_iterobjects(self): with self.get_pack_data(pack1_sha) as p: commit_data = ( b"tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n" b"author James Westby " b"1174945067 +0100\n" b"committer James Westby " b"1174945067 +0100\n" b"\n" b"Test commit\n" ) blob_sha = b"6f670c0fb53f9463760b7295fbb814e965fb20c8" tree_data = b"100644 a\0" + hex_to_sha(blob_sha) actual = [] for offset, type_num, chunks, crc32 in p.iterobjects(): actual.append((offset, type_num, b"".join(chunks), crc32)) self.assertEqual( [ (12, 1, commit_data, 3775879613), (138, 2, tree_data, 912998690), (178, 3, b"test 1\n", 1373561701), ], actual, ) def test_iterentries(self): with self.get_pack_data(pack1_sha) as p: - entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries()) + entries = {(sha_to_hex(s), o, c) for s, o, c in p.iterentries()} self.assertEqual( set( [ ( b"6f670c0fb53f9463760b7295fbb814e965fb20c8", 178, 1373561701, ), ( b"b2a2766a2879c209ab1176e7e778b81ae422eeaa", 138, 912998690, ), ( b"f18faa16531ac570a3fdc8c7ca16682548dafd12", 12, 3775879613, ), ] ), entries, ) def test_create_index_v1(self): with self.get_pack_data(pack1_sha) as p: filename = os.path.join(self.tempdir, "v1test.idx") p.create_index_v1(filename) idx1 = load_pack_index(filename) idx2 = self.get_pack_index(pack1_sha) self.assertEqual(idx1, idx2) def test_create_index_v2(self): with self.get_pack_data(pack1_sha) as p: filename = os.path.join(self.tempdir, "v2test.idx") p.create_index_v2(filename) idx1 = load_pack_index(filename) idx2 = self.get_pack_index(pack1_sha) self.assertEqual(idx1, idx2) def test_compute_file_sha(self): f = BytesIO(b"abcd1234wxyz") self.assertEqual( sha1(b"abcd1234wxyz").hexdigest(), compute_file_sha(f).hexdigest() ) self.assertEqual( sha1(b"abcd1234wxyz").hexdigest(), compute_file_sha(f, buffer_size=5).hexdigest(), ) self.assertEqual( sha1(b"abcd1234").hexdigest(), compute_file_sha(f, end_ofs=-4).hexdigest(), ) self.assertEqual( sha1(b"1234wxyz").hexdigest(), compute_file_sha(f, start_ofs=4).hexdigest(), ) self.assertEqual( sha1(b"1234").hexdigest(), compute_file_sha(f, start_ofs=4, end_ofs=-4).hexdigest(), ) def test_compute_file_sha_short_file(self): f = BytesIO(b"abcd1234wxyz") self.assertRaises(AssertionError, compute_file_sha, f, end_ofs=-20) self.assertRaises(AssertionError, compute_file_sha, f, end_ofs=20) self.assertRaises( AssertionError, compute_file_sha, f, start_ofs=10, end_ofs=-12 ) class TestPack(PackTests): def test_len(self): with self.get_pack(pack1_sha) as p: self.assertEqual(3, len(p)) def test_contains(self): with self.get_pack(pack1_sha) as p: self.assertTrue(tree_sha in p) def test_get(self): with self.get_pack(pack1_sha) as p: self.assertEqual(type(p[tree_sha]), Tree) def test_iter(self): with self.get_pack(pack1_sha) as p: self.assertEqual(set([tree_sha, commit_sha, a_sha]), set(p)) def test_iterobjects(self): with self.get_pack(pack1_sha) as p: expected = set([p[s] for s in [commit_sha, tree_sha, a_sha]]) self.assertEqual(expected, set(list(p.iterobjects()))) def test_pack_tuples(self): with self.get_pack(pack1_sha) as p: tuples = p.pack_tuples() expected = set([(p[s], None) for s in [commit_sha, tree_sha, a_sha]]) self.assertEqual(expected, set(list(tuples))) self.assertEqual(expected, set(list(tuples))) self.assertEqual(3, len(tuples)) def test_get_object_at(self): """Tests random access for non-delta objects""" with self.get_pack(pack1_sha) as p: obj = p[a_sha] self.assertEqual(obj.type_name, b"blob") self.assertEqual(obj.sha().hexdigest().encode("ascii"), a_sha) obj = p[tree_sha] self.assertEqual(obj.type_name, b"tree") self.assertEqual(obj.sha().hexdigest().encode("ascii"), tree_sha) obj = p[commit_sha] self.assertEqual(obj.type_name, b"commit") self.assertEqual(obj.sha().hexdigest().encode("ascii"), commit_sha) def test_copy(self): with self.get_pack(pack1_sha) as origpack: self.assertSucceeds(origpack.index.check) basename = os.path.join(self.tempdir, "Elch") write_pack(basename, origpack.pack_tuples()) with Pack(basename) as newpack: self.assertEqual(origpack, newpack) self.assertSucceeds(newpack.index.check) self.assertEqual(origpack.name(), newpack.name()) self.assertEqual( origpack.index.get_pack_checksum(), newpack.index.get_pack_checksum(), ) wrong_version = origpack.index.version != newpack.index.version orig_checksum = origpack.index.get_stored_checksum() new_checksum = newpack.index.get_stored_checksum() self.assertTrue(wrong_version or orig_checksum == new_checksum) def test_commit_obj(self): with self.get_pack(pack1_sha) as p: commit = p[commit_sha] self.assertEqual(b"James Westby ", commit.author) self.assertEqual([], commit.parents) def _copy_pack(self, origpack): basename = os.path.join(self.tempdir, "somepack") write_pack(basename, origpack.pack_tuples()) return Pack(basename) def test_keep_no_message(self): with self.get_pack(pack1_sha) as p: p = self._copy_pack(p) with p: keepfile_name = p.keep() # file should exist self.assertTrue(os.path.exists(keepfile_name)) with open(keepfile_name, "r") as f: buf = f.read() self.assertEqual("", buf) def test_keep_message(self): with self.get_pack(pack1_sha) as p: p = self._copy_pack(p) msg = b"some message" with p: keepfile_name = p.keep(msg) # file should exist self.assertTrue(os.path.exists(keepfile_name)) # and contain the right message, with a linefeed with open(keepfile_name, "rb") as f: buf = f.read() self.assertEqual(msg + b"\n", buf) def test_name(self): with self.get_pack(pack1_sha) as p: self.assertEqual(pack1_sha, p.name()) def test_length_mismatch(self): with self.get_pack_data(pack1_sha) as data: index = self.get_pack_index(pack1_sha) Pack.from_objects(data, index).check_length_and_checksum() data._file.seek(12) bad_file = BytesIO() write_pack_header(bad_file, 9999) bad_file.write(data._file.read()) bad_file = BytesIO(bad_file.getvalue()) bad_data = PackData("", file=bad_file) bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index) self.assertRaises(AssertionError, lambda: bad_pack.data) self.assertRaises( - AssertionError, lambda: bad_pack.check_length_and_checksum() + AssertionError, bad_pack.check_length_and_checksum ) def test_checksum_mismatch(self): with self.get_pack_data(pack1_sha) as data: index = self.get_pack_index(pack1_sha) Pack.from_objects(data, index).check_length_and_checksum() data._file.seek(0) bad_file = BytesIO(data._file.read()[:-20] + (b"\xff" * 20)) bad_data = PackData("", file=bad_file) bad_pack = Pack.from_lazy_objects(lambda: bad_data, lambda: index) self.assertRaises(ChecksumMismatch, lambda: bad_pack.data) self.assertRaises( - ChecksumMismatch, lambda: bad_pack.check_length_and_checksum() + ChecksumMismatch, bad_pack.check_length_and_checksum ) def test_iterobjects_2(self): with self.get_pack(pack1_sha) as p: - objs = dict((o.id, o) for o in p.iterobjects()) + objs = {o.id: o for o in p.iterobjects()} self.assertEqual(3, len(objs)) self.assertEqual(sorted(objs), sorted(p.index)) self.assertTrue(isinstance(objs[a_sha], Blob)) self.assertTrue(isinstance(objs[tree_sha], Tree)) self.assertTrue(isinstance(objs[commit_sha], Commit)) class TestThinPack(PackTests): def setUp(self): super(TestThinPack, self).setUp() self.store = MemoryObjectStore() self.blobs = {} for blob in (b"foo", b"bar", b"foo1234", b"bar2468"): self.blobs[blob] = make_object(Blob, data=blob) self.store.add_object(self.blobs[b"foo"]) self.store.add_object(self.blobs[b"bar"]) # Build a thin pack. 'foo' is as an external reference, 'bar' an # internal reference. self.pack_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, self.pack_dir) self.pack_prefix = os.path.join(self.pack_dir, "pack") with open(self.pack_prefix + ".pack", "wb") as f: build_pack( f, [ (REF_DELTA, (self.blobs[b"foo"].id, b"foo1234")), (Blob.type_num, b"bar"), (REF_DELTA, (self.blobs[b"bar"].id, b"bar2468")), ], store=self.store, ) # Index the new pack. with self.make_pack(True) as pack: with PackData(pack._data_path) as data: data.pack = pack data.create_index(self.pack_prefix + ".idx") del self.store[self.blobs[b"bar"].id] def make_pack(self, resolve_ext_ref): return Pack( self.pack_prefix, resolve_ext_ref=self.store.get_raw if resolve_ext_ref else None, ) def test_get_raw(self): with self.make_pack(False) as p: self.assertRaises(KeyError, p.get_raw, self.blobs[b"foo1234"].id) with self.make_pack(True) as p: self.assertEqual((3, b"foo1234"), p.get_raw(self.blobs[b"foo1234"].id)) def test_get_raw_unresolved(self): with self.make_pack(False) as p: self.assertEqual( ( 7, b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c", [b"x\x9ccf\x9f\xc0\xccbhdl\x02\x00\x06f\x01l"], ), p.get_raw_unresolved(self.blobs[b"foo1234"].id), ) with self.make_pack(True) as p: self.assertEqual( ( 7, b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c", [b"x\x9ccf\x9f\xc0\xccbhdl\x02\x00\x06f\x01l"], ), p.get_raw_unresolved(self.blobs[b"foo1234"].id), ) def test_iterobjects(self): with self.make_pack(False) as p: self.assertRaises(KeyError, list, p.iterobjects()) with self.make_pack(True) as p: self.assertEqual( sorted( [ self.blobs[b"foo1234"].id, self.blobs[b"bar"].id, self.blobs[b"bar2468"].id, ] ), sorted(o.id for o in p.iterobjects()), ) class WritePackTests(TestCase): def test_write_pack_header(self): f = BytesIO() write_pack_header(f, 42) self.assertEqual(b"PACK\x00\x00\x00\x02\x00\x00\x00*", f.getvalue()) def test_write_pack_object(self): f = BytesIO() f.write(b"header") offset = f.tell() crc32 = write_pack_object(f, Blob.type_num, b"blob") self.assertEqual(crc32, zlib.crc32(f.getvalue()[6:]) & 0xFFFFFFFF) f.write(b"x") # unpack_object needs extra trailing data. f.seek(offset) unpacked, unused = unpack_object(f.read, compute_crc32=True) self.assertEqual(Blob.type_num, unpacked.pack_type_num) self.assertEqual(Blob.type_num, unpacked.obj_type_num) self.assertEqual([b"blob"], unpacked.decomp_chunks) self.assertEqual(crc32, unpacked.crc32) self.assertEqual(b"x", unused) def test_write_pack_object_sha(self): f = BytesIO() f.write(b"header") offset = f.tell() sha_a = sha1(b"foo") sha_b = sha_a.copy() write_pack_object(f, Blob.type_num, b"blob", sha=sha_a) self.assertNotEqual(sha_a.digest(), sha_b.digest()) sha_b.update(f.getvalue()[offset:]) self.assertEqual(sha_a.digest(), sha_b.digest()) def test_write_pack_object_compression_level(self): f = BytesIO() f.write(b"header") offset = f.tell() sha_a = sha1(b"foo") sha_b = sha_a.copy() write_pack_object(f, Blob.type_num, b"blob", sha=sha_a, compression_level=6) self.assertNotEqual(sha_a.digest(), sha_b.digest()) sha_b.update(f.getvalue()[offset:]) self.assertEqual(sha_a.digest(), sha_b.digest()) pack_checksum = hex_to_sha("721980e866af9a5f93ad674144e1459b8ba3e7b7") class BaseTestPackIndexWriting(object): def assertSucceeds(self, func, *args, **kwargs): try: func(*args, **kwargs) except ChecksumMismatch as e: self.fail(e) def index(self, filename, entries, pack_checksum): raise NotImplementedError(self.index) def test_empty(self): idx = self.index("empty.idx", [], pack_checksum) self.assertEqual(idx.get_pack_checksum(), pack_checksum) self.assertEqual(0, len(idx)) def test_large(self): entry1_sha = hex_to_sha("4e6388232ec39792661e2e75db8fb117fc869ce6") entry2_sha = hex_to_sha("e98f071751bd77f59967bfa671cd2caebdccc9a2") entries = [ (entry1_sha, 0xF2972D0830529B87, 24), (entry2_sha, (~0xF2972D0830529B87) & (2 ** 64 - 1), 92), ] if not self._supports_large: self.assertRaises( TypeError, self.index, "single.idx", entries, pack_checksum ) return idx = self.index("single.idx", entries, pack_checksum) self.assertEqual(idx.get_pack_checksum(), pack_checksum) self.assertEqual(2, len(idx)) actual_entries = list(idx.iterentries()) self.assertEqual(len(entries), len(actual_entries)) for mine, actual in zip(entries, actual_entries): my_sha, my_offset, my_crc = mine actual_sha, actual_offset, actual_crc = actual self.assertEqual(my_sha, actual_sha) self.assertEqual(my_offset, actual_offset) if self._has_crc32_checksum: self.assertEqual(my_crc, actual_crc) else: self.assertTrue(actual_crc is None) def test_single(self): entry_sha = hex_to_sha("6f670c0fb53f9463760b7295fbb814e965fb20c8") my_entries = [(entry_sha, 178, 42)] idx = self.index("single.idx", my_entries, pack_checksum) self.assertEqual(idx.get_pack_checksum(), pack_checksum) self.assertEqual(1, len(idx)) actual_entries = list(idx.iterentries()) self.assertEqual(len(my_entries), len(actual_entries)) for mine, actual in zip(my_entries, actual_entries): my_sha, my_offset, my_crc = mine actual_sha, actual_offset, actual_crc = actual self.assertEqual(my_sha, actual_sha) self.assertEqual(my_offset, actual_offset) if self._has_crc32_checksum: self.assertEqual(my_crc, actual_crc) else: self.assertTrue(actual_crc is None) class BaseTestFilePackIndexWriting(BaseTestPackIndexWriting): def setUp(self): self.tempdir = tempfile.mkdtemp() def tearDown(self): shutil.rmtree(self.tempdir) def index(self, filename, entries, pack_checksum): path = os.path.join(self.tempdir, filename) self.writeIndex(path, entries, pack_checksum) idx = load_pack_index(path) self.assertSucceeds(idx.check) self.assertEqual(idx.version, self._expected_version) return idx def writeIndex(self, filename, entries, pack_checksum): # FIXME: Write to BytesIO instead rather than hitting disk ? with GitFile(filename, "wb") as f: self._write_fn(f, entries, pack_checksum) class TestMemoryIndexWriting(TestCase, BaseTestPackIndexWriting): def setUp(self): TestCase.setUp(self) self._has_crc32_checksum = True self._supports_large = True def index(self, filename, entries, pack_checksum): return MemoryPackIndex(entries, pack_checksum) def tearDown(self): TestCase.tearDown(self) class TestPackIndexWritingv1(TestCase, BaseTestFilePackIndexWriting): def setUp(self): TestCase.setUp(self) BaseTestFilePackIndexWriting.setUp(self) self._has_crc32_checksum = False self._expected_version = 1 self._supports_large = False self._write_fn = write_pack_index_v1 def tearDown(self): TestCase.tearDown(self) BaseTestFilePackIndexWriting.tearDown(self) class TestPackIndexWritingv2(TestCase, BaseTestFilePackIndexWriting): def setUp(self): TestCase.setUp(self) BaseTestFilePackIndexWriting.setUp(self) self._has_crc32_checksum = True self._supports_large = True self._expected_version = 2 self._write_fn = write_pack_index_v2 def tearDown(self): TestCase.tearDown(self) BaseTestFilePackIndexWriting.tearDown(self) class ReadZlibTests(TestCase): decomp = ( b"tree 4ada885c9196b6b6fa08744b5862bf92896fc002\n" b"parent None\n" b"author Jelmer Vernooij 1228980214 +0000\n" b"committer Jelmer Vernooij 1228980214 +0000\n" b"\n" b"Provide replacement for mmap()'s offset argument." ) comp = zlib.compress(decomp) extra = b"nextobject" def setUp(self): super(ReadZlibTests, self).setUp() self.read = BytesIO(self.comp + self.extra).read self.unpacked = UnpackedObject(Tree.type_num, None, len(self.decomp), 0) def test_decompress_size(self): good_decomp_len = len(self.decomp) self.unpacked.decomp_len = -1 self.assertRaises(ValueError, read_zlib_chunks, self.read, self.unpacked) self.unpacked.decomp_len = good_decomp_len - 1 self.assertRaises(zlib.error, read_zlib_chunks, self.read, self.unpacked) self.unpacked.decomp_len = good_decomp_len + 1 self.assertRaises(zlib.error, read_zlib_chunks, self.read, self.unpacked) def test_decompress_truncated(self): read = BytesIO(self.comp[:10]).read self.assertRaises(zlib.error, read_zlib_chunks, read, self.unpacked) read = BytesIO(self.comp).read self.assertRaises(zlib.error, read_zlib_chunks, read, self.unpacked) def test_decompress_empty(self): unpacked = UnpackedObject(Tree.type_num, None, 0, None) comp = zlib.compress(b"") read = BytesIO(comp + self.extra).read unused = read_zlib_chunks(read, unpacked) self.assertEqual(b"", b"".join(unpacked.decomp_chunks)) self.assertNotEqual(b"", unused) self.assertEqual(self.extra, unused + read()) def test_decompress_no_crc32(self): self.unpacked.crc32 = None read_zlib_chunks(self.read, self.unpacked) self.assertEqual(None, self.unpacked.crc32) def _do_decompress_test(self, buffer_size, **kwargs): unused = read_zlib_chunks( self.read, self.unpacked, buffer_size=buffer_size, **kwargs ) self.assertEqual(self.decomp, b"".join(self.unpacked.decomp_chunks)) self.assertEqual(zlib.crc32(self.comp), self.unpacked.crc32) self.assertNotEqual(b"", unused) self.assertEqual(self.extra, unused + self.read()) def test_simple_decompress(self): self._do_decompress_test(4096) self.assertEqual(None, self.unpacked.comp_chunks) # These buffer sizes are not intended to be realistic, but rather simulate # larger buffer sizes that may end at various places. def test_decompress_buffer_size_1(self): self._do_decompress_test(1) def test_decompress_buffer_size_2(self): self._do_decompress_test(2) def test_decompress_buffer_size_3(self): self._do_decompress_test(3) def test_decompress_buffer_size_4(self): self._do_decompress_test(4) def test_decompress_include_comp(self): self._do_decompress_test(4096, include_comp=True) self.assertEqual(self.comp, b"".join(self.unpacked.comp_chunks)) class DeltifyTests(TestCase): def test_empty(self): self.assertEqual([], list(deltify_pack_objects([]))) def test_single(self): b = Blob.from_string(b"foo") self.assertEqual( [(b.type_num, b.sha().digest(), None, b.as_raw_string())], list(deltify_pack_objects([(b, b"")])), ) def test_simple_delta(self): b1 = Blob.from_string(b"a" * 101) b2 = Blob.from_string(b"a" * 100) delta = create_delta(b1.as_raw_string(), b2.as_raw_string()) self.assertEqual( [ (b1.type_num, b1.sha().digest(), None, b1.as_raw_string()), (b2.type_num, b2.sha().digest(), b1.sha().digest(), delta), ], list(deltify_pack_objects([(b1, b""), (b2, b"")])), ) class TestPackStreamReader(TestCase): def test_read_objects_emtpy(self): f = BytesIO() build_pack(f, []) reader = PackStreamReader(f.read) self.assertEqual(0, len(list(reader.read_objects()))) def test_read_objects(self): f = BytesIO() entries = build_pack( f, [ (Blob.type_num, b"blob"), (OFS_DELTA, (0, b"blob1")), ], ) reader = PackStreamReader(f.read) objects = list(reader.read_objects(compute_crc32=True)) self.assertEqual(2, len(objects)) unpacked_blob, unpacked_delta = objects self.assertEqual(entries[0][0], unpacked_blob.offset) self.assertEqual(Blob.type_num, unpacked_blob.pack_type_num) self.assertEqual(Blob.type_num, unpacked_blob.obj_type_num) self.assertEqual(None, unpacked_blob.delta_base) self.assertEqual(b"blob", b"".join(unpacked_blob.decomp_chunks)) self.assertEqual(entries[0][4], unpacked_blob.crc32) self.assertEqual(entries[1][0], unpacked_delta.offset) self.assertEqual(OFS_DELTA, unpacked_delta.pack_type_num) self.assertEqual(None, unpacked_delta.obj_type_num) self.assertEqual( unpacked_delta.offset - unpacked_blob.offset, unpacked_delta.delta_base, ) delta = create_delta(b"blob", b"blob1") self.assertEqual(delta, b"".join(unpacked_delta.decomp_chunks)) self.assertEqual(entries[1][4], unpacked_delta.crc32) def test_read_objects_buffered(self): f = BytesIO() build_pack( f, [ (Blob.type_num, b"blob"), (OFS_DELTA, (0, b"blob1")), ], ) reader = PackStreamReader(f.read, zlib_bufsize=4) self.assertEqual(2, len(list(reader.read_objects()))) def test_read_objects_empty(self): reader = PackStreamReader(BytesIO().read) self.assertEqual([], list(reader.read_objects())) class TestPackIterator(DeltaChainIterator): _compute_crc32 = True def __init__(self, *args, **kwargs): super(TestPackIterator, self).__init__(*args, **kwargs) self._unpacked_offsets = set() def _result(self, unpacked): """Return entries in the same format as build_pack.""" return ( unpacked.offset, unpacked.obj_type_num, b"".join(unpacked.obj_chunks), unpacked.sha(), unpacked.crc32, ) def _resolve_object(self, offset, pack_type_num, base_chunks): assert offset not in self._unpacked_offsets, ( "Attempted to re-inflate offset %i" % offset ) self._unpacked_offsets.add(offset) return super(TestPackIterator, self)._resolve_object( offset, pack_type_num, base_chunks ) class DeltaChainIteratorTests(TestCase): def setUp(self): super(DeltaChainIteratorTests, self).setUp() self.store = MemoryObjectStore() self.fetched = set() def store_blobs(self, blobs_data): blobs = [] for data in blobs_data: blob = make_object(Blob, data=data) blobs.append(blob) self.store.add_object(blob) return blobs def get_raw_no_repeat(self, bin_sha): """Wrapper around store.get_raw that doesn't allow repeat lookups.""" hex_sha = sha_to_hex(bin_sha) self.assertFalse( hex_sha in self.fetched, "Attempted to re-fetch object %s" % hex_sha, ) self.fetched.add(hex_sha) return self.store.get_raw(hex_sha) def make_pack_iter(self, f, thin=None): if thin is None: thin = bool(list(self.store)) resolve_ext_ref = thin and self.get_raw_no_repeat or None data = PackData("test.pack", file=f) return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref) def assertEntriesMatch(self, expected_indexes, entries, pack_iter): expected = [entries[i] for i in expected_indexes] self.assertEqual(expected, list(pack_iter._walk_all_chains())) def test_no_deltas(self): f = BytesIO() entries = build_pack( f, [ (Commit.type_num, b"commit"), (Blob.type_num, b"blob"), (Tree.type_num, b"tree"), ], ) self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f)) def test_ofs_deltas(self): f = BytesIO() entries = build_pack( f, [ (Blob.type_num, b"blob"), (OFS_DELTA, (0, b"blob1")), (OFS_DELTA, (0, b"blob2")), ], ) self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f)) def test_ofs_deltas_chain(self): f = BytesIO() entries = build_pack( f, [ (Blob.type_num, b"blob"), (OFS_DELTA, (0, b"blob1")), (OFS_DELTA, (1, b"blob2")), ], ) self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f)) def test_ref_deltas(self): f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (1, b"blob1")), (Blob.type_num, (b"blob")), (REF_DELTA, (1, b"blob2")), ], ) self.assertEntriesMatch([1, 0, 2], entries, self.make_pack_iter(f)) def test_ref_deltas_chain(self): f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (2, b"blob1")), (Blob.type_num, (b"blob")), (REF_DELTA, (1, b"blob2")), ], ) self.assertEntriesMatch([1, 2, 0], entries, self.make_pack_iter(f)) def test_ofs_and_ref_deltas(self): # Deltas pending on this offset are popped before deltas depending on # this ref. f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (1, b"blob1")), (Blob.type_num, (b"blob")), (OFS_DELTA, (1, b"blob2")), ], ) self.assertEntriesMatch([1, 2, 0], entries, self.make_pack_iter(f)) def test_mixed_chain(self): f = BytesIO() entries = build_pack( f, [ (Blob.type_num, b"blob"), (REF_DELTA, (2, b"blob2")), (OFS_DELTA, (0, b"blob1")), (OFS_DELTA, (1, b"blob3")), (OFS_DELTA, (0, b"bob")), ], ) self.assertEntriesMatch([0, 2, 4, 1, 3], entries, self.make_pack_iter(f)) def test_long_chain(self): n = 100 objects_spec = [(Blob.type_num, b"blob")] for i in range(n): objects_spec.append((OFS_DELTA, (i, b"blob" + str(i).encode("ascii")))) f = BytesIO() entries = build_pack(f, objects_spec) self.assertEntriesMatch(range(n + 1), entries, self.make_pack_iter(f)) def test_branchy_chain(self): n = 100 objects_spec = [(Blob.type_num, b"blob")] for i in range(n): objects_spec.append((OFS_DELTA, (0, b"blob" + str(i).encode("ascii")))) f = BytesIO() entries = build_pack(f, objects_spec) self.assertEntriesMatch(range(n + 1), entries, self.make_pack_iter(f)) def test_ext_ref(self): (blob,) = self.store_blobs([b"blob"]) f = BytesIO() entries = build_pack(f, [(REF_DELTA, (blob.id, b"blob1"))], store=self.store) pack_iter = self.make_pack_iter(f) self.assertEntriesMatch([0], entries, pack_iter) self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs()) def test_ext_ref_chain(self): (blob,) = self.store_blobs([b"blob"]) f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (1, b"blob2")), (REF_DELTA, (blob.id, b"blob1")), ], store=self.store, ) pack_iter = self.make_pack_iter(f) self.assertEntriesMatch([1, 0], entries, pack_iter) self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs()) def test_ext_ref_chain_degenerate(self): # Test a degenerate case where the sender is sending a REF_DELTA # object that expands to an object already in the repository. (blob,) = self.store_blobs([b"blob"]) (blob2,) = self.store_blobs([b"blob2"]) assert blob.id < blob2.id f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (blob.id, b"blob2")), (REF_DELTA, (0, b"blob3")), ], store=self.store, ) pack_iter = self.make_pack_iter(f) self.assertEntriesMatch([0, 1], entries, pack_iter) self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs()) def test_ext_ref_multiple_times(self): (blob,) = self.store_blobs([b"blob"]) f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (blob.id, b"blob1")), (REF_DELTA, (blob.id, b"blob2")), ], store=self.store, ) pack_iter = self.make_pack_iter(f) self.assertEntriesMatch([0, 1], entries, pack_iter) self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs()) def test_multiple_ext_refs(self): b1, b2 = self.store_blobs([b"foo", b"bar"]) f = BytesIO() entries = build_pack( f, [ (REF_DELTA, (b1.id, b"foo1")), (REF_DELTA, (b2.id, b"bar2")), ], store=self.store, ) pack_iter = self.make_pack_iter(f) self.assertEntriesMatch([0, 1], entries, pack_iter) self.assertEqual([hex_to_sha(b1.id), hex_to_sha(b2.id)], pack_iter.ext_refs()) def test_bad_ext_ref_non_thin_pack(self): (blob,) = self.store_blobs([b"blob"]) f = BytesIO() build_pack(f, [(REF_DELTA, (blob.id, b"blob1"))], store=self.store) pack_iter = self.make_pack_iter(f, thin=False) try: list(pack_iter._walk_all_chains()) self.fail() except KeyError as e: self.assertEqual(([blob.id],), e.args) def test_bad_ext_ref_thin_pack(self): b1, b2, b3 = self.store_blobs([b"foo", b"bar", b"baz"]) f = BytesIO() build_pack( f, [ (REF_DELTA, (1, b"foo99")), (REF_DELTA, (b1.id, b"foo1")), (REF_DELTA, (b2.id, b"bar2")), (REF_DELTA, (b3.id, b"baz3")), ], store=self.store, ) del self.store[b2.id] del self.store[b3.id] pack_iter = self.make_pack_iter(f) try: list(pack_iter._walk_all_chains()) self.fail() except KeyError as e: self.assertEqual((sorted([b2.id, b3.id]),), (sorted(e.args[0]),)) class DeltaEncodeSizeTests(TestCase): def test_basic(self): self.assertEqual(b"\x00", _delta_encode_size(0)) self.assertEqual(b"\x01", _delta_encode_size(1)) self.assertEqual(b"\xfa\x01", _delta_encode_size(250)) self.assertEqual(b"\xe8\x07", _delta_encode_size(1000)) self.assertEqual(b"\xa0\x8d\x06", _delta_encode_size(100000)) class EncodeCopyOperationTests(TestCase): def test_basic(self): self.assertEqual(b"\x80", _encode_copy_operation(0, 0)) self.assertEqual(b"\x91\x01\x0a", _encode_copy_operation(1, 10)) self.assertEqual(b"\xb1\x64\xe8\x03", _encode_copy_operation(100, 1000)) self.assertEqual(b"\x93\xe8\x03\x01", _encode_copy_operation(1000, 1)) diff --git a/dulwich/tests/test_walk.py b/dulwich/tests/test_walk.py index 86a35ab4..d77ede05 100644 --- a/dulwich/tests/test_walk.py +++ b/dulwich/tests/test_walk.py @@ -1,632 +1,632 @@ # test_walk.py -- Tests for commit walking functionality. # Copyright (C) 2010 Google, Inc. # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for commit walking functionality.""" from itertools import ( permutations, ) from unittest import expectedFailure from dulwich.diff_tree import ( CHANGE_MODIFY, CHANGE_RENAME, TreeChange, RenameDetector, ) from dulwich.errors import ( MissingCommitError, ) from dulwich.object_store import ( MemoryObjectStore, ) from dulwich.objects import ( Commit, Blob, ) from dulwich.walk import ORDER_TOPO, WalkEntry, Walker, _topo_reorder from dulwich.tests import TestCase from dulwich.tests.utils import ( F, make_object, make_tag, build_commit_graph, ) class TestWalkEntry(object): def __init__(self, commit, changes): self.commit = commit self.changes = changes def __repr__(self): return "" % ( self.commit.id, self.changes, ) def __eq__(self, other): if not isinstance(other, WalkEntry) or self.commit != other.commit: return False if self.changes is None: return True return self.changes == other.changes() class WalkerTest(TestCase): def setUp(self): super(WalkerTest, self).setUp() self.store = MemoryObjectStore() def make_commits(self, commit_spec, **kwargs): times = kwargs.pop("times", []) attrs = kwargs.pop("attrs", {}) for i, t in enumerate(times): attrs.setdefault(i + 1, {})["commit_time"] = t return build_commit_graph(self.store, commit_spec, attrs=attrs, **kwargs) def make_linear_commits(self, num_commits, **kwargs): commit_spec = [] for i in range(1, num_commits + 1): c = [i] if i > 1: c.append(i - 1) commit_spec.append(c) return self.make_commits(commit_spec, **kwargs) def assertWalkYields(self, expected, *args, **kwargs): walker = Walker(self.store, *args, **kwargs) expected = list(expected) for i, entry in enumerate(expected): if isinstance(entry, Commit): expected[i] = TestWalkEntry(entry, None) actual = list(walker) self.assertEqual(expected, actual) def test_tag(self): c1, c2, c3 = self.make_linear_commits(3) t2 = make_tag(target=c2) self.store.add_object(t2) self.assertWalkYields([c2, c1], [t2.id]) def test_linear(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([c1], [c1.id]) self.assertWalkYields([c2, c1], [c2.id]) self.assertWalkYields([c3, c2, c1], [c3.id]) self.assertWalkYields([c3, c2, c1], [c3.id, c1.id]) self.assertWalkYields([c3, c2], [c3.id], exclude=[c1.id]) self.assertWalkYields([c3, c2], [c3.id, c1.id], exclude=[c1.id]) self.assertWalkYields([c3], [c3.id, c1.id], exclude=[c2.id]) def test_missing(self): cs = list(reversed(self.make_linear_commits(20))) self.assertWalkYields(cs, [cs[0].id]) # Exactly how close we can get to a missing commit depends on our # implementation (in particular the choice of _MAX_EXTRA_COMMITS), but # we should at least be able to walk some history in a broken repo. del self.store[cs[-1].id] for i in range(1, 11): self.assertWalkYields(cs[:i], [cs[0].id], max_entries=i) self.assertRaises(MissingCommitError, Walker, self.store, [cs[-1].id]) def test_branch(self): c1, x2, x3, y4 = self.make_commits([[1], [2, 1], [3, 2], [4, 1]]) self.assertWalkYields([x3, x2, c1], [x3.id]) self.assertWalkYields([y4, c1], [y4.id]) self.assertWalkYields([y4, x2, c1], [y4.id, x2.id]) self.assertWalkYields([y4, x2], [y4.id, x2.id], exclude=[c1.id]) self.assertWalkYields([y4, x3], [y4.id, x3.id], exclude=[x2.id]) self.assertWalkYields([y4], [y4.id], exclude=[x3.id]) self.assertWalkYields([x3, x2], [x3.id], exclude=[y4.id]) def test_merge(self): c1, c2, c3, c4 = self.make_commits([[1], [2, 1], [3, 1], [4, 2, 3]]) self.assertWalkYields([c4, c3, c2, c1], [c4.id]) self.assertWalkYields([c3, c1], [c3.id]) self.assertWalkYields([c2, c1], [c2.id]) self.assertWalkYields([c4, c3], [c4.id], exclude=[c2.id]) self.assertWalkYields([c4, c2], [c4.id], exclude=[c3.id]) def test_merge_of_new_branch_from_old_base(self): # The commit on the branch was made at a time after any of the # commits on master, but the branch was from an older commit. # See also test_merge_of_old_branch self.maxDiff = None c1, c2, c3, c4, c5 = self.make_commits( [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]], times=[1, 2, 3, 4, 5], ) self.assertWalkYields([c5, c4, c3, c2, c1], [c5.id]) self.assertWalkYields([c3, c2, c1], [c3.id]) self.assertWalkYields([c2, c1], [c2.id]) @expectedFailure def test_merge_of_old_branch(self): # The commit on the branch was made at a time before any of # the commits on master, but it was merged into master after # those commits. # See also test_merge_of_new_branch_from_old_base self.maxDiff = None c1, c2, c3, c4, c5 = self.make_commits( [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]], times=[1, 3, 4, 2, 5], ) self.assertWalkYields([c5, c4, c3, c2, c1], [c5.id]) self.assertWalkYields([c3, c2, c1], [c3.id]) self.assertWalkYields([c2, c1], [c2.id]) def test_reverse(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([c1, c2, c3], [c3.id], reverse=True) def test_max_entries(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([c3, c2, c1], [c3.id], max_entries=3) self.assertWalkYields([c3, c2], [c3.id], max_entries=2) self.assertWalkYields([c3], [c3.id], max_entries=1) def test_reverse_after_max_entries(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([c1, c2, c3], [c3.id], max_entries=3, reverse=True) self.assertWalkYields([c2, c3], [c3.id], max_entries=2, reverse=True) self.assertWalkYields([c3], [c3.id], max_entries=1, reverse=True) def test_changes_one_parent(self): blob_a1 = make_object(Blob, data=b"a1") blob_a2 = make_object(Blob, data=b"a2") blob_b2 = make_object(Blob, data=b"b2") c1, c2 = self.make_linear_commits( 2, trees={ 1: [(b"a", blob_a1)], 2: [(b"a", blob_a2), (b"b", blob_b2)], }, ) e1 = TestWalkEntry(c1, [TreeChange.add((b"a", F, blob_a1.id))]) e2 = TestWalkEntry( c2, [ TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a2.id)), TreeChange.add((b"b", F, blob_b2.id)), ], ) self.assertWalkYields([e2, e1], [c2.id]) def test_changes_multiple_parents(self): blob_a1 = make_object(Blob, data=b"a1") blob_b2 = make_object(Blob, data=b"b2") blob_a3 = make_object(Blob, data=b"a3") c1, c2, c3 = self.make_commits( [[1], [2], [3, 1, 2]], trees={ 1: [(b"a", blob_a1)], 2: [(b"b", blob_b2)], 3: [(b"a", blob_a3), (b"b", blob_b2)], }, ) # a is a modify/add conflict and b is not conflicted. changes = [ [ TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a3.id)), TreeChange.add((b"a", F, blob_a3.id)), ] ] self.assertWalkYields( [TestWalkEntry(c3, changes)], [c3.id], exclude=[c1.id, c2.id] ) def test_path_matches(self): walker = Walker(None, [], paths=[b"foo", b"bar", b"baz/quux"]) self.assertTrue(walker._path_matches(b"foo")) self.assertTrue(walker._path_matches(b"foo/a")) self.assertTrue(walker._path_matches(b"foo/a/b")) self.assertTrue(walker._path_matches(b"bar")) self.assertTrue(walker._path_matches(b"baz/quux")) self.assertTrue(walker._path_matches(b"baz/quux/a")) self.assertFalse(walker._path_matches(None)) self.assertFalse(walker._path_matches(b"oops")) self.assertFalse(walker._path_matches(b"fool")) self.assertFalse(walker._path_matches(b"baz")) self.assertFalse(walker._path_matches(b"baz/quu")) def test_paths(self): blob_a1 = make_object(Blob, data=b"a1") blob_b2 = make_object(Blob, data=b"b2") blob_a3 = make_object(Blob, data=b"a3") blob_b3 = make_object(Blob, data=b"b3") c1, c2, c3 = self.make_linear_commits( 3, trees={ 1: [(b"a", blob_a1)], 2: [(b"a", blob_a1), (b"x/b", blob_b2)], 3: [(b"a", blob_a3), (b"x/b", blob_b3)], }, ) self.assertWalkYields([c3, c2, c1], [c3.id]) self.assertWalkYields([c3, c1], [c3.id], paths=[b"a"]) self.assertWalkYields([c3, c2], [c3.id], paths=[b"x/b"]) # All changes are included, not just for requested paths. changes = [ TreeChange(CHANGE_MODIFY, (b"a", F, blob_a1.id), (b"a", F, blob_a3.id)), TreeChange(CHANGE_MODIFY, (b"x/b", F, blob_b2.id), (b"x/b", F, blob_b3.id)), ] self.assertWalkYields( [TestWalkEntry(c3, changes)], [c3.id], max_entries=1, paths=[b"a"] ) def test_paths_subtree(self): blob_a = make_object(Blob, data=b"a") blob_b = make_object(Blob, data=b"b") c1, c2, c3 = self.make_linear_commits( 3, trees={ 1: [(b"x/a", blob_a)], 2: [(b"b", blob_b), (b"x/a", blob_a)], 3: [(b"b", blob_b), (b"x/a", blob_a), (b"x/b", blob_b)], }, ) self.assertWalkYields([c2], [c3.id], paths=[b"b"]) self.assertWalkYields([c3, c1], [c3.id], paths=[b"x"]) def test_paths_max_entries(self): blob_a = make_object(Blob, data=b"a") blob_b = make_object(Blob, data=b"b") c1, c2 = self.make_linear_commits( 2, trees={1: [(b"a", blob_a)], 2: [(b"a", blob_a), (b"b", blob_b)]} ) self.assertWalkYields([c2], [c2.id], paths=[b"b"], max_entries=1) self.assertWalkYields([c1], [c1.id], paths=[b"a"], max_entries=1) def test_paths_merge(self): blob_a1 = make_object(Blob, data=b"a1") blob_a2 = make_object(Blob, data=b"a2") blob_a3 = make_object(Blob, data=b"a3") x1, y2, m3, m4 = self.make_commits( [[1], [2], [3, 1, 2], [4, 1, 2]], trees={ 1: [(b"a", blob_a1)], 2: [(b"a", blob_a2)], 3: [(b"a", blob_a3)], 4: [(b"a", blob_a1)], }, ) # Non-conflicting self.assertWalkYields([m3, y2, x1], [m3.id], paths=[b"a"]) self.assertWalkYields([y2, x1], [m4.id], paths=[b"a"]) def test_changes_with_renames(self): blob = make_object(Blob, data=b"blob") c1, c2 = self.make_linear_commits( 2, trees={1: [(b"a", blob)], 2: [(b"b", blob)]} ) entry_a = (b"a", F, blob.id) entry_b = (b"b", F, blob.id) changes_without_renames = [ TreeChange.delete(entry_a), TreeChange.add(entry_b), ] changes_with_renames = [TreeChange(CHANGE_RENAME, entry_a, entry_b)] self.assertWalkYields( [TestWalkEntry(c2, changes_without_renames)], [c2.id], max_entries=1, ) detector = RenameDetector(self.store) self.assertWalkYields( [TestWalkEntry(c2, changes_with_renames)], [c2.id], max_entries=1, rename_detector=detector, ) def test_follow_rename(self): blob = make_object(Blob, data=b"blob") names = [b"a", b"a", b"b", b"b", b"c", b"c"] - trees = dict((i + 1, [(n, blob, F)]) for i, n in enumerate(names)) + trees = {i + 1: [(n, blob, F)] for i, n in enumerate(names)} c1, c2, c3, c4, c5, c6 = self.make_linear_commits(6, trees=trees) self.assertWalkYields([c5], [c6.id], paths=[b"c"]) def e(n): return (n, F, blob.id) self.assertWalkYields( [ TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b"b"), e(b"c"))]), TestWalkEntry(c3, [TreeChange(CHANGE_RENAME, e(b"a"), e(b"b"))]), TestWalkEntry(c1, [TreeChange.add(e(b"a"))]), ], [c6.id], paths=[b"c"], follow=True, ) def test_follow_rename_remove_path(self): blob = make_object(Blob, data=b"blob") _, _, _, c4, c5, c6 = self.make_linear_commits( 6, trees={ 1: [(b"a", blob), (b"c", blob)], 2: [], 3: [], 4: [(b"b", blob)], 5: [(b"a", blob)], 6: [(b"c", blob)], }, ) def e(n): return (n, F, blob.id) # Once the path changes to b, we aren't interested in a or c anymore. self.assertWalkYields( [ TestWalkEntry(c6, [TreeChange(CHANGE_RENAME, e(b"a"), e(b"c"))]), TestWalkEntry(c5, [TreeChange(CHANGE_RENAME, e(b"b"), e(b"a"))]), TestWalkEntry(c4, [TreeChange.add(e(b"b"))]), ], [c6.id], paths=[b"c"], follow=True, ) def test_since(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([c3, c2, c1], [c3.id], since=-1) self.assertWalkYields([c3, c2, c1], [c3.id], since=0) self.assertWalkYields([c3, c2], [c3.id], since=1) self.assertWalkYields([c3, c2], [c3.id], since=99) self.assertWalkYields([c3, c2], [c3.id], since=100) self.assertWalkYields([c3], [c3.id], since=101) self.assertWalkYields([c3], [c3.id], since=199) self.assertWalkYields([c3], [c3.id], since=200) self.assertWalkYields([], [c3.id], since=201) self.assertWalkYields([], [c3.id], since=300) def test_until(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([], [c3.id], until=-1) self.assertWalkYields([c1], [c3.id], until=0) self.assertWalkYields([c1], [c3.id], until=1) self.assertWalkYields([c1], [c3.id], until=99) self.assertWalkYields([c2, c1], [c3.id], until=100) self.assertWalkYields([c2, c1], [c3.id], until=101) self.assertWalkYields([c2, c1], [c3.id], until=199) self.assertWalkYields([c3, c2, c1], [c3.id], until=200) self.assertWalkYields([c3, c2, c1], [c3.id], until=201) self.assertWalkYields([c3, c2, c1], [c3.id], until=300) def test_since_until(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([], [c3.id], since=100, until=99) self.assertWalkYields([c3, c2, c1], [c3.id], since=-1, until=201) self.assertWalkYields([c2], [c3.id], since=100, until=100) self.assertWalkYields([c2], [c3.id], since=50, until=150) def test_since_over_scan(self): commits = self.make_linear_commits(11, times=[9, 0, 1, 2, 3, 4, 5, 8, 6, 7, 9]) c8, _, c10, c11 = commits[-4:] del self.store[commits[0].id] # c9 is older than we want to walk, but is out of order with its # parent, so we need to walk past it to get to c8. # c1 would also match, but we've deleted it, and it should get pruned # even with over-scanning. self.assertWalkYields([c11, c10, c8], [c11.id], since=7) def assertTopoOrderEqual(self, expected_commits, commits): entries = [TestWalkEntry(c, None) for c in commits] actual_ids = [e.commit.id for e in list(_topo_reorder(entries))] self.assertEqual([c.id for c in expected_commits], actual_ids) def test_topo_reorder_linear(self): commits = self.make_linear_commits(5) commits.reverse() for perm in permutations(commits): self.assertTopoOrderEqual(commits, perm) def test_topo_reorder_multiple_parents(self): c1, c2, c3 = self.make_commits([[1], [2], [3, 1, 2]]) # Already sorted, so totally FIFO. self.assertTopoOrderEqual([c3, c2, c1], [c3, c2, c1]) self.assertTopoOrderEqual([c3, c1, c2], [c3, c1, c2]) # c3 causes one parent to be yielded. self.assertTopoOrderEqual([c3, c2, c1], [c2, c3, c1]) self.assertTopoOrderEqual([c3, c1, c2], [c1, c3, c2]) # c3 causes both parents to be yielded. self.assertTopoOrderEqual([c3, c2, c1], [c1, c2, c3]) self.assertTopoOrderEqual([c3, c2, c1], [c2, c1, c3]) def test_topo_reorder_multiple_children(self): c1, c2, c3 = self.make_commits([[1], [2, 1], [3, 1]]) # c2 and c3 are FIFO but c1 moves to the end. self.assertTopoOrderEqual([c3, c2, c1], [c3, c2, c1]) self.assertTopoOrderEqual([c3, c2, c1], [c3, c1, c2]) self.assertTopoOrderEqual([c3, c2, c1], [c1, c3, c2]) self.assertTopoOrderEqual([c2, c3, c1], [c2, c3, c1]) self.assertTopoOrderEqual([c2, c3, c1], [c2, c1, c3]) self.assertTopoOrderEqual([c2, c3, c1], [c1, c2, c3]) def test_out_of_order_children(self): c1, c2, c3, c4, c5 = self.make_commits( [[1], [2, 1], [3, 2], [4, 1], [5, 3, 4]], times=[2, 1, 3, 4, 5] ) self.assertWalkYields([c5, c4, c3, c1, c2], [c5.id]) self.assertWalkYields([c5, c4, c3, c2, c1], [c5.id], order=ORDER_TOPO) def test_out_of_order_with_exclude(self): # Create the following graph: # c1-------x2---m6 # \ / # \-y3--y4-/--y5 # Due to skew, y5 is the oldest commit. c1, x2, y3, y4, y5, m6 = self.make_commits( [[1], [2, 1], [3, 1], [4, 3], [5, 4], [6, 2, 4]], times=[2, 3, 4, 5, 1, 6], ) self.assertWalkYields([m6, y4, y3, x2, c1], [m6.id]) # Ensure that c1..y4 get excluded even though they're popped from the # priority queue long before y5. self.assertWalkYields([m6, x2], [m6.id], exclude=[y5.id]) def test_empty_walk(self): c1, c2, c3 = self.make_linear_commits(3) self.assertWalkYields([], [c3.id], exclude=[c3.id]) class WalkEntryTest(TestCase): def setUp(self): super(WalkEntryTest, self).setUp() self.store = MemoryObjectStore() def make_commits(self, commit_spec, **kwargs): times = kwargs.pop("times", []) attrs = kwargs.pop("attrs", {}) for i, t in enumerate(times): attrs.setdefault(i + 1, {})["commit_time"] = t return build_commit_graph(self.store, commit_spec, attrs=attrs, **kwargs) def make_linear_commits(self, num_commits, **kwargs): commit_spec = [] for i in range(1, num_commits + 1): c = [i] if i > 1: c.append(i - 1) commit_spec.append(c) return self.make_commits(commit_spec, **kwargs) def test_all_changes(self): # Construct a commit with 2 files in different subdirectories. blob_a = make_object(Blob, data=b"a") blob_b = make_object(Blob, data=b"b") c1 = self.make_linear_commits( 1, trees={1: [(b"x/a", blob_a), (b"y/b", blob_b)]}, )[0] # Get the WalkEntry for the commit. walker = Walker(self.store, c1.id) walker_entry = list(walker)[0] changes = walker_entry.changes() # Compare the changes with the expected values. entry_a = (b"x/a", F, blob_a.id) entry_b = (b"y/b", F, blob_b.id) self.assertEqual( [TreeChange.add(entry_a), TreeChange.add(entry_b)], changes, ) def test_all_with_merge(self): blob_a = make_object(Blob, data=b"a") blob_a2 = make_object(Blob, data=b"a2") blob_b = make_object(Blob, data=b"b") blob_b2 = make_object(Blob, data=b"b2") x1, y2, m3 = self.make_commits( [[1], [2], [3, 1, 2]], trees={ 1: [(b"x/a", blob_a)], 2: [(b"y/b", blob_b)], 3: [(b"x/a", blob_a2), (b"y/b", blob_b2)], }, ) # Get the WalkEntry for the merge commit. walker = Walker(self.store, m3.id) entries = list(walker) walker_entry = entries[0] self.assertEqual(walker_entry.commit.id, m3.id) changes = walker_entry.changes() self.assertEqual(2, len(changes)) entry_a = (b"x/a", F, blob_a.id) entry_a2 = (b"x/a", F, blob_a2.id) entry_b = (b"y/b", F, blob_b.id) entry_b2 = (b"y/b", F, blob_b2.id) self.assertEqual( [ [ TreeChange(CHANGE_MODIFY, entry_a, entry_a2), TreeChange.add(entry_a2), ], [ TreeChange.add(entry_b2), TreeChange(CHANGE_MODIFY, entry_b, entry_b2), ], ], changes, ) def test_filter_changes(self): # Construct a commit with 2 files in different subdirectories. blob_a = make_object(Blob, data=b"a") blob_b = make_object(Blob, data=b"b") c1 = self.make_linear_commits( 1, trees={1: [(b"x/a", blob_a), (b"y/b", blob_b)]}, )[0] # Get the WalkEntry for the commit. walker = Walker(self.store, c1.id) walker_entry = list(walker)[0] changes = walker_entry.changes(path_prefix=b"x") # Compare the changes with the expected values. entry_a = (b"a", F, blob_a.id) self.assertEqual( [TreeChange.add(entry_a)], changes, ) def test_filter_with_merge(self): blob_a = make_object(Blob, data=b"a") blob_a2 = make_object(Blob, data=b"a2") blob_b = make_object(Blob, data=b"b") blob_b2 = make_object(Blob, data=b"b2") x1, y2, m3 = self.make_commits( [[1], [2], [3, 1, 2]], trees={ 1: [(b"x/a", blob_a)], 2: [(b"y/b", blob_b)], 3: [(b"x/a", blob_a2), (b"y/b", blob_b2)], }, ) # Get the WalkEntry for the merge commit. walker = Walker(self.store, m3.id) entries = list(walker) walker_entry = entries[0] self.assertEqual(walker_entry.commit.id, m3.id) changes = walker_entry.changes(b"x") self.assertEqual(1, len(changes)) entry_a = (b"a", F, blob_a.id) entry_a2 = (b"a", F, blob_a2.id) self.assertEqual( [[TreeChange(CHANGE_MODIFY, entry_a, entry_a2)]], changes, ) diff --git a/setup.py b/setup.py index e782c367..7f49272b 100755 --- a/setup.py +++ b/setup.py @@ -1,135 +1,135 @@ #!/usr/bin/python3 # encoding: utf-8 # Setup file for dulwich # Copyright (C) 2008-2016 Jelmer Vernooij try: from setuptools import setup, Extension except ImportError: from distutils.core import setup, Extension has_setuptools = False else: has_setuptools = True from distutils.core import Distribution import io import os import sys from typing import Dict, Any if sys.version_info < (3, 5): raise Exception( 'Dulwich only supports Python 3.5 and later. ' 'For 2.7 support, please install a version prior to 0.20') dulwich_version_string = '0.20.21' class DulwichDistribution(Distribution): def is_pure(self): if self.pure: return True def has_ext_modules(self): return not self.pure global_options = Distribution.global_options + [ ('pure', None, "use pure Python code instead of C " "extensions (slower on CPython)")] pure = False if sys.platform == 'darwin' and os.path.exists('/usr/bin/xcodebuild'): # XCode 4.0 dropped support for ppc architecture, which is hardcoded in # distutils.sysconfig import subprocess p = subprocess.Popen( ['/usr/bin/xcodebuild', '-version'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, env={}) out, err = p.communicate() for line in out.splitlines(): line = line.decode("utf8") # Also parse only first digit, because 3.2.1 can't be parsed nicely if (line.startswith('Xcode') and int(line.split()[1].split('.')[0]) >= 4): os.environ['ARCHFLAGS'] = '' tests_require = ['fastimport'] -if '__pypy__' not in sys.modules and not sys.platform == 'win32': +if '__pypy__' not in sys.modules and sys.platform != 'win32': tests_require.extend([ 'gevent', 'geventhttpclient', 'setuptools>=17.1']) ext_modules = [ Extension('dulwich._objects', ['dulwich/_objects.c']), Extension('dulwich._pack', ['dulwich/_pack.c']), Extension('dulwich._diff_tree', ['dulwich/_diff_tree.c']), ] setup_kwargs = {} # type: Dict[str, Any] scripts = ['bin/dul-receive-pack', 'bin/dul-upload-pack'] if has_setuptools: setup_kwargs['extras_require'] = { 'fastimport': ['fastimport'], 'https': ['urllib3[secure]>=1.24.1'], 'pgp': ['gpg'], 'watch': ['pyinotify'], } setup_kwargs['install_requires'] = ['urllib3>=1.24.1', 'certifi'] setup_kwargs['include_package_data'] = True setup_kwargs['test_suite'] = 'dulwich.tests.test_suite' setup_kwargs['tests_require'] = tests_require setup_kwargs['entry_points'] = { "console_scripts": [ "dulwich=dulwich.cli:main", ]} setup_kwargs['python_requires'] = '>=3.5' else: scripts.append('bin/dulwich') with io.open(os.path.join(os.path.dirname(__file__), "README.rst"), encoding="utf-8") as f: description = f.read() setup(name='dulwich', author="Jelmer Vernooij", author_email="jelmer@jelmer.uk", url="https://www.dulwich.io/", long_description=description, description="Python Git Library", version=dulwich_version_string, license='Apachev2 or later or GPLv2', project_urls={ "Bug Tracker": "https://github.com/dulwich/dulwich/issues", "Repository": "https://www.dulwich.io/code/", "GitHub": "https://github.com/dulwich/dulwich", }, keywords="git vcs", packages=['dulwich', 'dulwich.tests', 'dulwich.tests.compat', 'dulwich.contrib'], package_data={'': ['../docs/tutorial/*.txt', 'py.typed']}, scripts=scripts, ext_modules=ext_modules, distclass=DulwichDistribution, classifiers=[ 'Development Status :: 4 - Beta', 'License :: OSI Approved :: Apache Software License', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Operating System :: POSIX', 'Operating System :: Microsoft :: Windows', 'Topic :: Software Development :: Version Control', ], **setup_kwargs )