diff --git a/dulwich/client.py b/dulwich/client.py index ca03f251..9fb5e854 100644 --- a/dulwich/client.py +++ b/dulwich/client.py @@ -1,2194 +1,2197 @@ # client.py -- Implementation of the client side git protocols # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Client side support for the Git protocol. The Dulwich client supports the following capabilities: * thin-pack * multi_ack_detailed * multi_ack * side-band-64k * ofs-delta * quiet * report-status * delete-refs * shallow Known capabilities that are not supported: * no-progress * include-tag """ from contextlib import closing from io import BytesIO, BufferedReader import logging import os import select import socket import subprocess import sys from typing import Optional, Dict, Callable, Set from urllib.parse import ( quote as urlquote, unquote as urlunquote, urlparse, urljoin, urlunsplit, urlunparse, ) import dulwich from dulwich.config import get_xdg_config_home_path from dulwich.errors import ( GitProtocolError, NotGitRepository, SendPackError, ) from dulwich.protocol import ( HangupException, _RBUFSIZE, agent_string, capability_agent, extract_capability_names, CAPABILITY_AGENT, CAPABILITY_DELETE_REFS, CAPABILITY_INCLUDE_TAG, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_OFS_DELTA, CAPABILITY_QUIET, CAPABILITY_REPORT_STATUS, CAPABILITY_SHALLOW, CAPABILITY_SYMREF, CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK, CAPABILITIES_REF, KNOWN_RECEIVE_CAPABILITIES, KNOWN_UPLOAD_CAPABILITIES, COMMAND_DEEPEN, COMMAND_SHALLOW, COMMAND_UNSHALLOW, COMMAND_DONE, COMMAND_HAVE, COMMAND_WANT, SIDE_BAND_CHANNEL_DATA, SIDE_BAND_CHANNEL_PROGRESS, SIDE_BAND_CHANNEL_FATAL, PktLineParser, Protocol, ProtocolFile, TCP_GIT_PORT, ZERO_SHA, extract_capabilities, parse_capability, ) from dulwich.pack import ( write_pack_data, write_pack_objects, ) from dulwich.refs import ( read_info_refs, ANNOTATED_TAG_SUFFIX, ) logger = logging.getLogger(__name__) class InvalidWants(Exception): """Invalid wants.""" def __init__(self, wants): Exception.__init__( self, "requested wants not in server provided refs: %r" % wants ) class HTTPUnauthorized(Exception): """Raised when authentication fails.""" def __init__(self, www_authenticate, url): Exception.__init__(self, "No valid credentials provided") self.www_authenticate = www_authenticate self.url = url def _fileno_can_read(fileno): """Check if a file descriptor is readable.""" return len(select.select([fileno], [], [], 0)[0]) > 0 def _win32_peek_avail(handle): """Wrapper around PeekNamedPipe to check how many bytes are available.""" from ctypes import byref, wintypes, windll c_avail = wintypes.DWORD() c_message = wintypes.DWORD() success = windll.kernel32.PeekNamedPipe( handle, None, 0, None, byref(c_avail), byref(c_message) ) if not success: raise OSError(wintypes.GetLastError()) return c_avail.value COMMON_CAPABILITIES = [CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND_64K] UPLOAD_CAPABILITIES = [ CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_SHALLOW, ] + COMMON_CAPABILITIES RECEIVE_CAPABILITIES = [ CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS, ] + COMMON_CAPABILITIES class ReportStatusParser(object): """Handle status as reported by servers with 'report-status' capability.""" def __init__(self): self._done = False self._pack_status = None self._ref_statuses = [] def check(self): """Check if there were any errors and, if so, raise exceptions. Raises: SendPackError: Raised when the server could not unpack Returns: iterator over refs """ if self._pack_status not in (b"unpack ok", None): raise SendPackError(self._pack_status) for status in self._ref_statuses: try: status, rest = status.split(b" ", 1) except ValueError: # malformed response, move on to the next one continue if status == b"ng": ref, error = rest.split(b" ", 1) yield ref, error.decode("utf-8") elif status == b"ok": yield rest, None else: raise GitProtocolError("invalid ref status %r" % status) def handle_packet(self, pkt): """Handle a packet. Raises: GitProtocolError: Raised when packets are received after a flush packet. """ if self._done: raise GitProtocolError("received more data after status report") if pkt is None: self._done = True return if self._pack_status is None: self._pack_status = pkt.strip() else: ref_status = pkt.strip() self._ref_statuses.append(ref_status) def read_pkt_refs(proto): server_capabilities = None refs = {} # Receive refs from server for pkt in proto.read_pkt_seq(): (sha, ref) = pkt.rstrip(b"\n").split(None, 1) if sha == b"ERR": raise GitProtocolError(ref.decode("utf-8", "replace")) if server_capabilities is None: (ref, server_capabilities) = extract_capabilities(ref) refs[ref] = sha if len(refs) == 0: return {}, set([]) if refs == {CAPABILITIES_REF: ZERO_SHA}: refs = {} return refs, set(server_capabilities) class FetchPackResult(object): """Result of a fetch-pack operation. Attributes: refs: Dictionary with all remote refs symrefs: Dictionary with remote symrefs agent: User agent string """ _FORWARDED_ATTRS = [ "clear", "copy", "fromkeys", "get", "items", "keys", "pop", "popitem", "setdefault", "update", "values", "viewitems", "viewkeys", "viewvalues", ] def __init__(self, refs, symrefs, agent, new_shallow=None, new_unshallow=None): self.refs = refs self.symrefs = symrefs self.agent = agent self.new_shallow = new_shallow self.new_unshallow = new_unshallow def _warn_deprecated(self): import warnings warnings.warn( "Use FetchPackResult.refs instead.", DeprecationWarning, stacklevel=3, ) def __eq__(self, other): if isinstance(other, dict): self._warn_deprecated() return self.refs == other return ( self.refs == other.refs and self.symrefs == other.symrefs and self.agent == other.agent ) def __contains__(self, name): self._warn_deprecated() return name in self.refs def __getitem__(self, name): self._warn_deprecated() return self.refs[name] def __len__(self): self._warn_deprecated() return len(self.refs) def __iter__(self): self._warn_deprecated() return iter(self.refs) def __getattribute__(self, name): if name in type(self)._FORWARDED_ATTRS: self._warn_deprecated() return getattr(self.refs, name) return super(FetchPackResult, self).__getattribute__(name) def __repr__(self): return "%s(%r, %r, %r)" % ( self.__class__.__name__, self.refs, self.symrefs, self.agent, ) class SendPackResult(object): """Result of a upload-pack operation. Attributes: refs: Dictionary with all remote refs agent: User agent string ref_status: Optional dictionary mapping ref name to error message (if it failed to update), or None if it was updated successfully """ _FORWARDED_ATTRS = [ "clear", "copy", "fromkeys", "get", "items", "keys", "pop", "popitem", "setdefault", "update", "values", "viewitems", "viewkeys", "viewvalues", ] def __init__(self, refs, agent=None, ref_status=None): self.refs = refs self.agent = agent self.ref_status = ref_status def _warn_deprecated(self): import warnings warnings.warn( "Use SendPackResult.refs instead.", DeprecationWarning, stacklevel=3, ) def __eq__(self, other): if isinstance(other, dict): self._warn_deprecated() return self.refs == other return self.refs == other.refs and self.agent == other.agent def __contains__(self, name): self._warn_deprecated() return name in self.refs def __getitem__(self, name): self._warn_deprecated() return self.refs[name] def __len__(self): self._warn_deprecated() return len(self.refs) def __iter__(self): self._warn_deprecated() return iter(self.refs) def __getattribute__(self, name): if name in type(self)._FORWARDED_ATTRS: self._warn_deprecated() return getattr(self.refs, name) return super(SendPackResult, self).__getattribute__(name) def __repr__(self): return "%s(%r, %r)" % (self.__class__.__name__, self.refs, self.agent) def _read_shallow_updates(proto): new_shallow = set() new_unshallow = set() for pkt in proto.read_pkt_seq(): cmd, sha = pkt.split(b" ", 1) if cmd == COMMAND_SHALLOW: new_shallow.add(sha.strip()) elif cmd == COMMAND_UNSHALLOW: new_unshallow.add(sha.strip()) else: raise GitProtocolError("unknown command %s" % pkt) return (new_shallow, new_unshallow) # TODO(durin42): this doesn't correctly degrade if the server doesn't # support some capabilities. This should work properly with servers # that don't support multi_ack. class GitClient(object): """Git smart server client.""" def __init__( self, thin_packs=True, report_activity=None, quiet=False, include_tags=False, ): """Create a new GitClient instance. Args: thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. include_tags: send annotated tags when sending the objects they point to """ self._report_activity = report_activity self._report_status_parser = None self._fetch_capabilities = set(UPLOAD_CAPABILITIES) self._fetch_capabilities.add(capability_agent()) self._send_capabilities = set(RECEIVE_CAPABILITIES) self._send_capabilities.add(capability_agent()) if quiet: self._send_capabilities.add(CAPABILITY_QUIET) if not thin_packs: self._fetch_capabilities.remove(CAPABILITY_THIN_PACK) if include_tags: self._fetch_capabilities.add(CAPABILITY_INCLUDE_TAG) def get_url(self, path): """Retrieves full url to given path. Args: path: Repository path (as string) Returns: Url to path (as string) """ raise NotImplementedError(self.get_url) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): """Create an instance of this client from a urlparse.parsed object. Args: parsedurl: Result of urlparse() Returns: A `GitClient` object """ raise NotImplementedError(cls.from_parsedurl) def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of objects and list of pack data to include progress: Optional progress function Returns: SendPackResult object Raises: SendPackError: if server rejects the pack data """ raise NotImplementedError(self.send_pack) def fetch(self, path, target, determine_wants=None, progress=None, depth=None): """Fetch into a target repository. Args: path: Path to fetch from (as bytestring) target: Target repository to fetch into determine_wants: Optional function to determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. Defaults to all shas. progress: Optional progress function depth: Depth to fetch at Returns: Dictionary with all remote refs (not just those fetched) """ if determine_wants is None: - determine_wants = target.object_store.determine_wants_all + if depth is not None and depth > 1: + determine_wants = target.object_store.determine_wants_force + else: + determine_wants = target.object_store.determine_wants_all if CAPABILITY_THIN_PACK in self._fetch_capabilities: # TODO(jelmer): Avoid reading entire file into memory and # only processing it after the whole file has been fetched. f = BytesIO() def commit(): if f.tell(): f.seek(0) target.object_store.add_thin_pack(f.read, None) def abort(): pass else: f, commit, abort = target.object_store.add_pack() try: result = self.fetch_pack( path, determine_wants, target.get_graph_walker(), f.write, progress=progress, depth=depth, ) except BaseException: abort() raise else: commit() target.update_shallow(result.new_shallow, result.new_unshallow) return result def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ raise NotImplementedError(self.fetch_pack) def get_refs(self, path): """Retrieve the current refs from a git smart server. Args: path: Path to the repo to fetch from. (as bytestring) Returns: """ raise NotImplementedError(self.get_refs) def _read_side_band64k_data(self, proto, channel_callbacks): """Read per-channel data. This requires the side-band-64k capability. Args: proto: Protocol object to read from channel_callbacks: Dictionary mapping channels to packet handlers to use. None for a callback discards channel data. """ for pkt in proto.read_pkt_seq(): channel = ord(pkt[:1]) pkt = pkt[1:] try: cb = channel_callbacks[channel] except KeyError: raise AssertionError("Invalid sideband channel %d" % channel) else: if cb is not None: cb(pkt) @staticmethod def _should_send_pack(new_refs): # The packfile MUST NOT be sent if the only command used is delete. return any(sha != ZERO_SHA for sha in new_refs.values()) def _handle_receive_pack_head(self, proto, capabilities, old_refs, new_refs): """Handle the head of a 'git-receive-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities old_refs: Old refs, as received from the server new_refs: Refs to change Returns: (have, want) tuple """ want = [] have = [x for x in old_refs.values() if not x == ZERO_SHA] sent_capabilities = False for refname in new_refs: if not isinstance(refname, bytes): raise TypeError("refname is not a bytestring: %r" % refname) old_sha1 = old_refs.get(refname, ZERO_SHA) if not isinstance(old_sha1, bytes): raise TypeError( "old sha1 for %s is not a bytestring: %r" % (refname, old_sha1) ) new_sha1 = new_refs.get(refname, ZERO_SHA) if not isinstance(new_sha1, bytes): raise TypeError( "old sha1 for %s is not a bytestring %r" % (refname, new_sha1) ) if old_sha1 != new_sha1: logger.debug( 'Sending updated ref %r: %r -> %r', refname, old_sha1, new_sha1) if sent_capabilities: proto.write_pkt_line(old_sha1 + b" " + new_sha1 + b" " + refname) else: proto.write_pkt_line( old_sha1 + b" " + new_sha1 + b" " + refname + b"\0" + b" ".join(sorted(capabilities)) ) sent_capabilities = True if new_sha1 not in have and new_sha1 != ZERO_SHA: want.append(new_sha1) proto.write_pkt_line(None) return (have, want) def _negotiate_receive_pack_capabilities(self, server_capabilities): negotiated_capabilities = self._send_capabilities & server_capabilities agent = None for capability in server_capabilities: k, v = parse_capability(capability) if k == CAPABILITY_AGENT: agent = v unknown_capabilities = ( # noqa: F841 extract_capability_names(server_capabilities) - KNOWN_RECEIVE_CAPABILITIES ) # TODO(jelmer): warn about unknown capabilities return negotiated_capabilities, agent def _handle_receive_pack_tail( self, proto: Protocol, capabilities: Set[bytes], progress: Callable[[bytes], None] = None, ) -> Optional[Dict[bytes, Optional[str]]]: """Handle the tail of a 'git-receive-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities progress: Optional progress reporting function Returns: dict mapping ref name to: error message if the ref failed to update None if it was updated successfully """ if CAPABILITY_SIDE_BAND_64K in capabilities: if progress is None: def progress(x): pass channel_callbacks = {2: progress} if CAPABILITY_REPORT_STATUS in capabilities: channel_callbacks[1] = PktLineParser( self._report_status_parser.handle_packet ).parse self._read_side_band64k_data(proto, channel_callbacks) else: if CAPABILITY_REPORT_STATUS in capabilities: for pkt in proto.read_pkt_seq(): self._report_status_parser.handle_packet(pkt) if self._report_status_parser is not None: return dict(self._report_status_parser.check()) return None def _negotiate_upload_pack_capabilities(self, server_capabilities): unknown_capabilities = ( # noqa: F841 extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES ) # TODO(jelmer): warn about unknown capabilities symrefs = {} agent = None for capability in server_capabilities: k, v = parse_capability(capability) if k == CAPABILITY_SYMREF: (src, dst) = v.split(b":", 1) symrefs[src] = dst if k == CAPABILITY_AGENT: agent = v negotiated_capabilities = self._fetch_capabilities & server_capabilities return (negotiated_capabilities, symrefs, agent) def _handle_upload_pack_head( self, proto, capabilities, graph_walker, wants, can_read, depth ): """Handle the head of a 'git-upload-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities graph_walker: GraphWalker instance to call .ack() on wants: List of commits to fetch can_read: function that returns a boolean that indicates whether there is extra graph data to read on proto depth: Depth for request Returns: """ assert isinstance(wants, list) and isinstance(wants[0], bytes) proto.write_pkt_line( COMMAND_WANT + b" " + wants[0] + b" " + b" ".join(sorted(capabilities)) + b"\n" ) for want in wants[1:]: proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n") if depth not in (0, None) or getattr(graph_walker, "shallow", None): if CAPABILITY_SHALLOW not in capabilities: raise GitProtocolError( "server does not support shallow capability required for " "depth" ) for sha in graph_walker.shallow: proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n") if depth is not None: proto.write_pkt_line( COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n" ) proto.write_pkt_line(None) if can_read is not None: (new_shallow, new_unshallow) = _read_shallow_updates(proto) else: new_shallow = new_unshallow = None else: new_shallow = new_unshallow = set() proto.write_pkt_line(None) have = next(graph_walker) while have: proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n") if can_read is not None and can_read(): pkt = proto.read_pkt_line() parts = pkt.rstrip(b"\n").split(b" ") if parts[0] == b"ACK": graph_walker.ack(parts[1]) if parts[2] in (b"continue", b"common"): pass elif parts[2] == b"ready": break else: raise AssertionError( "%s not in ('continue', 'ready', 'common)" % parts[2] ) have = next(graph_walker) proto.write_pkt_line(COMMAND_DONE + b"\n") return (new_shallow, new_unshallow) def _handle_upload_pack_tail( self, proto, capabilities, graph_walker, pack_data, progress=None, rbufsize=_RBUFSIZE, ): """Handle the tail of a 'git-upload-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities graph_walker: GraphWalker instance to call .ack() on pack_data: Function to call with pack data progress: Optional progress reporting function rbufsize: Read buffer size Returns: """ pkt = proto.read_pkt_line() while pkt: parts = pkt.rstrip(b"\n").split(b" ") if parts[0] == b"ACK": graph_walker.ack(parts[1]) if len(parts) < 3 or parts[2] not in ( b"ready", b"continue", b"common", ): break pkt = proto.read_pkt_line() if CAPABILITY_SIDE_BAND_64K in capabilities: if progress is None: # Just ignore progress data def progress(x): pass self._read_side_band64k_data( proto, { SIDE_BAND_CHANNEL_DATA: pack_data, SIDE_BAND_CHANNEL_PROGRESS: progress, }, ) else: while True: data = proto.read(rbufsize) if data == b"": break pack_data(data) def check_wants(wants, refs): """Check that a set of wants is valid. Args: wants: Set of object SHAs to fetch refs: Refs dictionary to check against Returns: """ missing = set(wants) - { v for (k, v) in refs.items() if not k.endswith(ANNOTATED_TAG_SUFFIX) } if missing: raise InvalidWants(missing) def _remote_error_from_stderr(stderr): if stderr is None: return HangupException() lines = [line.rstrip(b"\n") for line in stderr.readlines()] for line in lines: if line.startswith(b"ERROR: "): return GitProtocolError(line[len(b"ERROR: ") :].decode("utf-8", "replace")) return HangupException(lines) class TraditionalGitClient(GitClient): """Traditional Git client.""" DEFAULT_ENCODING = "utf-8" def __init__(self, path_encoding=DEFAULT_ENCODING, **kwargs): self._remote_path_encoding = path_encoding super(TraditionalGitClient, self).__init__(**kwargs) async def _connect(self, cmd, path): """Create a connection to the server. This method is abstract - concrete implementations should implement their own variant which connects to the server and returns an initialized Protocol object with the service ready for use and a can_read function which may be used to see if reads would block. Args: cmd: The git service name to which we should connect. path: The path we should pass to the service. (as bytestirng) """ raise NotImplementedError() def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of objects and pack data to upload. progress: Optional callback called with progress updates Returns: SendPackResult Raises: SendPackError: if server rejects the pack data """ proto, unused_can_read, stderr = self._connect(b"receive-pack", path) with proto: try: old_refs, server_capabilities = read_pkt_refs(proto) except HangupException: raise _remote_error_from_stderr(stderr) ( negotiated_capabilities, agent, ) = self._negotiate_receive_pack_capabilities(server_capabilities) if CAPABILITY_REPORT_STATUS in negotiated_capabilities: self._report_status_parser = ReportStatusParser() report_status_parser = self._report_status_parser try: new_refs = orig_new_refs = update_refs(dict(old_refs)) except BaseException: proto.write_pkt_line(None) raise if set(new_refs.items()).issubset(set(old_refs.items())): proto.write_pkt_line(None) return SendPackResult(new_refs, agent=agent, ref_status={}) if CAPABILITY_DELETE_REFS not in server_capabilities: # Server does not support deletions. Fail later. new_refs = dict(orig_new_refs) for ref, sha in orig_new_refs.items(): if sha == ZERO_SHA: if CAPABILITY_REPORT_STATUS in negotiated_capabilities: report_status_parser._ref_statuses.append( b"ng " + ref + b" remote does not support deleting refs" ) report_status_parser._ref_status_ok = False del new_refs[ref] if new_refs is None: proto.write_pkt_line(None) return SendPackResult(old_refs, agent=agent, ref_status={}) if len(new_refs) == 0 and orig_new_refs: # NOOP - Original new refs filtered out by policy proto.write_pkt_line(None) if report_status_parser is not None: ref_status = dict(report_status_parser.check()) else: ref_status = None return SendPackResult(old_refs, agent=agent, ref_status=ref_status) (have, want) = self._handle_receive_pack_head( proto, negotiated_capabilities, old_refs, new_refs ) pack_data_count, pack_data = generate_pack_data( have, want, ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities), ) if self._should_send_pack(new_refs): write_pack_data(proto.write_file(), pack_data_count, pack_data) ref_status = self._handle_receive_pack_tail( proto, negotiated_capabilities, progress ) return SendPackResult(new_refs, agent=agent, ref_status=ref_status) def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ proto, can_read, stderr = self._connect(b"upload-pack", path) with proto: try: refs, server_capabilities = read_pkt_refs(proto) except HangupException: raise _remote_error_from_stderr(stderr) ( negotiated_capabilities, symrefs, agent, ) = self._negotiate_upload_pack_capabilities(server_capabilities) if refs is None: proto.write_pkt_line(None) return FetchPackResult(refs, symrefs, agent) try: wants = determine_wants(refs) except BaseException: proto.write_pkt_line(None) raise if wants is not None: wants = [cid for cid in wants if cid != ZERO_SHA] if not wants: proto.write_pkt_line(None) return FetchPackResult(refs, symrefs, agent) (new_shallow, new_unshallow) = self._handle_upload_pack_head( proto, negotiated_capabilities, graph_walker, wants, can_read, depth=depth, ) self._handle_upload_pack_tail( proto, negotiated_capabilities, graph_walker, pack_data, progress, ) return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow) def get_refs(self, path): """Retrieve the current refs from a git smart server.""" # stock `git ls-remote` uses upload-pack proto, _, stderr = self._connect(b"upload-pack", path) with proto: try: refs, _ = read_pkt_refs(proto) except HangupException: raise _remote_error_from_stderr(stderr) proto.write_pkt_line(None) return refs def archive( self, path, committish, write_data, progress=None, write_error=None, format=None, subdirs=None, prefix=None, ): proto, can_read, stderr = self._connect(b"upload-archive", path) with proto: if format is not None: proto.write_pkt_line(b"argument --format=" + format) proto.write_pkt_line(b"argument " + committish) if subdirs is not None: for subdir in subdirs: proto.write_pkt_line(b"argument " + subdir) if prefix is not None: proto.write_pkt_line(b"argument --prefix=" + prefix) proto.write_pkt_line(None) try: pkt = proto.read_pkt_line() except HangupException: raise _remote_error_from_stderr(stderr) if pkt == b"NACK\n" or pkt == b"NACK": return elif pkt == b"ACK\n" or pkt == b"ACK": pass elif pkt.startswith(b"ERR "): raise GitProtocolError(pkt[4:].rstrip(b"\n").decode("utf-8", "replace")) else: raise AssertionError("invalid response %r" % pkt) ret = proto.read_pkt_line() if ret is not None: raise AssertionError("expected pkt tail") self._read_side_band64k_data( proto, { SIDE_BAND_CHANNEL_DATA: write_data, SIDE_BAND_CHANNEL_PROGRESS: progress, SIDE_BAND_CHANNEL_FATAL: write_error, }, ) class TCPGitClient(TraditionalGitClient): """A Git Client that works over TCP directly (i.e. git://).""" def __init__(self, host, port=None, **kwargs): if port is None: port = TCP_GIT_PORT self._host = host self._port = port super(TCPGitClient, self).__init__(**kwargs) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(parsedurl.hostname, port=parsedurl.port, **kwargs) def get_url(self, path): netloc = self._host if self._port is not None and self._port != TCP_GIT_PORT: netloc += ":%d" % self._port return urlunsplit(("git", netloc, path, "", "")) def _connect(self, cmd, path): if not isinstance(cmd, bytes): raise TypeError(cmd) if not isinstance(path, bytes): path = path.encode(self._remote_path_encoding) sockaddrs = socket.getaddrinfo( self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM ) s = None err = socket.error("no address found for %s" % self._host) for (family, socktype, proto, canonname, sockaddr) in sockaddrs: s = socket.socket(family, socktype, proto) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: s.connect(sockaddr) break except socket.error as e: err = e if s is not None: s.close() s = None if s is None: raise err # -1 means system default buffering rfile = s.makefile("rb", -1) # 0 means unbuffered wfile = s.makefile("wb", 0) def close(): rfile.close() wfile.close() s.close() proto = Protocol( rfile.read, wfile.write, close, report_activity=self._report_activity, ) if path.startswith(b"/~"): path = path[1:] # TODO(jelmer): Alternative to ascii? proto.send_cmd(b"git-" + cmd, path, b"host=" + self._host.encode("ascii")) return proto, lambda: _fileno_can_read(s), None class SubprocessWrapper(object): """A socket-like object that talks to a subprocess via pipes.""" def __init__(self, proc): self.proc = proc self.read = BufferedReader(proc.stdout).read self.write = proc.stdin.write @property def stderr(self): return self.proc.stderr def can_read(self): if sys.platform == "win32": from msvcrt import get_osfhandle handle = get_osfhandle(self.proc.stdout.fileno()) return _win32_peek_avail(handle) != 0 else: return _fileno_can_read(self.proc.stdout.fileno()) def close(self): self.proc.stdin.close() self.proc.stdout.close() if self.proc.stderr: self.proc.stderr.close() self.proc.wait() def find_git_command(): """Find command to run for system Git (usually C Git).""" if sys.platform == "win32": # support .exe, .bat and .cmd try: # to avoid overhead import win32api except ImportError: # run through cmd.exe with some overhead return ["cmd", "/c", "git"] else: status, git = win32api.FindExecutable("git") return [git] else: return ["git"] class SubprocessGitClient(TraditionalGitClient): """Git client that talks to a server using a subprocess.""" @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(**kwargs) git_command = None def _connect(self, service, path): if not isinstance(service, bytes): raise TypeError(service) if isinstance(path, bytes): path = path.decode(self._remote_path_encoding) if self.git_command is None: git_command = find_git_command() argv = git_command + [service.decode("ascii"), path] p = subprocess.Popen( argv, bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) pw = SubprocessWrapper(p) return ( Protocol( pw.read, pw.write, pw.close, report_activity=self._report_activity, ), pw.can_read, p.stderr, ) class LocalGitClient(GitClient): """Git Client that just uses a local Repo.""" def __init__(self, thin_packs=True, report_activity=None, config=None): """Create a new LocalGitClient instance. Args: thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. """ self._report_activity = report_activity # Ignore the thin_packs argument def get_url(self, path): return urlunsplit(("file", "", path, "", "")) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(**kwargs) @classmethod def _open_repo(cls, path): from dulwich.repo import Repo if not isinstance(path, str): path = os.fsdecode(path) return closing(Repo(path)) def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) with number of items and pack data to upload. progress: Optional progress function Returns: SendPackResult Raises: SendPackError: if server rejects the pack data """ if not progress: def progress(x): pass with self._open_repo(path) as target: old_refs = target.get_refs() new_refs = update_refs(dict(old_refs)) have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA] want = [] for refname, new_sha1 in new_refs.items(): if ( new_sha1 not in have and new_sha1 not in want and new_sha1 != ZERO_SHA ): want.append(new_sha1) if not want and set(new_refs.items()).issubset(set(old_refs.items())): return SendPackResult(new_refs, ref_status={}) target.object_store.add_pack_data( *generate_pack_data(have, want, ofs_delta=True) ) ref_status = {} for refname, new_sha1 in new_refs.items(): old_sha1 = old_refs.get(refname, ZERO_SHA) if new_sha1 != ZERO_SHA: if not target.refs.set_if_equals(refname, old_sha1, new_sha1): msg = "unable to set %s to %s" % (refname, new_sha1) progress(msg) ref_status[refname] = msg else: if not target.refs.remove_if_equals(refname, old_sha1): progress("unable to remove %s" % refname) ref_status[refname] = "unable to remove" return SendPackResult(new_refs, ref_status=ref_status) def fetch(self, path, target, determine_wants=None, progress=None, depth=None): """Fetch into a target repository. Args: path: Path to fetch from (as bytestring) target: Target repository to fetch into determine_wants: Optional function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. Defaults to all shas. progress: Optional progress function depth: Shallow fetch depth Returns: FetchPackResult object """ with self._open_repo(path) as r: refs = r.fetch( target, determine_wants=determine_wants, progress=progress, depth=depth, ) return FetchPackResult(refs, r.refs.get_symrefs(), agent_string()) def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ with self._open_repo(path) as r: objects_iter = r.fetch_objects( determine_wants, graph_walker, progress=progress, depth=depth ) symrefs = r.refs.get_symrefs() agent = agent_string() # Did the process short-circuit (e.g. in a stateless RPC call)? # Note that the client still expects a 0-object pack in most cases. if objects_iter is None: return FetchPackResult(None, symrefs, agent) protocol = ProtocolFile(None, pack_data) write_pack_objects(protocol, objects_iter) return FetchPackResult(r.get_refs(), symrefs, agent) def get_refs(self, path): """Retrieve the current refs from a git smart server.""" with self._open_repo(path) as target: return target.get_refs() # What Git client to use for local access default_local_git_client_cls = LocalGitClient class SSHVendor(object): """A client side SSH implementation.""" def connect_ssh( self, host, command, username=None, port=None, password=None, key_filename=None, ): # This function was deprecated in 0.9.1 import warnings warnings.warn( "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command", DeprecationWarning, ) return self.run_command( host, command, username=username, port=port, password=password, key_filename=key_filename, ) def run_command( self, host, command, username=None, port=None, password=None, key_filename=None, ): """Connect to an SSH server. Run a command remotely and return a file-like object for interaction with the remote command. Args: host: Host name command: Command to run (as argv array) username: Optional ame of user to log in as port: Optional SSH port to use password: Optional ssh password for login or private key key_filename: Optional path to private keyfile Returns: """ raise NotImplementedError(self.run_command) class StrangeHostname(Exception): """Refusing to connect to strange SSH hostname.""" def __init__(self, hostname): super(StrangeHostname, self).__init__(hostname) class SubprocessSSHVendor(SSHVendor): """SSH vendor that shells out to the local 'ssh' command.""" def run_command( self, host, command, username=None, port=None, password=None, key_filename=None, ): if password is not None: raise NotImplementedError( "Setting password not supported by SubprocessSSHVendor." ) args = ["ssh", "-x"] if port: args.extend(["-p", str(port)]) if key_filename: args.extend(["-i", str(key_filename)]) if username: host = "%s@%s" % (username, host) if host.startswith("-"): raise StrangeHostname(hostname=host) args.append(host) proc = subprocess.Popen( args + [command], bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) return SubprocessWrapper(proc) class PLinkSSHVendor(SSHVendor): """SSH vendor that shells out to the local 'plink' command.""" def run_command( self, host, command, username=None, port=None, password=None, key_filename=None, ): if sys.platform == "win32": args = ["plink.exe", "-ssh"] else: args = ["plink", "-ssh"] if password is not None: import warnings warnings.warn( "Invoking PLink with a password exposes the password in the " "process list." ) args.extend(["-pw", str(password)]) if port: args.extend(["-P", str(port)]) if key_filename: args.extend(["-i", str(key_filename)]) if username: host = "%s@%s" % (username, host) if host.startswith("-"): raise StrangeHostname(hostname=host) args.append(host) proc = subprocess.Popen( args + [command], bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) return SubprocessWrapper(proc) def ParamikoSSHVendor(**kwargs): import warnings warnings.warn( "ParamikoSSHVendor has been moved to dulwich.contrib.paramiko_vendor.", DeprecationWarning, ) from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor return ParamikoSSHVendor(**kwargs) # Can be overridden by users get_ssh_vendor = SubprocessSSHVendor class SSHGitClient(TraditionalGitClient): def __init__( self, host, port=None, username=None, vendor=None, config=None, password=None, key_filename=None, **kwargs ): self.host = host self.port = port self.username = username self.password = password self.key_filename = key_filename super(SSHGitClient, self).__init__(**kwargs) self.alternative_paths = {} if vendor is not None: self.ssh_vendor = vendor else: self.ssh_vendor = get_ssh_vendor() def get_url(self, path): netloc = self.host if self.port is not None: netloc += ":%d" % self.port if self.username is not None: netloc = urlquote(self.username, "@/:") + "@" + netloc return urlunsplit(("ssh", netloc, path, "", "")) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls( host=parsedurl.hostname, port=parsedurl.port, username=parsedurl.username, **kwargs ) def _get_cmd_path(self, cmd): cmd = self.alternative_paths.get(cmd, b"git-" + cmd) assert isinstance(cmd, bytes) return cmd def _connect(self, cmd, path): if not isinstance(cmd, bytes): raise TypeError(cmd) if isinstance(path, bytes): path = path.decode(self._remote_path_encoding) if path.startswith("/~"): path = path[1:] argv = ( self._get_cmd_path(cmd).decode(self._remote_path_encoding) + " '" + path + "'" ) kwargs = {} if self.password is not None: kwargs["password"] = self.password if self.key_filename is not None: kwargs["key_filename"] = self.key_filename con = self.ssh_vendor.run_command( self.host, argv, port=self.port, username=self.username, **kwargs ) return ( Protocol( con.read, con.write, con.close, report_activity=self._report_activity, ), con.can_read, getattr(con, "stderr", None), ) def default_user_agent_string(): # Start user agent with "git/", because GitHub requires this. :-( See # https://github.com/jelmer/dulwich/issues/562 for details. return "git/dulwich/%s" % ".".join([str(x) for x in dulwich.__version__]) def default_urllib3_manager( # noqa: C901 config, pool_manager_cls=None, proxy_manager_cls=None, **override_kwargs ): """Return `urllib3` connection pool manager. Honour detected proxy configurations. Args: config: dulwich.config.ConfigDict` instance with Git configuration. kwargs: Additional arguments for urllib3.ProxyManager Returns: `pool_manager_cls` (defaults to `urllib3.ProxyManager`) instance for proxy configurations, `proxy_manager_cls` (defaults to `urllib3.PoolManager`) instance otherwise. """ proxy_server = user_agent = None ca_certs = ssl_verify = None if proxy_server is None: for proxyname in ("https_proxy", "http_proxy", "all_proxy"): proxy_server = os.environ.get(proxyname) if proxy_server is not None: break if config is not None: if proxy_server is None: try: proxy_server = config.get(b"http", b"proxy") except KeyError: pass try: user_agent = config.get(b"http", b"useragent") except KeyError: pass # TODO(jelmer): Support per-host settings try: ssl_verify = config.get_boolean(b"http", b"sslVerify") except KeyError: ssl_verify = True try: ca_certs = config.get(b"http", b"sslCAInfo") except KeyError: ca_certs = None if user_agent is None: user_agent = default_user_agent_string() headers = {"User-agent": user_agent} kwargs = {} if ssl_verify is True: kwargs["cert_reqs"] = "CERT_REQUIRED" elif ssl_verify is False: kwargs["cert_reqs"] = "CERT_NONE" else: # Default to SSL verification kwargs["cert_reqs"] = "CERT_REQUIRED" if ca_certs is not None: kwargs["ca_certs"] = ca_certs kwargs.update(override_kwargs) # Try really hard to find a SSL certificate path if "ca_certs" not in kwargs and kwargs.get("cert_reqs") != "CERT_NONE": try: import certifi except ImportError: pass else: kwargs["ca_certs"] = certifi.where() import urllib3 if proxy_server is not None: if proxy_manager_cls is None: proxy_manager_cls = urllib3.ProxyManager # `urllib3` requires a `str` object in both Python 2 and 3, while # `ConfigDict` coerces entries to `bytes` on Python 3. Compensate. if not isinstance(proxy_server, str): proxy_server = proxy_server.decode() manager = proxy_manager_cls(proxy_server, headers=headers, **kwargs) else: if pool_manager_cls is None: pool_manager_cls = urllib3.PoolManager manager = pool_manager_cls(headers=headers, **kwargs) return manager class HttpGitClient(GitClient): def __init__( self, base_url, dumb=None, pool_manager=None, config=None, username=None, password=None, **kwargs ): self._base_url = base_url.rstrip("/") + "/" self._username = username self._password = password self.dumb = dumb if pool_manager is None: self.pool_manager = default_urllib3_manager(config) else: self.pool_manager = pool_manager if username is not None: # No escaping needed: ":" is not allowed in username: # https://tools.ietf.org/html/rfc2617#section-2 credentials = "%s:%s" % (username, password) import urllib3.util basic_auth = urllib3.util.make_headers(basic_auth=credentials) self.pool_manager.headers.update(basic_auth) GitClient.__init__(self, **kwargs) def get_url(self, path): return self._get_url(path).rstrip("/") @classmethod def from_parsedurl(cls, parsedurl, **kwargs): password = parsedurl.password if password is not None: kwargs["password"] = urlunquote(password) username = parsedurl.username if username is not None: kwargs["username"] = urlunquote(username) netloc = parsedurl.hostname if parsedurl.port: netloc = "%s:%s" % (netloc, parsedurl.port) if parsedurl.username: netloc = "%s@%s" % (parsedurl.username, netloc) parsedurl = parsedurl._replace(netloc=netloc) return cls(urlunparse(parsedurl), **kwargs) def __repr__(self): return "%s(%r, dumb=%r)" % ( type(self).__name__, self._base_url, self.dumb, ) def _get_url(self, path): if not isinstance(path, str): # urllib3.util.url._encode_invalid_chars() converts the path back # to bytes using the utf-8 codec. path = path.decode("utf-8") return urljoin(self._base_url, path).rstrip("/") + "/" def _http_request(self, url, headers=None, data=None, allow_compression=False): """Perform HTTP request. Args: url: Request URL. headers: Optional custom headers to override defaults. data: Request data. allow_compression: Allow GZipped communication. Returns: Tuple (`response`, `read`), where response is an `urllib3` response object with additional `content_type` and `redirect_location` properties, and `read` is a consumable read method for the response data. """ req_headers = self.pool_manager.headers.copy() if headers is not None: req_headers.update(headers) req_headers["Pragma"] = "no-cache" if allow_compression: req_headers["Accept-Encoding"] = "gzip" else: req_headers["Accept-Encoding"] = "identity" if data is None: resp = self.pool_manager.request("GET", url, headers=req_headers) else: resp = self.pool_manager.request( "POST", url, headers=req_headers, body=data ) if resp.status == 404: raise NotGitRepository() if resp.status == 401: raise HTTPUnauthorized(resp.getheader("WWW-Authenticate"), url) if resp.status != 200: raise GitProtocolError( "unexpected http resp %d for %s" % (resp.status, url) ) # TODO: Optimization available by adding `preload_content=False` to the # request and just passing the `read` method on instead of going via # `BytesIO`, if we can guarantee that the entire response is consumed # before issuing the next to still allow for connection reuse from the # pool. read = BytesIO(resp.data).read resp.content_type = resp.getheader("Content-Type") # Check if geturl() is available (urllib3 version >= 1.23) try: resp_url = resp.geturl() except AttributeError: # get_redirect_location() is available for urllib3 >= 1.1 resp.redirect_location = resp.get_redirect_location() else: resp.redirect_location = resp_url if resp_url != url else "" return resp, read def _discover_references(self, service, base_url): assert base_url[-1] == "/" tail = "info/refs" headers = {"Accept": "*/*"} if self.dumb is not True: tail += "?service=%s" % service.decode("ascii") url = urljoin(base_url, tail) resp, read = self._http_request(url, headers, allow_compression=True) if resp.redirect_location: # Something changed (redirect!), so let's update the base URL if not resp.redirect_location.endswith(tail): raise GitProtocolError( "Redirected from URL %s to URL %s without %s" % (url, resp.redirect_location, tail) ) base_url = resp.redirect_location[: -len(tail)] try: self.dumb = not resp.content_type.startswith("application/x-git-") if not self.dumb: proto = Protocol(read, None) # The first line should mention the service try: [pkt] = list(proto.read_pkt_seq()) except ValueError: raise GitProtocolError("unexpected number of packets received") if pkt.rstrip(b"\n") != (b"# service=" + service): raise GitProtocolError( "unexpected first line %r from smart server" % pkt ) return read_pkt_refs(proto) + (base_url,) else: return read_info_refs(resp), set(), base_url finally: resp.close() def _smart_request(self, service, url, data): assert url[-1] == "/" url = urljoin(url, service) result_content_type = "application/x-%s-result" % service headers = { "Content-Type": "application/x-%s-request" % service, "Accept": result_content_type, "Content-Length": str(len(data)), } resp, read = self._http_request(url, headers, data) if resp.content_type != result_content_type: raise GitProtocolError( "Invalid content-type from server: %s" % resp.content_type ) return resp, read def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receives dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of elements and pack data to upload. progress: Optional progress function Returns: SendPackResult Raises: SendPackError: if server rejects the pack data """ url = self._get_url(path) old_refs, server_capabilities, url = self._discover_references( b"git-receive-pack", url ) ( negotiated_capabilities, agent, ) = self._negotiate_receive_pack_capabilities(server_capabilities) negotiated_capabilities.add(capability_agent()) if CAPABILITY_REPORT_STATUS in negotiated_capabilities: self._report_status_parser = ReportStatusParser() new_refs = update_refs(dict(old_refs)) if new_refs is None: # Determine wants function is aborting the push. return SendPackResult(old_refs, agent=agent, ref_status={}) if set(new_refs.items()).issubset(set(old_refs.items())): return SendPackResult(new_refs, agent=agent, ref_status={}) if self.dumb: raise NotImplementedError(self.fetch_pack) req_data = BytesIO() req_proto = Protocol(None, req_data.write) (have, want) = self._handle_receive_pack_head( req_proto, negotiated_capabilities, old_refs, new_refs ) pack_data_count, pack_data = generate_pack_data( have, want, ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities), ) if self._should_send_pack(new_refs): write_pack_data(req_proto.write_file(), pack_data_count, pack_data) resp, read = self._smart_request( "git-receive-pack", url, data=req_data.getvalue() ) try: resp_proto = Protocol(read, None) ref_status = self._handle_receive_pack_tail( resp_proto, negotiated_capabilities, progress ) return SendPackResult(new_refs, agent=agent, ref_status=ref_status) finally: resp.close() def fetch_pack( self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None, ): """Retrieve a pack from a git smart server. Args: path: Path to fetch from determine_wants: Callback that returns list of commits to fetch graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Depth for request Returns: FetchPackResult object """ url = self._get_url(path) refs, server_capabilities, url = self._discover_references( b"git-upload-pack", url ) ( negotiated_capabilities, symrefs, agent, ) = self._negotiate_upload_pack_capabilities(server_capabilities) wants = determine_wants(refs) if wants is not None: wants = [cid for cid in wants if cid != ZERO_SHA] if not wants: return FetchPackResult(refs, symrefs, agent) if self.dumb: raise NotImplementedError(self.fetch_pack) req_data = BytesIO() req_proto = Protocol(None, req_data.write) (new_shallow, new_unshallow) = self._handle_upload_pack_head( req_proto, negotiated_capabilities, graph_walker, wants, can_read=None, depth=depth, ) resp, read = self._smart_request( "git-upload-pack", url, data=req_data.getvalue() ) try: resp_proto = Protocol(read, None) if new_shallow is None and new_unshallow is None: (new_shallow, new_unshallow) = _read_shallow_updates(resp_proto) self._handle_upload_pack_tail( resp_proto, negotiated_capabilities, graph_walker, pack_data, progress, ) return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow) finally: resp.close() def get_refs(self, path): """Retrieve the current refs from a git smart server.""" url = self._get_url(path) refs, _, _ = self._discover_references(b"git-upload-pack", url) return refs def get_transport_and_path_from_url(url, config=None, **kwargs): """Obtain a git client from a URL. Args: url: URL to open (a unicode string) config: Optional config object thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. Returns: Tuple with client instance and relative path. """ parsed = urlparse(url) if parsed.scheme == "git": return (TCPGitClient.from_parsedurl(parsed, **kwargs), parsed.path) elif parsed.scheme in ("git+ssh", "ssh"): return SSHGitClient.from_parsedurl(parsed, **kwargs), parsed.path elif parsed.scheme in ("http", "https"): return ( HttpGitClient.from_parsedurl(parsed, config=config, **kwargs), parsed.path, ) elif parsed.scheme == "file": return ( default_local_git_client_cls.from_parsedurl(parsed, **kwargs), parsed.path, ) raise ValueError("unknown scheme '%s'" % parsed.scheme) def parse_rsync_url(location): """Parse a rsync-style URL.""" if ":" in location and "@" not in location: # SSH with no user@, zero or one leading slash. (host, path) = location.split(":", 1) user = None elif ":" in location: # SSH with user@host:foo. user_host, path = location.split(":", 1) if "@" in user_host: user, host = user_host.rsplit("@", 1) else: user = None host = user_host else: raise ValueError("not a valid rsync-style URL") return (user, host, path) def get_transport_and_path(location, **kwargs): """Obtain a git client from a URL. Args: location: URL or path (a string) config: Optional config object thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. Returns: Tuple with client instance and relative path. """ # First, try to parse it as a URL try: return get_transport_and_path_from_url(location, **kwargs) except ValueError: pass if sys.platform == "win32" and location[0].isalpha() and location[1:3] == ":\\": # Windows local path return default_local_git_client_cls(**kwargs), location try: (username, hostname, path) = parse_rsync_url(location) except ValueError: # Otherwise, assume it's a local path. return default_local_git_client_cls(**kwargs), location else: return SSHGitClient(hostname, username=username, **kwargs), path DEFAULT_GIT_CREDENTIALS_PATHS = [ os.path.expanduser("~/.git-credentials"), get_xdg_config_home_path("git", "credentials"), ] def get_credentials_from_store( scheme, hostname, username=None, fnames=DEFAULT_GIT_CREDENTIALS_PATHS ): for fname in fnames: try: with open(fname, "rb") as f: for line in f: parsed_line = urlparse(line.strip()) if ( parsed_line.scheme == scheme and parsed_line.hostname == hostname and (username is None or parsed_line.username == username) ): return parsed_line.username, parsed_line.password except FileNotFoundError: # If the file doesn't exist, try the next one. continue diff --git a/dulwich/object_store.py b/dulwich/object_store.py index a5d07e09..2075fcf9 100644 --- a/dulwich/object_store.py +++ b/dulwich/object_store.py @@ -1,1556 +1,1562 @@ # object_store.py -- Object store for git objects # Copyright (C) 2008-2013 Jelmer Vernooij # and others # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Git object store interfaces and implementation.""" from io import BytesIO import os import stat import sys from dulwich.diff_tree import ( tree_changes, walk_trees, ) from dulwich.errors import ( NotTreeError, ) from dulwich.file import GitFile from dulwich.objects import ( Commit, ShaFile, Tag, Tree, ZERO_SHA, hex_to_sha, sha_to_hex, hex_to_filename, S_ISGITLINK, object_class, valid_hexsha, ) from dulwich.pack import ( Pack, PackData, PackInflater, PackFileDisappeared, load_pack_index_file, iter_sha1, pack_objects_to_data, write_pack_header, write_pack_index_v2, write_pack_data, write_pack_object, compute_file_sha, PackIndexer, PackStreamCopier, ) from dulwich.refs import ANNOTATED_TAG_SUFFIX INFODIR = "info" PACKDIR = "pack" class BaseObjectStore(object): """Object store interface.""" - def determine_wants_all(self, refs): + def _determine_wants_all(self, refs, force=False): return [ sha for (ref, sha) in refs.items() - if sha not in self + if (force or sha not in self) and not ref.endswith(ANNOTATED_TAG_SUFFIX) and not sha == ZERO_SHA ] + def determine_wants_all(self, refs): + return self._determine_wants_all(refs) + + def determine_wants_force(self, refs): + return self._determine_wants_all(refs, force=True) + def iter_shas(self, shas): """Iterate over the objects for the specified shas. Args: shas: Iterable object with SHAs Returns: Object iterator """ return ObjectStoreIterator(self, shas) def contains_loose(self, sha): """Check if a particular object is present by SHA1 and is loose.""" raise NotImplementedError(self.contains_loose) def contains_packed(self, sha): """Check if a particular object is present by SHA1 and is packed.""" raise NotImplementedError(self.contains_packed) def __contains__(self, sha): """Check if a particular object is present by SHA1. This method makes no distinction between loose and packed objects. """ return self.contains_packed(sha) or self.contains_loose(sha) @property def packs(self): """Iterable of pack objects.""" raise NotImplementedError def get_raw(self, name): """Obtain the raw text for an object. Args: name: sha for the object. Returns: tuple with numeric type and object contents. """ raise NotImplementedError(self.get_raw) def __getitem__(self, sha): """Obtain an object by SHA1.""" type_num, uncomp = self.get_raw(sha) return ShaFile.from_raw_string(type_num, uncomp, sha=sha) def __iter__(self): """Iterate over the SHAs that are present in this store.""" raise NotImplementedError(self.__iter__) def add_object(self, obj): """Add a single object to this object store.""" raise NotImplementedError(self.add_object) def add_objects(self, objects, progress=None): """Add a set of objects to this object store. Args: objects: Iterable over a list of (object, path) tuples """ raise NotImplementedError(self.add_objects) def add_pack_data(self, count, pack_data, progress=None): """Add pack data to this object store. Args: num_items: Number of items to add pack_data: Iterator over pack data tuples """ if count == 0: # Don't bother writing an empty pack file return f, commit, abort = self.add_pack() try: write_pack_data( f, count, pack_data, progress, compression_level=self.pack_compression_level, ) except BaseException: abort() raise else: return commit() def tree_changes( self, source, target, want_unchanged=False, include_trees=False, change_type_same=False, rename_detector=None, ): """Find the differences between the contents of two trees Args: source: SHA1 of the source tree target: SHA1 of the target tree want_unchanged: Whether unchanged files should be reported include_trees: Whether to include trees change_type_same: Whether to report files changing type in the same entry. Returns: Iterator over tuples with (oldpath, newpath), (oldmode, newmode), (oldsha, newsha) """ for change in tree_changes( self, source, target, want_unchanged=want_unchanged, include_trees=include_trees, change_type_same=change_type_same, rename_detector=rename_detector, ): yield ( (change.old.path, change.new.path), (change.old.mode, change.new.mode), (change.old.sha, change.new.sha), ) def iter_tree_contents(self, tree_id, include_trees=False): """Iterate the contents of a tree and all subtrees. Iteration is depth-first pre-order, as in e.g. os.walk. Args: tree_id: SHA1 of the tree. include_trees: If True, include tree objects in the iteration. Returns: Iterator over TreeEntry namedtuples for all the objects in a tree. """ for entry, _ in walk_trees(self, tree_id, None): if ( entry.mode is not None and not stat.S_ISDIR(entry.mode) ) or include_trees: yield entry def find_missing_objects( self, haves, wants, shallow=None, progress=None, get_tagged=None, get_parents=lambda commit: commit.parents, depth=None, ): """Find the missing objects required for a set of revisions. Args: haves: Iterable over SHAs already in common. wants: Iterable over SHAs of objects to fetch. shallow: Set of shallow commit SHA1s to skip progress: Simple progress function that will be called with updated progress strings. get_tagged: Function that returns a dict of pointed-to sha -> tag sha for including tags. get_parents: Optional function for getting the parents of a commit. Returns: Iterator over (sha, path) pairs. """ finder = MissingObjectFinder( self, haves, wants, shallow, progress, get_tagged, get_parents=get_parents, ) return iter(finder.next, None) def find_common_revisions(self, graphwalker): """Find which revisions this store has in common using graphwalker. Args: graphwalker: A graphwalker object. Returns: List of SHAs that are in common """ haves = [] sha = next(graphwalker) while sha: if sha in self: haves.append(sha) graphwalker.ack(sha) sha = next(graphwalker) return haves def generate_pack_contents(self, have, want, shallow=None, progress=None): """Iterate over the contents of a pack file. Args: have: List of SHA1s of objects that should not be sent want: List of SHA1s of objects that should be sent shallow: Set of shallow commit SHA1s to skip progress: Optional progress reporting method """ missing = self.find_missing_objects(have, want, shallow, progress) return self.iter_shas(missing) def generate_pack_data( self, have, want, shallow=None, progress=None, ofs_delta=True ): """Generate pack data objects for a set of wants/haves. Args: have: List of SHA1s of objects that should not be sent want: List of SHA1s of objects that should be sent shallow: Set of shallow commit SHA1s to skip ofs_delta: Whether OFS deltas can be included progress: Optional progress reporting method """ # TODO(jelmer): More efficient implementation return pack_objects_to_data( self.generate_pack_contents(have, want, shallow, progress) ) def peel_sha(self, sha): """Peel all tags from a SHA. Args: sha: The object SHA to peel. Returns: The fully-peeled SHA1 of a tag object, after peeling all intermediate tags; if the original ref does not point to a tag, this will equal the original SHA1. """ obj = self[sha] obj_class = object_class(obj.type_name) while obj_class is Tag: obj_class, sha = obj.object obj = self[sha] return obj def _collect_ancestors( self, heads, common=set(), shallow=set(), get_parents=lambda commit: commit.parents, ): """Collect all ancestors of heads up to (excluding) those in common. Args: heads: commits to start from common: commits to end at, or empty set to walk repository completely get_parents: Optional function for getting the parents of a commit. Returns: a tuple (A, B) where A - all commits reachable from heads but not present in common, B - common (shared) elements that are directly reachable from heads """ bases = set() commits = set() queue = [] queue.extend(heads) while queue: e = queue.pop(0) if e in common: bases.add(e) elif e not in commits: commits.add(e) if e in shallow: continue cmt = self[e] queue.extend(get_parents(cmt)) return (commits, bases) def close(self): """Close any files opened by this object store.""" # Default implementation is a NO-OP class PackBasedObjectStore(BaseObjectStore): def __init__(self, pack_compression_level=-1): self._pack_cache = {} self.pack_compression_level = pack_compression_level @property def alternates(self): return [] def contains_packed(self, sha): """Check if a particular object is present by SHA1 and is packed. This does not check alternates. """ for pack in self.packs: try: if sha in pack: return True except PackFileDisappeared: pass return False def __contains__(self, sha): """Check if a particular object is present by SHA1. This method makes no distinction between loose and packed objects. """ if self.contains_packed(sha) or self.contains_loose(sha): return True for alternate in self.alternates: if sha in alternate: return True return False def _add_cached_pack(self, base_name, pack): """Add a newly appeared pack to the cache by path.""" prev_pack = self._pack_cache.get(base_name) if prev_pack is not pack: self._pack_cache[base_name] = pack if prev_pack: prev_pack.close() def _clear_cached_packs(self): pack_cache = self._pack_cache self._pack_cache = {} while pack_cache: (name, pack) = pack_cache.popitem() pack.close() def _iter_cached_packs(self): return self._pack_cache.values() def _update_pack_cache(self): raise NotImplementedError(self._update_pack_cache) def close(self): self._clear_cached_packs() @property def packs(self): """List with pack objects.""" return list(self._iter_cached_packs()) + list(self._update_pack_cache()) def _iter_alternate_objects(self): """Iterate over the SHAs of all the objects in alternate stores.""" for alternate in self.alternates: for alternate_object in alternate: yield alternate_object def _iter_loose_objects(self): """Iterate over the SHAs of all loose objects.""" raise NotImplementedError(self._iter_loose_objects) def _get_loose_object(self, sha): raise NotImplementedError(self._get_loose_object) def _remove_loose_object(self, sha): raise NotImplementedError(self._remove_loose_object) def _remove_pack(self, name): raise NotImplementedError(self._remove_pack) def pack_loose_objects(self): """Pack loose objects. Returns: Number of objects packed """ objects = set() for sha in self._iter_loose_objects(): objects.add((self._get_loose_object(sha), None)) self.add_objects(list(objects)) for obj, path in objects: self._remove_loose_object(obj.id) return len(objects) def repack(self): """Repack the packs in this repository. Note that this implementation is fairly naive and currently keeps all objects in memory while it repacks. """ loose_objects = set() for sha in self._iter_loose_objects(): loose_objects.add(self._get_loose_object(sha)) objects = {(obj, None) for obj in loose_objects} old_packs = {p.name(): p for p in self.packs} for name, pack in old_packs.items(): objects.update((obj, None) for obj in pack.iterobjects()) # The name of the consolidated pack might match the name of a # pre-existing pack. Take care not to remove the newly created # consolidated pack. consolidated = self.add_objects(objects) old_packs.pop(consolidated.name(), None) for obj in loose_objects: self._remove_loose_object(obj.id) for name, pack in old_packs.items(): self._remove_pack(pack) self._update_pack_cache() return len(objects) def __iter__(self): """Iterate over the SHAs that are present in this store.""" self._update_pack_cache() for pack in self._iter_cached_packs(): try: for sha in pack: yield sha except PackFileDisappeared: pass for sha in self._iter_loose_objects(): yield sha for sha in self._iter_alternate_objects(): yield sha def contains_loose(self, sha): """Check if a particular object is present by SHA1 and is loose. This does not check alternates. """ return self._get_loose_object(sha) is not None def get_raw(self, name): """Obtain the raw fulltext for an object. Args: name: sha for the object. Returns: tuple with numeric type and object contents. """ if name == ZERO_SHA: raise KeyError(name) if len(name) == 40: sha = hex_to_sha(name) hexsha = name elif len(name) == 20: sha = name hexsha = None else: raise AssertionError("Invalid object name %r" % (name,)) for pack in self._iter_cached_packs(): try: return pack.get_raw(sha) except (KeyError, PackFileDisappeared): pass if hexsha is None: hexsha = sha_to_hex(name) ret = self._get_loose_object(hexsha) if ret is not None: return ret.type_num, ret.as_raw_string() # Maybe something else has added a pack with the object # in the mean time? for pack in self._update_pack_cache(): try: return pack.get_raw(sha) except KeyError: pass for alternate in self.alternates: try: return alternate.get_raw(hexsha) except KeyError: pass raise KeyError(hexsha) def add_objects(self, objects, progress=None): """Add a set of objects to this object store. Args: objects: Iterable over (object, path) tuples, should support __len__. Returns: Pack object of the objects written. """ return self.add_pack_data(*pack_objects_to_data(objects), progress=progress) class DiskObjectStore(PackBasedObjectStore): """Git-style object store that exists on disk.""" def __init__(self, path, loose_compression_level=-1, pack_compression_level=-1): """Open an object store. Args: path: Path of the object store. loose_compression_level: zlib compression level for loose objects pack_compression_level: zlib compression level for pack objects """ super(DiskObjectStore, self).__init__( pack_compression_level=pack_compression_level ) self.path = path self.pack_dir = os.path.join(self.path, PACKDIR) self._alternates = None self.loose_compression_level = loose_compression_level self.pack_compression_level = pack_compression_level def __repr__(self): return "<%s(%r)>" % (self.__class__.__name__, self.path) @classmethod def from_config(cls, path, config): try: default_compression_level = int( config.get((b"core",), b"compression").decode() ) except KeyError: default_compression_level = -1 try: loose_compression_level = int( config.get((b"core",), b"looseCompression").decode() ) except KeyError: loose_compression_level = default_compression_level try: pack_compression_level = int( config.get((b"core",), "packCompression").decode() ) except KeyError: pack_compression_level = default_compression_level return cls(path, loose_compression_level, pack_compression_level) @property def alternates(self): if self._alternates is not None: return self._alternates self._alternates = [] for path in self._read_alternate_paths(): self._alternates.append(DiskObjectStore(path)) return self._alternates def _read_alternate_paths(self): try: f = GitFile(os.path.join(self.path, INFODIR, "alternates"), "rb") except FileNotFoundError: return with f: for line in f.readlines(): line = line.rstrip(b"\n") if line.startswith(b"#"): continue if os.path.isabs(line): yield os.fsdecode(line) else: yield os.fsdecode(os.path.join(os.fsencode(self.path), line)) def add_alternate_path(self, path): """Add an alternate path to this object store.""" try: os.mkdir(os.path.join(self.path, INFODIR)) except FileExistsError: pass alternates_path = os.path.join(self.path, INFODIR, "alternates") with GitFile(alternates_path, "wb") as f: try: orig_f = open(alternates_path, "rb") except FileNotFoundError: pass else: with orig_f: f.write(orig_f.read()) f.write(os.fsencode(path) + b"\n") if not os.path.isabs(path): path = os.path.join(self.path, path) self.alternates.append(DiskObjectStore(path)) def _update_pack_cache(self): """Read and iterate over new pack files and cache them.""" try: pack_dir_contents = os.listdir(self.pack_dir) except FileNotFoundError: self.close() return [] pack_files = set() for name in pack_dir_contents: if name.startswith("pack-") and name.endswith(".pack"): # verify that idx exists first (otherwise the pack was not yet # fully written) idx_name = os.path.splitext(name)[0] + ".idx" if idx_name in pack_dir_contents: pack_name = name[: -len(".pack")] pack_files.add(pack_name) # Open newly appeared pack files new_packs = [] for f in pack_files: if f not in self._pack_cache: pack = Pack(os.path.join(self.pack_dir, f)) new_packs.append(pack) self._pack_cache[f] = pack # Remove disappeared pack files for f in set(self._pack_cache) - pack_files: self._pack_cache.pop(f).close() return new_packs def _get_shafile_path(self, sha): # Check from object dir return hex_to_filename(self.path, sha) def _iter_loose_objects(self): for base in os.listdir(self.path): if len(base) != 2: continue for rest in os.listdir(os.path.join(self.path, base)): sha = os.fsencode(base + rest) if not valid_hexsha(sha): continue yield sha def _get_loose_object(self, sha): path = self._get_shafile_path(sha) try: return ShaFile.from_path(path) except FileNotFoundError: return None def _remove_loose_object(self, sha): os.remove(self._get_shafile_path(sha)) def _remove_pack(self, pack): try: del self._pack_cache[os.path.basename(pack._basename)] except KeyError: pass pack.close() os.remove(pack.data.path) os.remove(pack.index.path) def _get_pack_basepath(self, entries): suffix = iter_sha1(entry[0] for entry in entries) # TODO: Handle self.pack_dir being bytes suffix = suffix.decode("ascii") return os.path.join(self.pack_dir, "pack-" + suffix) def _complete_thin_pack(self, f, path, copier, indexer): """Move a specific file containing a pack into the pack directory. Note: The file should be on the same file system as the packs directory. Args: f: Open file object for the pack. path: Path to the pack file. copier: A PackStreamCopier to use for writing pack data. indexer: A PackIndexer for indexing the pack. """ entries = list(indexer) # Update the header with the new number of objects. f.seek(0) write_pack_header(f, len(entries) + len(indexer.ext_refs())) # Must flush before reading (http://bugs.python.org/issue3207) f.flush() # Rescan the rest of the pack, computing the SHA with the new header. new_sha = compute_file_sha(f, end_ofs=-20) # Must reposition before writing (http://bugs.python.org/issue3207) f.seek(0, os.SEEK_CUR) # Complete the pack. for ext_sha in indexer.ext_refs(): assert len(ext_sha) == 20 type_num, data = self.get_raw(ext_sha) offset = f.tell() crc32 = write_pack_object( f, type_num, data, sha=new_sha, compression_level=self.pack_compression_level, ) entries.append((ext_sha, offset, crc32)) pack_sha = new_sha.digest() f.write(pack_sha) f.close() # Move the pack in. entries.sort() pack_base_name = self._get_pack_basepath(entries) target_pack = pack_base_name + ".pack" if sys.platform == "win32": # Windows might have the target pack file lingering. Attempt # removal, silently passing if the target does not exist. try: os.remove(target_pack) except FileNotFoundError: pass os.rename(path, target_pack) # Write the index. index_file = GitFile(pack_base_name + ".idx", "wb") try: write_pack_index_v2(index_file, entries, pack_sha) index_file.close() finally: index_file.abort() # Add the pack to the store and return it. final_pack = Pack(pack_base_name) final_pack.check_length_and_checksum() self._add_cached_pack(pack_base_name, final_pack) return final_pack def add_thin_pack(self, read_all, read_some): """Add a new thin pack to this object store. Thin packs are packs that contain deltas with parents that exist outside the pack. They should never be placed in the object store directly, and always indexed and completed as they are copied. Args: read_all: Read function that blocks until the number of requested bytes are read. read_some: Read function that returns at least one byte, but may not return the number of bytes requested. Returns: A Pack object pointing at the now-completed thin pack in the objects/pack directory. """ import tempfile fd, path = tempfile.mkstemp(dir=self.path, prefix="tmp_pack_") with os.fdopen(fd, "w+b") as f: indexer = PackIndexer(f, resolve_ext_ref=self.get_raw) copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer) copier.verify() return self._complete_thin_pack(f, path, copier, indexer) def move_in_pack(self, path): """Move a specific file containing a pack into the pack directory. Note: The file should be on the same file system as the packs directory. Args: path: Path to the pack file. """ with PackData(path) as p: entries = p.sorted_entries() basename = self._get_pack_basepath(entries) index_name = basename + ".idx" if not os.path.exists(index_name): with GitFile(index_name, "wb") as f: write_pack_index_v2(f, entries, p.get_stored_checksum()) for pack in self.packs: if pack._basename == basename: return pack target_pack = basename + ".pack" if sys.platform == "win32": # Windows might have the target pack file lingering. Attempt # removal, silently passing if the target does not exist. try: os.remove(target_pack) except FileNotFoundError: pass os.rename(path, target_pack) final_pack = Pack(basename) self._add_cached_pack(basename, final_pack) return final_pack def add_pack(self): """Add a new pack to this object store. Returns: Fileobject to write to, a commit function to call when the pack is finished and an abort function. """ import tempfile fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack") f = os.fdopen(fd, "wb") def commit(): f.flush() os.fsync(fd) f.close() if os.path.getsize(path) > 0: return self.move_in_pack(path) else: os.remove(path) return None def abort(): f.close() os.remove(path) return f, commit, abort def add_object(self, obj): """Add a single object to this object store. Args: obj: Object to add """ path = self._get_shafile_path(obj.id) dir = os.path.dirname(path) try: os.mkdir(dir) except FileExistsError: pass if os.path.exists(path): return # Already there, no need to write again with GitFile(path, "wb") as f: f.write( obj.as_legacy_object(compression_level=self.loose_compression_level) ) @classmethod def init(cls, path): try: os.mkdir(path) except FileExistsError: pass os.mkdir(os.path.join(path, "info")) os.mkdir(os.path.join(path, PACKDIR)) return cls(path) class MemoryObjectStore(BaseObjectStore): """Object store that keeps all objects in memory.""" def __init__(self): super(MemoryObjectStore, self).__init__() self._data = {} self.pack_compression_level = -1 def _to_hexsha(self, sha): if len(sha) == 40: return sha elif len(sha) == 20: return sha_to_hex(sha) else: raise ValueError("Invalid sha %r" % (sha,)) def contains_loose(self, sha): """Check if a particular object is present by SHA1 and is loose.""" return self._to_hexsha(sha) in self._data def contains_packed(self, sha): """Check if a particular object is present by SHA1 and is packed.""" return False def __iter__(self): """Iterate over the SHAs that are present in this store.""" return iter(self._data.keys()) @property def packs(self): """List with pack objects.""" return [] def get_raw(self, name): """Obtain the raw text for an object. Args: name: sha for the object. Returns: tuple with numeric type and object contents. """ obj = self[self._to_hexsha(name)] return obj.type_num, obj.as_raw_string() def __getitem__(self, name): return self._data[self._to_hexsha(name)].copy() def __delitem__(self, name): """Delete an object from this store, for testing only.""" del self._data[self._to_hexsha(name)] def add_object(self, obj): """Add a single object to this object store.""" self._data[obj.id] = obj.copy() def add_objects(self, objects, progress=None): """Add a set of objects to this object store. Args: objects: Iterable over a list of (object, path) tuples """ for obj, path in objects: self.add_object(obj) def add_pack(self): """Add a new pack to this object store. Because this object store doesn't support packs, we extract and add the individual objects. Returns: Fileobject to write to and a commit function to call when the pack is finished. """ f = BytesIO() def commit(): p = PackData.from_file(BytesIO(f.getvalue()), f.tell()) f.close() for obj in PackInflater.for_pack_data(p, self.get_raw): self.add_object(obj) def abort(): pass return f, commit, abort def _complete_thin_pack(self, f, indexer): """Complete a thin pack by adding external references. Args: f: Open file object for the pack. indexer: A PackIndexer for indexing the pack. """ entries = list(indexer) # Update the header with the new number of objects. f.seek(0) write_pack_header(f, len(entries) + len(indexer.ext_refs())) # Rescan the rest of the pack, computing the SHA with the new header. new_sha = compute_file_sha(f, end_ofs=-20) # Complete the pack. for ext_sha in indexer.ext_refs(): assert len(ext_sha) == 20 type_num, data = self.get_raw(ext_sha) write_pack_object(f, type_num, data, sha=new_sha) pack_sha = new_sha.digest() f.write(pack_sha) def add_thin_pack(self, read_all, read_some): """Add a new thin pack to this object store. Thin packs are packs that contain deltas with parents that exist outside the pack. Because this object store doesn't support packs, we extract and add the individual objects. Args: read_all: Read function that blocks until the number of requested bytes are read. read_some: Read function that returns at least one byte, but may not return the number of bytes requested. """ f, commit, abort = self.add_pack() try: indexer = PackIndexer(f, resolve_ext_ref=self.get_raw) copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer) copier.verify() self._complete_thin_pack(f, indexer) except BaseException: abort() raise else: commit() class ObjectIterator(object): """Interface for iterating over objects.""" def iterobjects(self): raise NotImplementedError(self.iterobjects) class ObjectStoreIterator(ObjectIterator): """ObjectIterator that works on top of an ObjectStore.""" def __init__(self, store, sha_iter): """Create a new ObjectIterator. Args: store: Object store to retrieve from sha_iter: Iterator over (sha, path) tuples """ self.store = store self.sha_iter = sha_iter self._shas = [] def __iter__(self): """Yield tuple with next object and path.""" for sha, path in self.itershas(): yield self.store[sha], path def iterobjects(self): """Iterate over just the objects.""" for o, path in self: yield o def itershas(self): """Iterate over the SHAs.""" for sha in self._shas: yield sha for sha in self.sha_iter: self._shas.append(sha) yield sha def __contains__(self, needle): """Check if an object is present. Note: This checks if the object is present in the underlying object store, not if it would be yielded by the iterator. Args: needle: SHA1 of the object to check for """ if needle == ZERO_SHA: return False return needle in self.store def __getitem__(self, key): """Find an object by SHA1. Note: This retrieves the object from the underlying object store. It will also succeed if the object would not be returned by the iterator. """ return self.store[key] def __len__(self): """Return the number of objects.""" return len(list(self.itershas())) def empty(self): import warnings warnings.warn("Use bool() instead.", DeprecationWarning) return self._empty() def _empty(self): it = self.itershas() try: next(it) except StopIteration: return True else: return False def __bool__(self): """Indicate whether this object has contents.""" return not self._empty() def tree_lookup_path(lookup_obj, root_sha, path): """Look up an object in a Git tree. Args: lookup_obj: Callback for retrieving object by SHA1 root_sha: SHA1 of the root tree path: Path to lookup Returns: A tuple of (mode, SHA) of the resulting path. """ tree = lookup_obj(root_sha) if not isinstance(tree, Tree): raise NotTreeError(root_sha) return tree.lookup_path(lookup_obj, path) def _collect_filetree_revs(obj_store, tree_sha, kset): """Collect SHA1s of files and directories for specified tree. Args: obj_store: Object store to get objects by SHA from tree_sha: tree reference to walk kset: set to fill with references to files and directories """ filetree = obj_store[tree_sha] for name, mode, sha in filetree.iteritems(): if not S_ISGITLINK(mode) and sha not in kset: kset.add(sha) if stat.S_ISDIR(mode): _collect_filetree_revs(obj_store, sha, kset) def _split_commits_and_tags(obj_store, lst, ignore_unknown=False): """Split object id list into three lists with commit, tag, and other SHAs. Commits referenced by tags are included into commits list as well. Only SHA1s known in this repository will get through, and unless ignore_unknown argument is True, KeyError is thrown for SHA1 missing in the repository Args: obj_store: Object store to get objects by SHA1 from lst: Collection of commit and tag SHAs ignore_unknown: True to skip SHA1 missing in the repository silently. Returns: A tuple of (commits, tags, others) SHA1s """ commits = set() tags = set() others = set() for e in lst: try: o = obj_store[e] except KeyError: if not ignore_unknown: raise else: if isinstance(o, Commit): commits.add(e) elif isinstance(o, Tag): tags.add(e) tagged = o.object[1] c, t, o = _split_commits_and_tags( obj_store, [tagged], ignore_unknown=ignore_unknown ) commits |= c tags |= t others |= o else: others.add(e) return (commits, tags, others) class MissingObjectFinder(object): """Find the objects missing from another object store. Args: object_store: Object store containing at least all objects to be sent haves: SHA1s of commits not to send (already present in target) wants: SHA1s of commits to send progress: Optional function to report progress to. get_tagged: Function that returns a dict of pointed-to sha -> tag sha for including tags. get_parents: Optional function for getting the parents of a commit. tagged: dict of pointed-to sha -> tag sha for including tags """ def __init__( self, object_store, haves, wants, shallow=None, progress=None, get_tagged=None, get_parents=lambda commit: commit.parents, ): self.object_store = object_store if shallow is None: shallow = set() self._get_parents = get_parents # process Commits and Tags differently # Note, while haves may list commits/tags not available locally, # and such SHAs would get filtered out by _split_commits_and_tags, # wants shall list only known SHAs, and otherwise # _split_commits_and_tags fails with KeyError have_commits, have_tags, have_others = _split_commits_and_tags( object_store, haves, True ) want_commits, want_tags, want_others = _split_commits_and_tags( object_store, wants, False ) # all_ancestors is a set of commits that shall not be sent # (complete repository up to 'haves') all_ancestors = object_store._collect_ancestors( have_commits, shallow=shallow, get_parents=self._get_parents )[0] # all_missing - complete set of commits between haves and wants # common - commits from all_ancestors we hit into while # traversing parent hierarchy of wants missing_commits, common_commits = object_store._collect_ancestors( want_commits, all_ancestors, shallow=shallow, get_parents=self._get_parents, ) self.sha_done = set() # Now, fill sha_done with commits and revisions of # files and directories known to be both locally # and on target. Thus these commits and files # won't get selected for fetch for h in common_commits: self.sha_done.add(h) cmt = object_store[h] _collect_filetree_revs(object_store, cmt.tree, self.sha_done) # record tags we have as visited, too for t in have_tags: self.sha_done.add(t) missing_tags = want_tags.difference(have_tags) missing_others = want_others.difference(have_others) # in fact, what we 'want' is commits, tags, and others # we've found missing wants = missing_commits.union(missing_tags) wants = wants.union(missing_others) self.objects_to_send = set([(w, None, False) for w in wants]) if progress is None: self.progress = lambda x: None else: self.progress = progress self._tagged = get_tagged and get_tagged() or {} def add_todo(self, entries): self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done]) def next(self): while True: if not self.objects_to_send: return None (sha, name, leaf) = self.objects_to_send.pop() if sha not in self.sha_done: break if not leaf: o = self.object_store[sha] if isinstance(o, Commit): self.add_todo([(o.tree, "", False)]) elif isinstance(o, Tree): self.add_todo( [ (s, n, not stat.S_ISDIR(m)) for n, m, s in o.iteritems() if not S_ISGITLINK(m) ] ) elif isinstance(o, Tag): self.add_todo([(o.object[1], None, False)]) if sha in self._tagged: self.add_todo([(self._tagged[sha], None, True)]) self.sha_done.add(sha) self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii")) return (sha, name) __next__ = next class ObjectStoreGraphWalker(object): """Graph walker that finds what commits are missing from an object store. :ivar heads: Revisions without descendants in the local repo :ivar get_parents: Function to retrieve parents in the local repo """ def __init__(self, local_heads, get_parents, shallow=None): """Create a new instance. Args: local_heads: Heads to start search with get_parents: Function for finding the parents of a SHA1. """ self.heads = set(local_heads) self.get_parents = get_parents self.parents = {} if shallow is None: shallow = set() self.shallow = shallow def ack(self, sha): """Ack that a revision and its ancestors are present in the source.""" if len(sha) != 40: raise ValueError("unexpected sha %r received" % sha) ancestors = set([sha]) # stop if we run out of heads to remove while self.heads: for a in ancestors: if a in self.heads: self.heads.remove(a) # collect all ancestors new_ancestors = set() for a in ancestors: ps = self.parents.get(a) if ps is not None: new_ancestors.update(ps) self.parents[a] = None # no more ancestors; stop if not new_ancestors: break ancestors = new_ancestors def next(self): """Iterate over ancestors of heads in the target.""" if self.heads: ret = self.heads.pop() ps = self.get_parents(ret) self.parents[ret] = ps self.heads.update([p for p in ps if p not in self.parents]) return ret return None __next__ = next def commit_tree_changes(object_store, tree, changes): """Commit a specified set of changes to a tree structure. This will apply a set of changes on top of an existing tree, storing new objects in object_store. changes are a list of tuples with (path, mode, object_sha). Paths can be both blobs and trees. See the mode and object sha to None deletes the path. This method works especially well if there are only a small number of changes to a big tree. For a large number of changes to a large tree, use e.g. commit_tree. Args: object_store: Object store to store new objects in and retrieve old ones from. tree: Original tree root changes: changes to apply Returns: New tree root object """ # TODO(jelmer): Save up the objects and add them using .add_objects # rather than with individual calls to .add_object. nested_changes = {} for (path, new_mode, new_sha) in changes: try: (dirname, subpath) = path.split(b"/", 1) except ValueError: if new_sha is None: del tree[path] else: tree[path] = (new_mode, new_sha) else: nested_changes.setdefault(dirname, []).append((subpath, new_mode, new_sha)) for name, subchanges in nested_changes.items(): try: orig_subtree = object_store[tree[name][1]] except KeyError: orig_subtree = Tree() subtree = commit_tree_changes(object_store, orig_subtree, subchanges) if len(subtree) == 0: del tree[name] else: tree[name] = (stat.S_IFDIR, subtree.id) object_store.add_object(tree) return tree class OverlayObjectStore(BaseObjectStore): """Object store that can overlay multiple object stores.""" def __init__(self, bases, add_store=None): self.bases = bases self.add_store = add_store def add_object(self, object): if self.add_store is None: raise NotImplementedError(self.add_object) return self.add_store.add_object(object) def add_objects(self, objects, progress=None): if self.add_store is None: raise NotImplementedError(self.add_object) return self.add_store.add_objects(objects, progress) @property def packs(self): ret = [] for b in self.bases: ret.extend(b.packs) return ret def __iter__(self): done = set() for b in self.bases: for o_id in b: if o_id not in done: yield o_id done.add(o_id) def get_raw(self, sha_id): for b in self.bases: try: return b.get_raw(sha_id) except KeyError: pass raise KeyError(sha_id) def contains_packed(self, sha): for b in self.bases: if b.contains_packed(sha): return True return False def contains_loose(self, sha): for b in self.bases: if b.contains_loose(sha): return True return False def read_packs_file(f): """Yield the packs listed in a packs file.""" for line in f.read().splitlines(): if not line: continue (kind, name) = line.split(b" ", 1) if kind != b"P": continue yield os.fsdecode(name) class BucketBasedObjectStore(PackBasedObjectStore): """Object store implementation that uses a bucket store like S3 as backend. """ def _iter_loose_objects(self): """Iterate over the SHAs of all loose objects.""" return iter([]) def _get_loose_object(self, sha): return None def _remove_loose_object(self, sha): # Doesn't exist.. pass def _remove_pack(self, name): raise NotImplementedError(self._remove_pack) def _iter_pack_names(self): raise NotImplementedError(self._iter_pack_names) def _get_pack(self, name): raise NotImplementedError(self._get_pack) def _update_pack_cache(self): pack_files = set(self._iter_pack_names()) # Open newly appeared pack files new_packs = [] for f in pack_files: if f not in self._pack_cache: pack = self._get_pack(f) new_packs.append(pack) self._pack_cache[f] = pack # Remove disappeared pack files for f in set(self._pack_cache) - pack_files: self._pack_cache.pop(f).close() return new_packs def _upload_pack(self, basename, pack_file, index_file): raise NotImplementedError def add_pack(self): """Add a new pack to this object store. Returns: Fileobject to write to, a commit function to call when the pack is finished and an abort function. """ import tempfile pf = tempfile.SpooledTemporaryFile() def commit(): if pf.tell() == 0: pf.close() return None pf.seek(0) p = PackData(pf.name, pf) entries = p.sorted_entries() basename = iter_sha1(entry[0] for entry in entries).decode('ascii') idxf = tempfile.SpooledTemporaryFile() checksum = p.get_stored_checksum() write_pack_index_v2(idxf, entries, checksum) idxf.seek(0) idx = load_pack_index_file(basename + '.idx', idxf) for pack in self.packs: if pack.get_stored_checksum() == p.get_stored_checksum(): p.close() idx.close() return pack pf.seek(0) idxf.seek(0) self._upload_pack(basename, pf, idxf) final_pack = Pack.from_objects(p, idx) self._add_cached_pack(basename, final_pack) return final_pack return pf, commit, pf.close diff --git a/dulwich/protocol.py b/dulwich/protocol.py index 9641d06f..553e9cd0 100644 --- a/dulwich/protocol.py +++ b/dulwich/protocol.py @@ -1,583 +1,585 @@ # protocol.py -- Shared parts of the git protocols # Copyright (C) 2008 John Carr # Copyright (C) 2008-2012 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Generic functions for talking the git smart server protocol.""" from io import BytesIO from os import ( SEEK_END, ) import socket import dulwich from dulwich.errors import ( HangupException, GitProtocolError, ) TCP_GIT_PORT = 9418 ZERO_SHA = b"0" * 40 SINGLE_ACK = 0 MULTI_ACK = 1 MULTI_ACK_DETAILED = 2 # pack data SIDE_BAND_CHANNEL_DATA = 1 # progress messages SIDE_BAND_CHANNEL_PROGRESS = 2 # fatal error message just before stream aborts SIDE_BAND_CHANNEL_FATAL = 3 CAPABILITY_ATOMIC = b"atomic" CAPABILITY_DEEPEN_SINCE = b"deepen-since" CAPABILITY_DEEPEN_NOT = b"deepen-not" CAPABILITY_DEEPEN_RELATIVE = b"deepen-relative" CAPABILITY_DELETE_REFS = b"delete-refs" CAPABILITY_INCLUDE_TAG = b"include-tag" CAPABILITY_MULTI_ACK = b"multi_ack" CAPABILITY_MULTI_ACK_DETAILED = b"multi_ack_detailed" CAPABILITY_NO_DONE = b"no-done" CAPABILITY_NO_PROGRESS = b"no-progress" CAPABILITY_OFS_DELTA = b"ofs-delta" CAPABILITY_QUIET = b"quiet" CAPABILITY_REPORT_STATUS = b"report-status" CAPABILITY_SHALLOW = b"shallow" CAPABILITY_SIDE_BAND = b"side-band" CAPABILITY_SIDE_BAND_64K = b"side-band-64k" CAPABILITY_THIN_PACK = b"thin-pack" CAPABILITY_AGENT = b"agent" CAPABILITY_SYMREF = b"symref" CAPABILITY_ALLOW_TIP_SHA1_IN_WANT = b"allow-tip-sha1-in-want" CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT = b"allow-reachable-sha1-in-want" # Magic ref that is used to attach capabilities to when # there are no refs. Should always be ste to ZERO_SHA. CAPABILITIES_REF = b"capabilities^{}" COMMON_CAPABILITIES = [ CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND, CAPABILITY_SIDE_BAND_64K, CAPABILITY_AGENT, CAPABILITY_NO_PROGRESS, ] KNOWN_UPLOAD_CAPABILITIES = set( COMMON_CAPABILITIES + [ CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_INCLUDE_TAG, CAPABILITY_DEEPEN_SINCE, CAPABILITY_SYMREF, CAPABILITY_SHALLOW, CAPABILITY_DEEPEN_NOT, CAPABILITY_DEEPEN_RELATIVE, CAPABILITY_ALLOW_TIP_SHA1_IN_WANT, CAPABILITY_ALLOW_REACHABLE_SHA1_IN_WANT, ] ) KNOWN_RECEIVE_CAPABILITIES = set( COMMON_CAPABILITIES + [ CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS, CAPABILITY_QUIET, CAPABILITY_ATOMIC, ] ) +DEPTH_INFINITE = 0x7FFFFFFF + def agent_string(): return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii") def capability_agent(): return CAPABILITY_AGENT + b"=" + agent_string() def capability_symref(from_ref, to_ref): return CAPABILITY_SYMREF + b"=" + from_ref + b":" + to_ref def extract_capability_names(capabilities): return {parse_capability(c)[0] for c in capabilities} def parse_capability(capability): parts = capability.split(b"=", 1) if len(parts) == 1: return (parts[0], None) return tuple(parts) def symref_capabilities(symrefs): return [capability_symref(*k) for k in symrefs] COMMAND_DEEPEN = b"deepen" COMMAND_SHALLOW = b"shallow" COMMAND_UNSHALLOW = b"unshallow" COMMAND_DONE = b"done" COMMAND_WANT = b"want" COMMAND_HAVE = b"have" class ProtocolFile(object): """A dummy file for network ops that expect file-like objects.""" def __init__(self, read, write): self.read = read self.write = write def tell(self): pass def close(self): pass def format_cmd_pkt(cmd, *args): return cmd + b" " + b"".join([(a + b"\0") for a in args]) def parse_cmd_pkt(line): splice_at = line.find(b" ") cmd, args = line[:splice_at], line[splice_at + 1 :] assert args[-1:] == b"\x00" return cmd, args[:-1].split(b"\0") def pkt_line(data): """Wrap data in a pkt-line. Args: data: The data to wrap, as a str or None. Returns: The data prefixed with its length in pkt-line format; if data was None, returns the flush-pkt ('0000'). """ if data is None: return b"0000" return ("%04x" % (len(data) + 4)).encode("ascii") + data class Protocol(object): """Class for interacting with a remote git process over the wire. Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line consists of the length of the line as a 4-byte hex string, followed by the payload data. The length includes the 4-byte header. The special line '0000' indicates the end of a section of input and is called a 'flush-pkt'. For details on the pkt-line format, see the cgit distribution: Documentation/technical/protocol-common.txt """ def __init__(self, read, write, close=None, report_activity=None): self.read = read self.write = write self._close = close self.report_activity = report_activity self._readahead = None def close(self): if self._close: self._close() def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def read_pkt_line(self): """Reads a pkt-line from the remote git process. This method may read from the readahead buffer; see unread_pkt_line. Returns: The next string from the stream, without the length prefix, or None for a flush-pkt ('0000'). """ if self._readahead is None: read = self.read else: read = self._readahead.read self._readahead = None try: sizestr = read(4) if not sizestr: raise HangupException() size = int(sizestr, 16) if size == 0: if self.report_activity: self.report_activity(4, "read") return None if self.report_activity: self.report_activity(size, "read") pkt_contents = read(size - 4) except socket.error as e: raise GitProtocolError(e) else: if len(pkt_contents) + 4 != size: raise GitProtocolError( "Length of pkt read %04x does not match length prefix %04x" % (len(pkt_contents) + 4, size) ) return pkt_contents def eof(self): """Test whether the protocol stream has reached EOF. Note that this refers to the actual stream EOF and not just a flush-pkt. Returns: True if the stream is at EOF, False otherwise. """ try: next_line = self.read_pkt_line() except HangupException: return True self.unread_pkt_line(next_line) return False def unread_pkt_line(self, data): """Unread a single line of data into the readahead buffer. This method can be used to unread a single pkt-line into a fixed readahead buffer. Args: data: The data to unread, without the length prefix. Raises: ValueError: If more than one pkt-line is unread. """ if self._readahead is not None: raise ValueError("Attempted to unread multiple pkt-lines.") self._readahead = BytesIO(pkt_line(data)) def read_pkt_seq(self): """Read a sequence of pkt-lines from the remote git process. Returns: Yields each line of data up to but not including the next flush-pkt. """ pkt = self.read_pkt_line() while pkt: yield pkt pkt = self.read_pkt_line() def write_pkt_line(self, line): """Sends a pkt-line to the remote git process. Args: line: A string containing the data to send, without the length prefix. """ try: line = pkt_line(line) self.write(line) if self.report_activity: self.report_activity(len(line), "write") except socket.error as e: raise GitProtocolError(e) def write_file(self): """Return a writable file-like object for this protocol.""" class ProtocolFile(object): def __init__(self, proto): self._proto = proto self._offset = 0 def write(self, data): self._proto.write(data) self._offset += len(data) def tell(self): return self._offset def close(self): pass return ProtocolFile(self) def write_sideband(self, channel, blob): """Write multiplexed data to the sideband. Args: channel: An int specifying the channel to write to. blob: A blob of data (as a string) to send on this channel. """ # a pktline can be a max of 65520. a sideband line can therefore be # 65520-5 = 65515 # WTF: Why have the len in ASCII, but the channel in binary. while blob: self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515]) blob = blob[65515:] def send_cmd(self, cmd, *args): """Send a command and some arguments to a git server. Only used for the TCP git protocol (git://). Args: cmd: The remote service to access. args: List of arguments to send to remove service. """ self.write_pkt_line(format_cmd_pkt(cmd, *args)) def read_cmd(self): """Read a command and some arguments from the git client Only used for the TCP git protocol (git://). Returns: A tuple of (command, [list of arguments]). """ line = self.read_pkt_line() return parse_cmd_pkt(line) _RBUFSIZE = 8192 # Default read buffer size. class ReceivableProtocol(Protocol): """Variant of Protocol that allows reading up to a size without blocking. This class has a recv() method that behaves like socket.recv() in addition to a read() method. If you want to read n bytes from the wire and block until exactly n bytes (or EOF) are read, use read(n). If you want to read at most n bytes from the wire but don't care if you get less, use recv(n). Note that recv(n) will still block until at least one byte is read. """ def __init__( self, recv, write, close=None, report_activity=None, rbufsize=_RBUFSIZE ): super(ReceivableProtocol, self).__init__( self.read, write, close=close, report_activity=report_activity ) self._recv = recv self._rbuf = BytesIO() self._rbufsize = rbufsize def read(self, size): # From _fileobj.read in socket.py in the Python 2.6.5 standard library, # with the following modifications: # - omit the size <= 0 branch # - seek back to start rather than 0 in case some buffer has been # consumed. # - use SEEK_END instead of the magic number. # Copyright (c) 2001-2010 Python Software Foundation; All Rights # Reserved # Licensed under the Python Software Foundation License. # TODO: see if buffer is more efficient than cBytesIO. assert size > 0 # Our use of BytesIO rather than lists of string objects returned by # recv() minimizes memory usage and fragmentation that occurs when # rbufsize is large compared to the typical return value of recv(). buf = self._rbuf start = buf.tell() buf.seek(0, SEEK_END) # buffer may have been partially consumed by recv() buf_len = buf.tell() - start if buf_len >= size: # Already have size bytes in our buffer? Extract and return. buf.seek(start) rv = buf.read(size) self._rbuf = BytesIO() self._rbuf.write(buf.read()) self._rbuf.seek(0) return rv self._rbuf = BytesIO() # reset _rbuf. we consume it via buf. while True: left = size - buf_len # recv() will malloc the amount of memory given as its # parameter even though it often returns much less data # than that. The returned data string is short lived # as we copy it into a BytesIO and free it. This avoids # fragmentation issues on many platforms. data = self._recv(left) if not data: break n = len(data) if n == size and not buf_len: # Shortcut. Avoid buffer data copies when: # - We have no data in our buffer. # AND # - Our call to recv returned exactly the # number of bytes we were asked to read. return data if n == left: buf.write(data) del data # explicit free break assert n <= left, "_recv(%d) returned %d bytes" % (left, n) buf.write(data) buf_len += n del data # explicit free # assert buf_len == buf.tell() buf.seek(start) return buf.read() def recv(self, size): assert size > 0 buf = self._rbuf start = buf.tell() buf.seek(0, SEEK_END) buf_len = buf.tell() buf.seek(start) left = buf_len - start if not left: # only read from the wire if our read buffer is exhausted data = self._recv(self._rbufsize) if len(data) == size: # shortcut: skip the buffer if we read exactly size bytes return data buf = BytesIO() buf.write(data) buf.seek(0) del data # explicit free self._rbuf = buf return buf.read(size) def extract_capabilities(text): """Extract a capabilities list from a string, if present. Args: text: String to extract from Returns: Tuple with text with capabilities removed and list of capabilities """ if b"\0" not in text: return text, [] text, capabilities = text.rstrip().split(b"\0") return (text, capabilities.strip().split(b" ")) def extract_want_line_capabilities(text): """Extract a capabilities list from a want line, if present. Note that want lines have capabilities separated from the rest of the line by a space instead of a null byte. Thus want lines have the form: want obj-id cap1 cap2 ... Args: text: Want line to extract from Returns: Tuple with text with capabilities removed and list of capabilities """ split_text = text.rstrip().split(b" ") if len(split_text) < 3: return text, [] return (b" ".join(split_text[:2]), split_text[2:]) def ack_type(capabilities): """Extract the ack type from a capabilities list.""" if b"multi_ack_detailed" in capabilities: return MULTI_ACK_DETAILED elif b"multi_ack" in capabilities: return MULTI_ACK return SINGLE_ACK class BufferedPktLineWriter(object): """Writer that wraps its data in pkt-lines and has an independent buffer. Consecutive calls to write() wrap the data in a pkt-line and then buffers it until enough lines have been written such that their total length (including length prefix) reach the buffer size. """ def __init__(self, write, bufsize=65515): """Initialize the BufferedPktLineWriter. Args: write: A write callback for the underlying writer. bufsize: The internal buffer size, including length prefixes. """ self._write = write self._bufsize = bufsize self._wbuf = BytesIO() self._buflen = 0 def write(self, data): """Write data, wrapping it in a pkt-line.""" line = pkt_line(data) line_len = len(line) over = self._buflen + line_len - self._bufsize if over >= 0: start = line_len - over self._wbuf.write(line[:start]) self.flush() else: start = 0 saved = line[start:] self._wbuf.write(saved) self._buflen += len(saved) def flush(self): """Flush all data from the buffer.""" data = self._wbuf.getvalue() if data: self._write(data) self._len = 0 self._wbuf = BytesIO() class PktLineParser(object): """Packet line parser that hands completed packets off to a callback.""" def __init__(self, handle_pkt): self.handle_pkt = handle_pkt self._readahead = BytesIO() def parse(self, data): """Parse a fragment of data and call back for any completed packets.""" self._readahead.write(data) buf = self._readahead.getvalue() if len(buf) < 4: return while len(buf) >= 4: size = int(buf[:4], 16) if size == 0: self.handle_pkt(None) buf = buf[4:] elif size <= len(buf): self.handle_pkt(buf[4:size]) buf = buf[size:] else: break self._readahead = BytesIO() self._readahead.write(buf) def get_tail(self): """Read back any unused data.""" return self._readahead.getvalue()