diff --git a/dulwich/client.py b/dulwich/client.py index 712f542d..bf93b4f3 100644 --- a/dulwich/client.py +++ b/dulwich/client.py @@ -1,1897 +1,1896 @@ # client.py -- Implementation of the client side git protocols # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Client side support for the Git protocol. The Dulwich client supports the following capabilities: * thin-pack * multi_ack_detailed * multi_ack * side-band-64k * ofs-delta * quiet * report-status * delete-refs * shallow Known capabilities that are not supported: * no-progress * include-tag """ from contextlib import closing from io import BytesIO, BufferedReader import errno import os import select import socket import subprocess import sys from urllib.parse import ( quote as urlquote, unquote as urlunquote, urlparse, urljoin, - urlunparse, urlunsplit, urlunparse, ) import dulwich from dulwich.config import get_xdg_config_home_path from dulwich.errors import ( GitProtocolError, NotGitRepository, SendPackError, UpdateRefsError, ) from dulwich.protocol import ( HangupException, _RBUFSIZE, agent_string, capability_agent, extract_capability_names, CAPABILITY_AGENT, CAPABILITY_DELETE_REFS, CAPABILITY_INCLUDE_TAG, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_OFS_DELTA, CAPABILITY_QUIET, CAPABILITY_REPORT_STATUS, CAPABILITY_SHALLOW, CAPABILITY_SYMREF, CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK, CAPABILITIES_REF, KNOWN_RECEIVE_CAPABILITIES, KNOWN_UPLOAD_CAPABILITIES, COMMAND_DEEPEN, COMMAND_SHALLOW, COMMAND_UNSHALLOW, COMMAND_DONE, COMMAND_HAVE, COMMAND_WANT, SIDE_BAND_CHANNEL_DATA, SIDE_BAND_CHANNEL_PROGRESS, SIDE_BAND_CHANNEL_FATAL, PktLineParser, Protocol, ProtocolFile, TCP_GIT_PORT, ZERO_SHA, extract_capabilities, parse_capability, ) from dulwich.pack import ( write_pack_data, write_pack_objects, ) from dulwich.refs import ( read_info_refs, ANNOTATED_TAG_SUFFIX, ) class InvalidWants(Exception): """Invalid wants.""" def __init__(self, wants): Exception.__init__( self, "requested wants not in server provided refs: %r" % wants) def _fileno_can_read(fileno): """Check if a file descriptor is readable. """ return len(select.select([fileno], [], [], 0)[0]) > 0 def _win32_peek_avail(handle): """Wrapper around PeekNamedPipe to check how many bytes are available. """ from ctypes import byref, wintypes, windll c_avail = wintypes.DWORD() c_message = wintypes.DWORD() success = windll.kernel32.PeekNamedPipe( handle, None, 0, None, byref(c_avail), byref(c_message)) if not success: raise OSError(wintypes.GetLastError()) return c_avail.value COMMON_CAPABILITIES = [CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND_64K] UPLOAD_CAPABILITIES = ([CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK, CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_SHALLOW] + COMMON_CAPABILITIES) RECEIVE_CAPABILITIES = ( [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS] + COMMON_CAPABILITIES) class ReportStatusParser(object): """Handle status as reported by servers with 'report-status' capability.""" def __init__(self): self._done = False self._pack_status = None self._ref_status_ok = True self._ref_statuses = [] def check(self): """Check if there were any errors and, if so, raise exceptions. Raises: SendPackError: Raised when the server could not unpack UpdateRefsError: Raised when refs could not be updated """ if self._pack_status not in (b'unpack ok', None): raise SendPackError(self._pack_status) if not self._ref_status_ok: ref_status = {} ok = set() for status in self._ref_statuses: if b' ' not in status: # malformed response, move on to the next one continue status, ref = status.split(b' ', 1) if status == b'ng': if b' ' in ref: ref, status = ref.split(b' ', 1) else: ok.add(ref) ref_status[ref] = status # TODO(jelmer): don't assume encoding of refs is ascii. raise UpdateRefsError(', '.join([ refname.decode('ascii') for refname in ref_status if refname not in ok]) + ' failed to update', ref_status=ref_status) def handle_packet(self, pkt): """Handle a packet. Raises: GitProtocolError: Raised when packets are received after a flush packet. """ if self._done: raise GitProtocolError("received more data after status report") if pkt is None: self._done = True return if self._pack_status is None: self._pack_status = pkt.strip() else: ref_status = pkt.strip() self._ref_statuses.append(ref_status) if not ref_status.startswith(b'ok '): self._ref_status_ok = False def read_pkt_refs(proto): server_capabilities = None refs = {} # Receive refs from server for pkt in proto.read_pkt_seq(): (sha, ref) = pkt.rstrip(b'\n').split(None, 1) if sha == b'ERR': raise GitProtocolError(ref.decode('utf-8', 'replace')) if server_capabilities is None: (ref, server_capabilities) = extract_capabilities(ref) refs[ref] = sha if len(refs) == 0: return {}, set([]) if refs == {CAPABILITIES_REF: ZERO_SHA}: refs = {} return refs, set(server_capabilities) class FetchPackResult(object): """Result of a fetch-pack operation. Attributes: refs: Dictionary with all remote refs symrefs: Dictionary with remote symrefs agent: User agent string """ _FORWARDED_ATTRS = [ 'clear', 'copy', 'fromkeys', 'get', 'has_key', 'items', 'iteritems', 'iterkeys', 'itervalues', 'keys', 'pop', 'popitem', 'setdefault', 'update', 'values', 'viewitems', 'viewkeys', 'viewvalues'] def __init__(self, refs, symrefs, agent, new_shallow=None, new_unshallow=None): self.refs = refs self.symrefs = symrefs self.agent = agent self.new_shallow = new_shallow self.new_unshallow = new_unshallow def _warn_deprecated(self): import warnings warnings.warn( "Use FetchPackResult.refs instead.", DeprecationWarning, stacklevel=3) def __eq__(self, other): if isinstance(other, dict): self._warn_deprecated() return (self.refs == other) return (self.refs == other.refs and self.symrefs == other.symrefs and self.agent == other.agent) def __contains__(self, name): self._warn_deprecated() return name in self.refs def __getitem__(self, name): self._warn_deprecated() return self.refs[name] def __len__(self): self._warn_deprecated() return len(self.refs) def __iter__(self): self._warn_deprecated() return iter(self.refs) def __getattribute__(self, name): if name in type(self)._FORWARDED_ATTRS: self._warn_deprecated() return getattr(self.refs, name) return super(FetchPackResult, self).__getattribute__(name) def __repr__(self): return "%s(%r, %r, %r)" % ( self.__class__.__name__, self.refs, self.symrefs, self.agent) def _read_shallow_updates(proto): new_shallow = set() new_unshallow = set() for pkt in proto.read_pkt_seq(): cmd, sha = pkt.split(b' ', 1) if cmd == COMMAND_SHALLOW: new_shallow.add(sha.strip()) elif cmd == COMMAND_UNSHALLOW: new_unshallow.add(sha.strip()) else: raise GitProtocolError('unknown command %s' % pkt) return (new_shallow, new_unshallow) # TODO(durin42): this doesn't correctly degrade if the server doesn't # support some capabilities. This should work properly with servers # that don't support multi_ack. class GitClient(object): """Git smart server client.""" def __init__(self, thin_packs=True, report_activity=None, quiet=False, include_tags=False): """Create a new GitClient instance. Args: thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. include_tags: send annotated tags when sending the objects they point to """ self._report_activity = report_activity self._report_status_parser = None self._fetch_capabilities = set(UPLOAD_CAPABILITIES) self._fetch_capabilities.add(capability_agent()) self._send_capabilities = set(RECEIVE_CAPABILITIES) self._send_capabilities.add(capability_agent()) if quiet: self._send_capabilities.add(CAPABILITY_QUIET) if not thin_packs: self._fetch_capabilities.remove(CAPABILITY_THIN_PACK) if include_tags: self._fetch_capabilities.add(CAPABILITY_INCLUDE_TAG) def get_url(self, path): """Retrieves full url to given path. Args: path: Repository path (as string) Returns: Url to path (as string) """ raise NotImplementedError(self.get_url) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): """Create an instance of this client from a urlparse.parsed object. Args: parsedurl: Result of urlparse() Returns: A `GitClient` object """ raise NotImplementedError(cls.from_parsedurl) def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of objects and list of pack data to include progress: Optional progress function Returns: new_refs dictionary containing the changes that were made {refname: new_ref}, including deleted refs. Raises: SendPackError: if server rejects the pack data UpdateRefsError: if the server supports report-status and rejects ref updates """ raise NotImplementedError(self.send_pack) def fetch(self, path, target, determine_wants=None, progress=None, depth=None): """Fetch into a target repository. Args: path: Path to fetch from (as bytestring) target: Target repository to fetch into determine_wants: Optional function to determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. Defaults to all shas. progress: Optional progress function depth: Depth to fetch at Returns: Dictionary with all remote refs (not just those fetched) """ if determine_wants is None: determine_wants = target.object_store.determine_wants_all if CAPABILITY_THIN_PACK in self._fetch_capabilities: # TODO(jelmer): Avoid reading entire file into memory and # only processing it after the whole file has been fetched. f = BytesIO() def commit(): if f.tell(): f.seek(0) target.object_store.add_thin_pack(f.read, None) def abort(): pass else: f, commit, abort = target.object_store.add_pack() try: result = self.fetch_pack( path, determine_wants, target.get_graph_walker(), f.write, progress=progress, depth=depth) except BaseException: abort() raise else: commit() target.update_shallow(result.new_shallow, result.new_unshallow) return result def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ raise NotImplementedError(self.fetch_pack) def get_refs(self, path): """Retrieve the current refs from a git smart server. Args: path: Path to the repo to fetch from. (as bytestring) Returns: """ raise NotImplementedError(self.get_refs) def _parse_status_report(self, proto): unpack = proto.read_pkt_line().strip() if unpack != b'unpack ok': st = True # flush remaining error data while st is not None: st = proto.read_pkt_line() raise SendPackError(unpack) statuses = [] errs = False ref_status = proto.read_pkt_line() while ref_status: ref_status = ref_status.strip() statuses.append(ref_status) if not ref_status.startswith(b'ok '): errs = True ref_status = proto.read_pkt_line() if errs: ref_status = {} ok = set() for status in statuses: if b' ' not in status: # malformed response, move on to the next one continue status, ref = status.split(b' ', 1) if status == b'ng': if b' ' in ref: ref, status = ref.split(b' ', 1) else: ok.add(ref) ref_status[ref] = status raise UpdateRefsError(', '.join([ refname for refname in ref_status if refname not in ok]) + b' failed to update', ref_status=ref_status) def _read_side_band64k_data(self, proto, channel_callbacks): """Read per-channel data. This requires the side-band-64k capability. Args: proto: Protocol object to read from channel_callbacks: Dictionary mapping channels to packet handlers to use. None for a callback discards channel data. """ for pkt in proto.read_pkt_seq(): channel = ord(pkt[:1]) pkt = pkt[1:] try: cb = channel_callbacks[channel] except KeyError: raise AssertionError('Invalid sideband channel %d' % channel) else: if cb is not None: cb(pkt) def _handle_receive_pack_head(self, proto, capabilities, old_refs, new_refs): """Handle the head of a 'git-receive-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities old_refs: Old refs, as received from the server new_refs: Refs to change Returns: have, want) tuple """ want = [] have = [x for x in old_refs.values() if not x == ZERO_SHA] sent_capabilities = False for refname in new_refs: if not isinstance(refname, bytes): raise TypeError('refname is not a bytestring: %r' % refname) old_sha1 = old_refs.get(refname, ZERO_SHA) if not isinstance(old_sha1, bytes): raise TypeError('old sha1 for %s is not a bytestring: %r' % (refname, old_sha1)) new_sha1 = new_refs.get(refname, ZERO_SHA) if not isinstance(new_sha1, bytes): raise TypeError('old sha1 for %s is not a bytestring %r' % (refname, new_sha1)) if old_sha1 != new_sha1: if sent_capabilities: proto.write_pkt_line(old_sha1 + b' ' + new_sha1 + b' ' + refname) else: proto.write_pkt_line( old_sha1 + b' ' + new_sha1 + b' ' + refname + b'\0' + b' '.join(sorted(capabilities))) sent_capabilities = True if new_sha1 not in have and new_sha1 != ZERO_SHA: want.append(new_sha1) proto.write_pkt_line(None) return (have, want) def _negotiate_receive_pack_capabilities(self, server_capabilities): negotiated_capabilities = ( self._send_capabilities & server_capabilities) unknown_capabilities = ( # noqa: F841 extract_capability_names(server_capabilities) - KNOWN_RECEIVE_CAPABILITIES) # TODO(jelmer): warn about unknown capabilities return negotiated_capabilities def _handle_receive_pack_tail(self, proto, capabilities, progress=None): """Handle the tail of a 'git-receive-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities progress: Optional progress reporting function Returns: """ if CAPABILITY_SIDE_BAND_64K in capabilities: if progress is None: def progress(x): pass channel_callbacks = {2: progress} if CAPABILITY_REPORT_STATUS in capabilities: channel_callbacks[1] = PktLineParser( self._report_status_parser.handle_packet).parse self._read_side_band64k_data(proto, channel_callbacks) else: if CAPABILITY_REPORT_STATUS in capabilities: for pkt in proto.read_pkt_seq(): self._report_status_parser.handle_packet(pkt) if self._report_status_parser is not None: self._report_status_parser.check() def _negotiate_upload_pack_capabilities(self, server_capabilities): unknown_capabilities = ( # noqa: F841 extract_capability_names(server_capabilities) - KNOWN_UPLOAD_CAPABILITIES) # TODO(jelmer): warn about unknown capabilities symrefs = {} agent = None for capability in server_capabilities: k, v = parse_capability(capability) if k == CAPABILITY_SYMREF: (src, dst) = v.split(b':', 1) symrefs[src] = dst if k == CAPABILITY_AGENT: agent = v negotiated_capabilities = ( self._fetch_capabilities & server_capabilities) return (negotiated_capabilities, symrefs, agent) def _handle_upload_pack_head(self, proto, capabilities, graph_walker, wants, can_read, depth): """Handle the head of a 'git-upload-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities graph_walker: GraphWalker instance to call .ack() on wants: List of commits to fetch can_read: function that returns a boolean that indicates whether there is extra graph data to read on proto depth: Depth for request Returns: """ assert isinstance(wants, list) and isinstance(wants[0], bytes) proto.write_pkt_line(COMMAND_WANT + b' ' + wants[0] + b' ' + b' '.join(sorted(capabilities)) + b'\n') for want in wants[1:]: proto.write_pkt_line(COMMAND_WANT + b' ' + want + b'\n') if depth not in (0, None) or getattr(graph_walker, 'shallow', None): if CAPABILITY_SHALLOW not in capabilities: raise GitProtocolError( "server does not support shallow capability required for " "depth") for sha in graph_walker.shallow: proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha + b'\n') if depth is not None: proto.write_pkt_line(COMMAND_DEEPEN + b' ' + str(depth).encode('ascii') + b'\n') proto.write_pkt_line(None) if can_read is not None: (new_shallow, new_unshallow) = _read_shallow_updates(proto) else: new_shallow = new_unshallow = None else: new_shallow = new_unshallow = set() proto.write_pkt_line(None) have = next(graph_walker) while have: proto.write_pkt_line(COMMAND_HAVE + b' ' + have + b'\n') if can_read is not None and can_read(): pkt = proto.read_pkt_line() parts = pkt.rstrip(b'\n').split(b' ') if parts[0] == b'ACK': graph_walker.ack(parts[1]) if parts[2] in (b'continue', b'common'): pass elif parts[2] == b'ready': break else: raise AssertionError( "%s not in ('continue', 'ready', 'common)" % parts[2]) have = next(graph_walker) proto.write_pkt_line(COMMAND_DONE + b'\n') return (new_shallow, new_unshallow) def _handle_upload_pack_tail(self, proto, capabilities, graph_walker, pack_data, progress=None, rbufsize=_RBUFSIZE): """Handle the tail of a 'git-upload-pack' request. Args: proto: Protocol object to read from capabilities: List of negotiated capabilities graph_walker: GraphWalker instance to call .ack() on pack_data: Function to call with pack data progress: Optional progress reporting function rbufsize: Read buffer size Returns: """ pkt = proto.read_pkt_line() while pkt: parts = pkt.rstrip(b'\n').split(b' ') if parts[0] == b'ACK': graph_walker.ack(parts[1]) if len(parts) < 3 or parts[2] not in ( b'ready', b'continue', b'common'): break pkt = proto.read_pkt_line() if CAPABILITY_SIDE_BAND_64K in capabilities: if progress is None: # Just ignore progress data def progress(x): pass self._read_side_band64k_data(proto, { SIDE_BAND_CHANNEL_DATA: pack_data, SIDE_BAND_CHANNEL_PROGRESS: progress} ) else: while True: data = proto.read(rbufsize) if data == b"": break pack_data(data) def check_wants(wants, refs): """Check that a set of wants is valid. Args: wants: Set of object SHAs to fetch refs: Refs dictionary to check against Returns: """ missing = set(wants) - { v for (k, v) in refs.items() if not k.endswith(ANNOTATED_TAG_SUFFIX)} if missing: raise InvalidWants(missing) def remote_error_from_stderr(stderr): if stderr is None: return HangupException() for line in stderr.readlines(): if line.startswith(b'ERROR: '): return GitProtocolError( line[len(b'ERROR: '):].decode('utf-8', 'replace')) return GitProtocolError(line.decode('utf-8', 'replace')) return HangupException() class TraditionalGitClient(GitClient): """Traditional Git client.""" DEFAULT_ENCODING = 'utf-8' def __init__(self, path_encoding=DEFAULT_ENCODING, **kwargs): self._remote_path_encoding = path_encoding super(TraditionalGitClient, self).__init__(**kwargs) def _connect(self, cmd, path): """Create a connection to the server. This method is abstract - concrete implementations should implement their own variant which connects to the server and returns an initialized Protocol object with the service ready for use and a can_read function which may be used to see if reads would block. Args: cmd: The git service name to which we should connect. path: The path we should pass to the service. (as bytestirng) """ raise NotImplementedError() def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of objects and pack data to upload. progress: Optional callback called with progress updates Returns: new_refs dictionary containing the changes that were made {refname: new_ref}, including deleted refs. Raises: SendPackError: if server rejects the pack data UpdateRefsError: if the server supports report-status and rejects ref updates """ proto, unused_can_read, stderr = self._connect(b'receive-pack', path) with proto: try: old_refs, server_capabilities = read_pkt_refs(proto) except HangupException: raise remote_error_from_stderr(stderr) negotiated_capabilities = \ self._negotiate_receive_pack_capabilities(server_capabilities) if CAPABILITY_REPORT_STATUS in negotiated_capabilities: self._report_status_parser = ReportStatusParser() report_status_parser = self._report_status_parser try: new_refs = orig_new_refs = update_refs(dict(old_refs)) except BaseException: proto.write_pkt_line(None) raise if CAPABILITY_DELETE_REFS not in server_capabilities: # Server does not support deletions. Fail later. new_refs = dict(orig_new_refs) for ref, sha in orig_new_refs.items(): if sha == ZERO_SHA: if CAPABILITY_REPORT_STATUS in negotiated_capabilities: report_status_parser._ref_statuses.append( b'ng ' + sha + b' remote does not support deleting refs') report_status_parser._ref_status_ok = False del new_refs[ref] if new_refs is None: proto.write_pkt_line(None) return old_refs if len(new_refs) == 0 and len(orig_new_refs): # NOOP - Original new refs filtered out by policy proto.write_pkt_line(None) if report_status_parser is not None: report_status_parser.check() return old_refs (have, want) = self._handle_receive_pack_head( proto, negotiated_capabilities, old_refs, new_refs) if (not want and set(new_refs.items()).issubset(set(old_refs.items()))): return new_refs pack_data_count, pack_data = generate_pack_data( have, want, ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities)) dowrite = bool(pack_data_count) dowrite = dowrite or any(old_refs.get(ref) != sha for (ref, sha) in new_refs.items() if sha != ZERO_SHA) if dowrite: write_pack_data(proto.write_file(), pack_data_count, pack_data) self._handle_receive_pack_tail( proto, negotiated_capabilities, progress) return new_refs def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ proto, can_read, stderr = self._connect(b'upload-pack', path) with proto: try: refs, server_capabilities = read_pkt_refs(proto) except HangupException: raise remote_error_from_stderr(stderr) negotiated_capabilities, symrefs, agent = ( self._negotiate_upload_pack_capabilities( server_capabilities)) if refs is None: proto.write_pkt_line(None) return FetchPackResult(refs, symrefs, agent) try: wants = determine_wants(refs) except BaseException: proto.write_pkt_line(None) raise if wants is not None: wants = [cid for cid in wants if cid != ZERO_SHA] if not wants: proto.write_pkt_line(None) return FetchPackResult(refs, symrefs, agent) (new_shallow, new_unshallow) = self._handle_upload_pack_head( proto, negotiated_capabilities, graph_walker, wants, can_read, depth=depth) self._handle_upload_pack_tail( proto, negotiated_capabilities, graph_walker, pack_data, progress) return FetchPackResult( refs, symrefs, agent, new_shallow, new_unshallow) def get_refs(self, path): """Retrieve the current refs from a git smart server. """ # stock `git ls-remote` uses upload-pack proto, _, stderr = self._connect(b'upload-pack', path) with proto: try: refs, _ = read_pkt_refs(proto) except HangupException: raise remote_error_from_stderr(stderr) proto.write_pkt_line(None) return refs def archive(self, path, committish, write_data, progress=None, write_error=None, format=None, subdirs=None, prefix=None): proto, can_read, stderr = self._connect(b'upload-archive', path) with proto: if format is not None: proto.write_pkt_line(b"argument --format=" + format) proto.write_pkt_line(b"argument " + committish) if subdirs is not None: for subdir in subdirs: proto.write_pkt_line(b"argument " + subdir) if prefix is not None: proto.write_pkt_line(b"argument --prefix=" + prefix) proto.write_pkt_line(None) try: pkt = proto.read_pkt_line() except HangupException: raise remote_error_from_stderr(stderr) if pkt == b"NACK\n": return elif pkt == b"ACK\n": pass elif pkt.startswith(b"ERR "): raise GitProtocolError( pkt[4:].rstrip(b"\n").decode('utf-8', 'replace')) else: raise AssertionError("invalid response %r" % pkt) ret = proto.read_pkt_line() if ret is not None: raise AssertionError("expected pkt tail") self._read_side_band64k_data(proto, { SIDE_BAND_CHANNEL_DATA: write_data, SIDE_BAND_CHANNEL_PROGRESS: progress, SIDE_BAND_CHANNEL_FATAL: write_error}) class TCPGitClient(TraditionalGitClient): """A Git Client that works over TCP directly (i.e. git://).""" def __init__(self, host, port=None, **kwargs): if port is None: port = TCP_GIT_PORT self._host = host self._port = port super(TCPGitClient, self).__init__(**kwargs) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(parsedurl.hostname, port=parsedurl.port, **kwargs) def get_url(self, path): netloc = self._host if self._port is not None and self._port != TCP_GIT_PORT: netloc += ":%d" % self._port return urlunsplit(("git", netloc, path, '', '')) def _connect(self, cmd, path): if not isinstance(cmd, bytes): raise TypeError(cmd) if not isinstance(path, bytes): path = path.encode(self._remote_path_encoding) sockaddrs = socket.getaddrinfo( self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM) s = None err = socket.error("no address found for %s" % self._host) for (family, socktype, proto, canonname, sockaddr) in sockaddrs: s = socket.socket(family, socktype, proto) s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: s.connect(sockaddr) break except socket.error as e: err = e if s is not None: s.close() s = None if s is None: raise err # -1 means system default buffering rfile = s.makefile('rb', -1) # 0 means unbuffered wfile = s.makefile('wb', 0) def close(): rfile.close() wfile.close() s.close() proto = Protocol(rfile.read, wfile.write, close, report_activity=self._report_activity) if path.startswith(b"/~"): path = path[1:] # TODO(jelmer): Alternative to ascii? proto.send_cmd( b'git-' + cmd, path, b'host=' + self._host.encode('ascii')) return proto, lambda: _fileno_can_read(s), None class SubprocessWrapper(object): """A socket-like object that talks to a subprocess via pipes.""" def __init__(self, proc): self.proc = proc self.read = BufferedReader(proc.stdout).read self.write = proc.stdin.write @property def stderr(self): return self.proc.stderr def can_read(self): if sys.platform == 'win32': from msvcrt import get_osfhandle handle = get_osfhandle(self.proc.stdout.fileno()) return _win32_peek_avail(handle) != 0 else: return _fileno_can_read(self.proc.stdout.fileno()) def close(self): self.proc.stdin.close() self.proc.stdout.close() if self.proc.stderr: self.proc.stderr.close() self.proc.wait() def find_git_command(): """Find command to run for system Git (usually C Git).""" if sys.platform == 'win32': # support .exe, .bat and .cmd try: # to avoid overhead import win32api except ImportError: # run through cmd.exe with some overhead return ['cmd', '/c', 'git'] else: status, git = win32api.FindExecutable('git') return [git] else: return ['git'] class SubprocessGitClient(TraditionalGitClient): """Git client that talks to a server using a subprocess.""" @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(**kwargs) git_command = None def _connect(self, service, path): if not isinstance(service, bytes): raise TypeError(service) if isinstance(path, bytes): path = path.decode(self._remote_path_encoding) if self.git_command is None: git_command = find_git_command() argv = git_command + [service.decode('ascii'), path] p = subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) pw = SubprocessWrapper(p) return (Protocol(pw.read, pw.write, pw.close, report_activity=self._report_activity), pw.can_read, p.stderr) class LocalGitClient(GitClient): """Git Client that just uses a local Repo.""" def __init__(self, thin_packs=True, report_activity=None, config=None): """Create a new LocalGitClient instance. Args: thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. """ self._report_activity = report_activity # Ignore the thin_packs argument def get_url(self, path): return urlunsplit(('file', '', path, '', '')) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(**kwargs) @classmethod def _open_repo(cls, path): from dulwich.repo import Repo if not isinstance(path, str): path = os.fsdecode(path) return closing(Repo(path)) def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receive dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) with number of items and pack data to upload. progress: Optional progress function Returns: new_refs dictionary containing the changes that were made {refname: new_ref}, including deleted refs. Raises: SendPackError: if server rejects the pack data UpdateRefsError: if the server supports report-status and rejects ref updates """ if not progress: def progress(x): pass with self._open_repo(path) as target: old_refs = target.get_refs() new_refs = update_refs(dict(old_refs)) have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA] want = [] for refname, new_sha1 in new_refs.items(): if (new_sha1 not in have and new_sha1 not in want and new_sha1 != ZERO_SHA): want.append(new_sha1) if (not want and set(new_refs.items()).issubset(set(old_refs.items()))): return new_refs target.object_store.add_pack_data( *generate_pack_data(have, want, ofs_delta=True)) for refname, new_sha1 in new_refs.items(): old_sha1 = old_refs.get(refname, ZERO_SHA) if new_sha1 != ZERO_SHA: if not target.refs.set_if_equals( refname, old_sha1, new_sha1): progress('unable to set %s to %s' % (refname, new_sha1)) else: if not target.refs.remove_if_equals(refname, old_sha1): progress('unable to remove %s' % refname) return new_refs def fetch(self, path, target, determine_wants=None, progress=None, depth=None): """Fetch into a target repository. Args: path: Path to fetch from (as bytestring) target: Target repository to fetch into determine_wants: Optional function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. Defaults to all shas. progress: Optional progress function depth: Shallow fetch depth Returns: FetchPackResult object """ with self._open_repo(path) as r: refs = r.fetch(target, determine_wants=determine_wants, progress=progress, depth=depth) return FetchPackResult(refs, r.refs.get_symrefs(), agent_string()) def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None): """Retrieve a pack from a git smart server. Args: path: Remote path to fetch from determine_wants: Function determine what refs to fetch. Receives dictionary of name->sha, should return list of shas to fetch. graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Shallow fetch depth Returns: FetchPackResult object """ with self._open_repo(path) as r: objects_iter = r.fetch_objects( determine_wants, graph_walker, progress=progress, depth=depth) symrefs = r.refs.get_symrefs() agent = agent_string() # Did the process short-circuit (e.g. in a stateless RPC call)? # Note that the client still expects a 0-object pack in most cases. if objects_iter is None: return FetchPackResult(None, symrefs, agent) protocol = ProtocolFile(None, pack_data) write_pack_objects(protocol, objects_iter) return FetchPackResult(r.get_refs(), symrefs, agent) def get_refs(self, path): """Retrieve the current refs from a git smart server. """ with self._open_repo(path) as target: return target.get_refs() # What Git client to use for local access default_local_git_client_cls = LocalGitClient class SSHVendor(object): """A client side SSH implementation.""" def connect_ssh(self, host, command, username=None, port=None, password=None, key_filename=None): # This function was deprecated in 0.9.1 import warnings warnings.warn( "SSHVendor.connect_ssh has been renamed to SSHVendor.run_command", DeprecationWarning) return self.run_command(host, command, username=username, port=port, password=password, key_filename=key_filename) def run_command(self, host, command, username=None, port=None, password=None, key_filename=None): """Connect to an SSH server. Run a command remotely and return a file-like object for interaction with the remote command. Args: host: Host name command: Command to run (as argv array) username: Optional ame of user to log in as port: Optional SSH port to use password: Optional ssh password for login or private key key_filename: Optional path to private keyfile Returns: """ raise NotImplementedError(self.run_command) class StrangeHostname(Exception): """Refusing to connect to strange SSH hostname.""" def __init__(self, hostname): super(StrangeHostname, self).__init__(hostname) class SubprocessSSHVendor(SSHVendor): """SSH vendor that shells out to the local 'ssh' command.""" def run_command(self, host, command, username=None, port=None, password=None, key_filename=None): if password is not None: raise NotImplementedError( "Setting password not supported by SubprocessSSHVendor.") args = ['ssh', '-x'] if port: args.extend(['-p', str(port)]) if key_filename: args.extend(['-i', str(key_filename)]) if username: host = '%s@%s' % (username, host) if host.startswith('-'): raise StrangeHostname(hostname=host) args.append(host) proc = subprocess.Popen(args + [command], bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return SubprocessWrapper(proc) class PLinkSSHVendor(SSHVendor): """SSH vendor that shells out to the local 'plink' command.""" def run_command(self, host, command, username=None, port=None, password=None, key_filename=None): if sys.platform == 'win32': args = ['plink.exe', '-ssh'] else: args = ['plink', '-ssh'] if password is not None: import warnings warnings.warn( "Invoking PLink with a password exposes the password in the " "process list.") args.extend(['-pw', str(password)]) if port: args.extend(['-P', str(port)]) if key_filename: args.extend(['-i', str(key_filename)]) if username: host = '%s@%s' % (username, host) if host.startswith('-'): raise StrangeHostname(hostname=host) args.append(host) proc = subprocess.Popen(args + [command], bufsize=0, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return SubprocessWrapper(proc) def ParamikoSSHVendor(**kwargs): import warnings warnings.warn( "ParamikoSSHVendor has been moved to dulwich.contrib.paramiko_vendor.", DeprecationWarning) from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor return ParamikoSSHVendor(**kwargs) # Can be overridden by users get_ssh_vendor = SubprocessSSHVendor class SSHGitClient(TraditionalGitClient): def __init__(self, host, port=None, username=None, vendor=None, config=None, password=None, key_filename=None, **kwargs): self.host = host self.port = port self.username = username self.password = password self.key_filename = key_filename super(SSHGitClient, self).__init__(**kwargs) self.alternative_paths = {} if vendor is not None: self.ssh_vendor = vendor else: self.ssh_vendor = get_ssh_vendor() def get_url(self, path): netloc = self.host if self.port is not None: netloc += ":%d" % self.port if self.username is not None: netloc = urlquote(self.username, '@/:') + "@" + netloc return urlunsplit(('ssh', netloc, path, '', '')) @classmethod def from_parsedurl(cls, parsedurl, **kwargs): return cls(host=parsedurl.hostname, port=parsedurl.port, username=parsedurl.username, **kwargs) def _get_cmd_path(self, cmd): cmd = self.alternative_paths.get(cmd, b'git-' + cmd) assert isinstance(cmd, bytes) return cmd def _connect(self, cmd, path): if not isinstance(cmd, bytes): raise TypeError(cmd) if isinstance(path, bytes): path = path.decode(self._remote_path_encoding) if path.startswith("/~"): path = path[1:] argv = (self._get_cmd_path(cmd).decode(self._remote_path_encoding) + " '" + path + "'") kwargs = {} if self.password is not None: kwargs['password'] = self.password if self.key_filename is not None: kwargs['key_filename'] = self.key_filename con = self.ssh_vendor.run_command( self.host, argv, port=self.port, username=self.username, **kwargs) return (Protocol(con.read, con.write, con.close, report_activity=self._report_activity), con.can_read, getattr(con, 'stderr', None)) def default_user_agent_string(): # Start user agent with "git/", because GitHub requires this. :-( See # https://github.com/jelmer/dulwich/issues/562 for details. return "git/dulwich/%s" % ".".join([str(x) for x in dulwich.__version__]) def default_urllib3_manager(config, pool_manager_cls=None, proxy_manager_cls=None, **override_kwargs): """Return `urllib3` connection pool manager. Honour detected proxy configurations. Args: config: dulwich.config.ConfigDict` instance with Git configuration. kwargs: Additional arguments for urllib3.ProxyManager Returns: `pool_manager_cls` (defaults to `urllib3.ProxyManager`) instance for proxy configurations, `proxy_manager_cls` (defaults to `urllib3.PoolManager`) instance otherwise. """ proxy_server = user_agent = None ca_certs = ssl_verify = None if config is not None: try: proxy_server = config.get(b"http", b"proxy") except KeyError: pass try: user_agent = config.get(b"http", b"useragent") except KeyError: pass # TODO(jelmer): Support per-host settings try: ssl_verify = config.get_boolean(b"http", b"sslVerify") except KeyError: ssl_verify = True try: ca_certs = config.get(b"http", b"sslCAInfo") except KeyError: ca_certs = None if user_agent is None: user_agent = default_user_agent_string() headers = {"User-agent": user_agent} kwargs = {} if ssl_verify is True: kwargs['cert_reqs'] = "CERT_REQUIRED" elif ssl_verify is False: kwargs['cert_reqs'] = 'CERT_NONE' else: # Default to SSL verification kwargs['cert_reqs'] = "CERT_REQUIRED" if ca_certs is not None: kwargs['ca_certs'] = ca_certs kwargs.update(override_kwargs) # Try really hard to find a SSL certificate path if 'ca_certs' not in kwargs and kwargs.get('cert_reqs') != 'CERT_NONE': try: import certifi except ImportError: pass else: kwargs['ca_certs'] = certifi.where() import urllib3 if proxy_server is not None: if proxy_manager_cls is None: proxy_manager_cls = urllib3.ProxyManager # `urllib3` requires a `str` object in both Python 2 and 3, while # `ConfigDict` coerces entries to `bytes` on Python 3. Compensate. if not isinstance(proxy_server, str): proxy_server = proxy_server.decode() manager = proxy_manager_cls(proxy_server, headers=headers, **kwargs) else: if pool_manager_cls is None: pool_manager_cls = urllib3.PoolManager manager = pool_manager_cls(headers=headers, **kwargs) return manager class HttpGitClient(GitClient): def __init__(self, base_url, dumb=None, pool_manager=None, config=None, username=None, password=None, **kwargs): self._base_url = base_url.rstrip("/") + "/" self._username = username self._password = password self.dumb = dumb if pool_manager is None: self.pool_manager = default_urllib3_manager(config) else: self.pool_manager = pool_manager if username is not None: # No escaping needed: ":" is not allowed in username: # https://tools.ietf.org/html/rfc2617#section-2 credentials = "%s:%s" % (username, password) import urllib3.util basic_auth = urllib3.util.make_headers(basic_auth=credentials) self.pool_manager.headers.update(basic_auth) GitClient.__init__(self, **kwargs) def get_url(self, path): return self._get_url(path).rstrip("/") @classmethod def from_parsedurl(cls, parsedurl, **kwargs): password = parsedurl.password if password is not None: kwargs['password'] = urlunquote(password) username = parsedurl.username if username is not None: kwargs['username'] = urlunquote(username) netloc = parsedurl.hostname if parsedurl.port: netloc = "%s:%s" % (netloc, parsedurl.port) if parsedurl.username: netloc = "%s@%s" % (parsedurl.username, netloc) parsedurl = parsedurl._replace(netloc=netloc) return cls(urlunparse(parsedurl), **kwargs) def __repr__(self): return "%s(%r, dumb=%r)" % ( type(self).__name__, self._base_url, self.dumb) def _get_url(self, path): if not isinstance(path, str): # urllib3.util.url._encode_invalid_chars() converts the path back # to bytes using the utf-8 codec. path = path.decode('utf-8') return urljoin(self._base_url, path).rstrip("/") + "/" def _http_request(self, url, headers=None, data=None, allow_compression=False): """Perform HTTP request. Args: url: Request URL. headers: Optional custom headers to override defaults. data: Request data. allow_compression: Allow GZipped communication. Returns: Tuple (`response`, `read`), where response is an `urllib3` response object with additional `content_type` and `redirect_location` properties, and `read` is a consumable read method for the response data. """ req_headers = self.pool_manager.headers.copy() if headers is not None: req_headers.update(headers) req_headers["Pragma"] = "no-cache" if allow_compression: req_headers["Accept-Encoding"] = "gzip" else: req_headers["Accept-Encoding"] = "identity" if data is None: resp = self.pool_manager.request("GET", url, headers=req_headers) else: resp = self.pool_manager.request("POST", url, headers=req_headers, body=data) if resp.status == 404: raise NotGitRepository() elif resp.status != 200: raise GitProtocolError("unexpected http resp %d for %s" % (resp.status, url)) # TODO: Optimization available by adding `preload_content=False` to the # request and just passing the `read` method on instead of going via # `BytesIO`, if we can guarantee that the entire response is consumed # before issuing the next to still allow for connection reuse from the # pool. read = BytesIO(resp.data).read resp.content_type = resp.getheader("Content-Type") # Check if geturl() is available (urllib3 version >= 1.23) try: resp_url = resp.geturl() except AttributeError: # get_redirect_location() is available for urllib3 >= 1.1 resp.redirect_location = resp.get_redirect_location() else: resp.redirect_location = resp_url if resp_url != url else '' return resp, read def _discover_references(self, service, base_url): assert base_url[-1] == "/" tail = "info/refs" headers = {"Accept": "*/*"} if self.dumb is not True: tail += "?service=%s" % service.decode('ascii') url = urljoin(base_url, tail) resp, read = self._http_request(url, headers, allow_compression=True) if resp.redirect_location: # Something changed (redirect!), so let's update the base URL if not resp.redirect_location.endswith(tail): raise GitProtocolError( "Redirected from URL %s to URL %s without %s" % ( url, resp.redirect_location, tail)) base_url = resp.redirect_location[:-len(tail)] try: self.dumb = not resp.content_type.startswith("application/x-git-") if not self.dumb: proto = Protocol(read, None) # The first line should mention the service try: [pkt] = list(proto.read_pkt_seq()) except ValueError: raise GitProtocolError( "unexpected number of packets received") if pkt.rstrip(b'\n') != (b'# service=' + service): raise GitProtocolError( "unexpected first line %r from smart server" % pkt) return read_pkt_refs(proto) + (base_url, ) else: return read_info_refs(resp), set(), base_url finally: resp.close() def _smart_request(self, service, url, data): assert url[-1] == "/" url = urljoin(url, service) result_content_type = "application/x-%s-result" % service headers = { "Content-Type": "application/x-%s-request" % service, "Accept": result_content_type, "Content-Length": str(len(data)), } resp, read = self._http_request(url, headers, data) if resp.content_type != result_content_type: raise GitProtocolError("Invalid content-type from server: %s" % resp.content_type) return resp, read def send_pack(self, path, update_refs, generate_pack_data, progress=None): """Upload a pack to a remote repository. Args: path: Repository path (as bytestring) update_refs: Function to determine changes to remote refs. Receives dict with existing remote refs, returns dict with changed refs (name -> sha, where sha=ZERO_SHA for deletions) generate_pack_data: Function that can return a tuple with number of elements and pack data to upload. progress: Optional progress function Returns: new_refs dictionary containing the changes that were made {refname: new_ref}, including deleted refs. Raises: SendPackError: if server rejects the pack data UpdateRefsError: if the server supports report-status and rejects ref updates """ url = self._get_url(path) old_refs, server_capabilities, url = self._discover_references( b"git-receive-pack", url) negotiated_capabilities = self._negotiate_receive_pack_capabilities( server_capabilities) negotiated_capabilities.add(capability_agent()) if CAPABILITY_REPORT_STATUS in negotiated_capabilities: self._report_status_parser = ReportStatusParser() new_refs = update_refs(dict(old_refs)) if new_refs is None: # Determine wants function is aborting the push. return old_refs if self.dumb: raise NotImplementedError(self.fetch_pack) req_data = BytesIO() req_proto = Protocol(None, req_data.write) (have, want) = self._handle_receive_pack_head( req_proto, negotiated_capabilities, old_refs, new_refs) if not want and set(new_refs.items()).issubset(set(old_refs.items())): return new_refs pack_data_count, pack_data = generate_pack_data( have, want, ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities)) if pack_data_count: write_pack_data(req_proto.write_file(), pack_data_count, pack_data) resp, read = self._smart_request("git-receive-pack", url, data=req_data.getvalue()) try: resp_proto = Protocol(read, None) self._handle_receive_pack_tail( resp_proto, negotiated_capabilities, progress) return new_refs finally: resp.close() def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress=None, depth=None): """Retrieve a pack from a git smart server. Args: path: Path to fetch from determine_wants: Callback that returns list of commits to fetch graph_walker: Object with next() and ack(). pack_data: Callback called for each bit of data in the pack progress: Callback for progress reports (strings) depth: Depth for request Returns: FetchPackResult object """ url = self._get_url(path) refs, server_capabilities, url = self._discover_references( b"git-upload-pack", url) negotiated_capabilities, symrefs, agent = ( self._negotiate_upload_pack_capabilities( server_capabilities)) wants = determine_wants(refs) if wants is not None: wants = [cid for cid in wants if cid != ZERO_SHA] if not wants: return FetchPackResult(refs, symrefs, agent) if self.dumb: raise NotImplementedError(self.send_pack) req_data = BytesIO() req_proto = Protocol(None, req_data.write) (new_shallow, new_unshallow) = self._handle_upload_pack_head( req_proto, negotiated_capabilities, graph_walker, wants, can_read=None, depth=depth) resp, read = self._smart_request( "git-upload-pack", url, data=req_data.getvalue()) try: resp_proto = Protocol(read, None) if new_shallow is None and new_unshallow is None: (new_shallow, new_unshallow) = _read_shallow_updates( resp_proto) self._handle_upload_pack_tail( resp_proto, negotiated_capabilities, graph_walker, pack_data, progress) return FetchPackResult( refs, symrefs, agent, new_shallow, new_unshallow) finally: resp.close() def get_refs(self, path): """Retrieve the current refs from a git smart server. """ url = self._get_url(path) refs, _, _ = self._discover_references( b"git-upload-pack", url) return refs def get_transport_and_path_from_url(url, config=None, **kwargs): """Obtain a git client from a URL. Args: url: URL to open (a unicode string) config: Optional config object thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. Returns: Tuple with client instance and relative path. """ parsed = urlparse(url) if parsed.scheme == 'git': return (TCPGitClient.from_parsedurl(parsed, **kwargs), parsed.path) elif parsed.scheme in ('git+ssh', 'ssh'): return SSHGitClient.from_parsedurl(parsed, **kwargs), parsed.path elif parsed.scheme in ('http', 'https'): return HttpGitClient.from_parsedurl( parsed, config=config, **kwargs), parsed.path elif parsed.scheme == 'file': return default_local_git_client_cls.from_parsedurl( parsed, **kwargs), parsed.path raise ValueError("unknown scheme '%s'" % parsed.scheme) def parse_rsync_url(location): """Parse a rsync-style URL. """ if ':' in location and '@' not in location: # SSH with no user@, zero or one leading slash. (host, path) = location.split(':', 1) user = None elif ':' in location: # SSH with user@host:foo. user_host, path = location.split(':', 1) if '@' in user_host: user, host = user_host.rsplit('@', 1) else: user = None host = user_host else: raise ValueError('not a valid rsync-style URL') return (user, host, path) def get_transport_and_path(location, **kwargs): """Obtain a git client from a URL. Args: location: URL or path (a string) config: Optional config object thin_packs: Whether or not thin packs should be retrieved report_activity: Optional callback for reporting transport activity. Returns: Tuple with client instance and relative path. """ # First, try to parse it as a URL try: return get_transport_and_path_from_url(location, **kwargs) except ValueError: pass if (sys.platform == 'win32' and location[0].isalpha() and location[1:3] == ':\\'): # Windows local path return default_local_git_client_cls(**kwargs), location try: (username, hostname, path) = parse_rsync_url(location) except ValueError: # Otherwise, assume it's a local path. return default_local_git_client_cls(**kwargs), location else: return SSHGitClient(hostname, username=username, **kwargs), path DEFAULT_GIT_CREDENTIALS_PATHS = [ os.path.expanduser('~/.git-credentials'), get_xdg_config_home_path('git', 'credentials')] def get_credentials_from_store(scheme, hostname, username=None, fnames=DEFAULT_GIT_CREDENTIALS_PATHS): for fname in fnames: try: with open(fname, 'rb') as f: for line in f: parsed_line = urlparse.urlparse(line) if (parsed_line.scheme == scheme and parsed_line.hostname == hostname and (username is None or parsed_line.username == username)): return parsed_line.username, parsed_line.password except OSError as e: if e.errno != errno.ENOENT: raise # If the file doesn't exist, try the next one. continue diff --git a/dulwich/ignore.py b/dulwich/ignore.py index 72396401..2227acb1 100644 --- a/dulwich/ignore.py +++ b/dulwich/ignore.py @@ -1,374 +1,373 @@ # Copyright (C) 2017 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Parsing of gitignore files. For details for the matching rules, see https://git-scm.com/docs/gitignore """ import os.path import re -import sys from dulwich.config import get_xdg_config_home_path def _translate_segment(segment): if segment == b"*": return b'[^/]+' res = b"" i, n = 0, len(segment) while i < n: c = segment[i:i+1] i = i+1 if c == b'*': res += b'[^/]*' elif c == b'?': res += b'[^/]' elif c == b'[': j = i if j < n and segment[j:j+1] == b'!': j = j+1 if j < n and segment[j:j+1] == b']': j = j+1 while j < n and segment[j:j+1] != b']': j = j+1 if j >= n: res += b'\\[' else: stuff = segment[i:j].replace(b'\\', b'\\\\') i = j+1 if stuff.startswith(b'!'): stuff = b'^' + stuff[1:] elif stuff.startswith(b'^'): stuff = b'\\' + stuff res += b'[' + stuff + b']' else: res += re.escape(c) return res def translate(pat): """Translate a shell PATTERN to a regular expression. There is no way to quote meta-characters. Originally copied from fnmatch in Python 2.7, but modified for Dulwich to cope with features in Git ignore patterns. """ res = b'(?ms)' if b'/' not in pat[:-1]: # If there's no slash, this is a filename-based match res += b'(.*/)?' if pat.startswith(b'**/'): # Leading **/ pat = pat[2:] res += b'(.*/)?' if pat.startswith(b'/'): pat = pat[1:] for i, segment in enumerate(pat.split(b'/')): if segment == b'**': res += b'(/.*)?' continue else: res += ((re.escape(b'/') if i > 0 else b'') + _translate_segment(segment)) if not pat.endswith(b'/'): res += b'/?' return res + b'\\Z' def read_ignore_patterns(f): """Read a git ignore file. Args: f: File-like object to read from Returns: List of patterns """ for line in f: line = line.rstrip(b"\r\n") # Ignore blank lines, they're used for readability. if not line: continue if line.startswith(b'#'): # Comment continue # Trailing spaces are ignored unless they are quoted with a backslash. while line.endswith(b' ') and not line.endswith(b'\\ '): line = line[:-1] line = line.replace(b'\\ ', b' ') yield line def match_pattern(path, pattern, ignorecase=False): """Match a gitignore-style pattern against a path. Args: path: Path to match pattern: Pattern to match ignorecase: Whether to do case-sensitive matching Returns: bool indicating whether the pattern matched """ return Pattern(pattern, ignorecase).match(path) class Pattern(object): """A single ignore pattern.""" def __init__(self, pattern, ignorecase=False): self.pattern = pattern self.ignorecase = ignorecase if pattern[0:1] == b'!': self.is_exclude = False pattern = pattern[1:] else: if pattern[0:1] == b'\\': pattern = pattern[1:] self.is_exclude = True flags = 0 if self.ignorecase: flags = re.IGNORECASE self._re = re.compile(translate(pattern), flags) def __bytes__(self): return self.pattern def __str__(self): return os.fsdecode(self.pattern) def __eq__(self, other): return (type(self) == type(other) and self.pattern == other.pattern and self.ignorecase == other.ignorecase) def __repr__(self): return "%s(%s, %r)" % ( type(self).__name__, self.pattern, self.ignorecase) def match(self, path): """Try to match a path against this ignore pattern. Args: path: Path to match (relative to ignore location) Returns: boolean """ return bool(self._re.match(path)) class IgnoreFilter(object): def __init__(self, patterns, ignorecase=False): self._patterns = [] self._ignorecase = ignorecase for pattern in patterns: self.append_pattern(pattern) def append_pattern(self, pattern): """Add a pattern to the set.""" self._patterns.append(Pattern(pattern, self._ignorecase)) def find_matching(self, path): """Yield all matching patterns for path. Args: path: Path to match Returns: Iterator over iterators """ if not isinstance(path, bytes): path = os.fsencode(path) for pattern in self._patterns: if pattern.match(path): yield pattern def is_ignored(self, path): """Check whether a path is ignored. For directories, include a trailing slash. Returns: status is None if file is not mentioned, True if it is included, False if it is explicitly excluded. """ status = None for pattern in self.find_matching(path): status = pattern.is_exclude return status @classmethod def from_path(cls, path, ignorecase=False): with open(path, 'rb') as f: ret = cls(read_ignore_patterns(f), ignorecase) ret._path = path return ret def __repr__(self): if getattr(self, '_path', None) is None: return "<%s>" % (type(self).__name__) else: return "%s.from_path(%r)" % (type(self).__name__, self._path) class IgnoreFilterStack(object): """Check for ignore status in multiple filters.""" def __init__(self, filters): self._filters = filters def is_ignored(self, path): """Check whether a path is explicitly included or excluded in ignores. Args: path: Path to check Returns: None if the file is not mentioned, True if it is included, False if it is explicitly excluded. """ status = None for filter in self._filters: status = filter.is_ignored(path) if status is not None: return status return status def default_user_ignore_filter_path(config): """Return default user ignore filter path. Args: config: A Config object Returns: Path to a global ignore file """ try: return config.get((b'core', ), b'excludesFile') except KeyError: pass return get_xdg_config_home_path('git', 'ignore') class IgnoreFilterManager(object): """Ignore file manager.""" def __init__(self, top_path, global_filters, ignorecase): self._path_filters = {} self._top_path = top_path self._global_filters = global_filters self._ignorecase = ignorecase def __repr__(self): return "%s(%s, %r, %r)" % ( type(self).__name__, self._top_path, self._global_filters, self._ignorecase) def _load_path(self, path): try: return self._path_filters[path] except KeyError: pass p = os.path.join(self._top_path, path, '.gitignore') try: self._path_filters[path] = IgnoreFilter.from_path( p, self._ignorecase) except IOError: self._path_filters[path] = None return self._path_filters[path] def find_matching(self, path): """Find matching patterns for path. Stops after the first ignore file with matches. Args: path: Path to check Returns: Iterator over Pattern instances """ if os.path.isabs(path): raise ValueError('%s is an absolute path' % path) filters = [(0, f) for f in self._global_filters] if os.path.sep != '/': path = path.replace(os.path.sep, '/') parts = path.split('/') for i in range(len(parts)+1): dirname = '/'.join(parts[:i]) for s, f in filters: relpath = '/'.join(parts[s:i]) if i < len(parts): # Paths leading up to the final part are all directories, # so need a trailing slash. relpath += '/' matches = list(f.find_matching(relpath)) if matches: return iter(matches) ignore_filter = self._load_path(dirname) if ignore_filter is not None: filters.insert(0, (i, ignore_filter)) return iter([]) def is_ignored(self, path): """Check whether a path is explicitly included or excluded in ignores. Args: path: Path to check Returns: None if the file is not mentioned, True if it is included, False if it is explicitly excluded. """ matches = list(self.find_matching(path)) if matches: return matches[-1].is_exclude return None @classmethod def from_repo(cls, repo): """Create a IgnoreFilterManager from a repository. Args: repo: Repository object Returns: A `IgnoreFilterManager` object """ global_filters = [] for p in [ os.path.join(repo.controldir(), 'info', 'exclude'), default_user_ignore_filter_path(repo.get_config_stack())]: try: global_filters.append( IgnoreFilter.from_path(os.path.expanduser(p))) except IOError: pass config = repo.get_config_stack() ignorecase = config.get_boolean((b'core'), (b'ignorecase'), False) return cls(repo.path, global_filters, ignorecase) diff --git a/dulwich/refs.py b/dulwich/refs.py index 54ce22a3..317261e1 100644 --- a/dulwich/refs.py +++ b/dulwich/refs.py @@ -1,968 +1,967 @@ # refs.py -- For dealing with git refs # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Ref handling. """ import errno import os -import sys from dulwich.errors import ( PackedRefsException, RefFormatError, ) from dulwich.objects import ( git_line, valid_hexsha, ZERO_SHA, ) from dulwich.file import ( GitFile, ensure_dir_exists, ) SYMREF = b'ref: ' LOCAL_BRANCH_PREFIX = b'refs/heads/' LOCAL_TAG_PREFIX = b'refs/tags/' BAD_REF_CHARS = set(b'\177 ~^:?*[') ANNOTATED_TAG_SUFFIX = b'^{}' def parse_symref_value(contents): """Parse a symref value. Args: contents: Contents to parse Returns: Destination """ if contents.startswith(SYMREF): return contents[len(SYMREF):].rstrip(b'\r\n') raise ValueError(contents) def check_ref_format(refname): """Check if a refname is correctly formatted. Implements all the same rules as git-check-ref-format[1]. [1] http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html Args: refname: The refname to check Returns: True if refname is valid, False otherwise """ # These could be combined into one big expression, but are listed # separately to parallel [1]. if b'/.' in refname or refname.startswith(b'.'): return False if b'/' not in refname: return False if b'..' in refname: return False for i, c in enumerate(refname): if ord(refname[i:i+1]) < 0o40 or c in BAD_REF_CHARS: return False if refname[-1] in b'/.': return False if refname.endswith(b'.lock'): return False if b'@{' in refname: return False if b'\\' in refname: return False return True class RefsContainer(object): """A container for refs.""" def __init__(self, logger=None): self._logger = logger def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None, timezone=None, message=None): if self._logger is None: return if message is None: return self._logger(ref, old_sha, new_sha, committer, timestamp, timezone, message) def set_symbolic_ref(self, name, other, committer=None, timestamp=None, timezone=None, message=None): """Make a ref point at another ref. Args: name: Name of the ref to set other: Name of the ref to point at message: Optional message """ raise NotImplementedError(self.set_symbolic_ref) def get_packed_refs(self): """Get contents of the packed-refs file. Returns: Dictionary mapping ref names to SHA1s Note: Will return an empty dictionary when no packed-refs file is present. """ raise NotImplementedError(self.get_packed_refs) def get_peeled(self, name): """Return the cached peeled value of a ref, if available. Args: name: Name of the ref to peel Returns: The peeled value of the ref. If the ref is known not point to a tag, this will be the SHA the ref refers to. If the ref may point to a tag, but no cached information is available, None is returned. """ return None def import_refs(self, base, other, committer=None, timestamp=None, timezone=None, message=None, prune=False): if prune: to_delete = set(self.subkeys(base)) else: to_delete = set() for name, value in other.items(): self.set_if_equals(b'/'.join((base, name)), None, value, message=message) if to_delete: try: to_delete.remove(name) except KeyError: pass for ref in to_delete: self.remove_if_equals(b'/'.join((base, ref)), None) def allkeys(self): """All refs present in this container.""" raise NotImplementedError(self.allkeys) def __iter__(self): return iter(self.allkeys()) def keys(self, base=None): """Refs present in this container. Args: base: An optional base to return refs under. Returns: An unsorted set of valid refs in this container, including packed refs. """ if base is not None: return self.subkeys(base) else: return self.allkeys() def subkeys(self, base): """Refs present in this container under a base. Args: base: The base to return refs under. Returns: A set of valid refs in this container under the base; the base prefix is stripped from the ref names returned. """ keys = set() base_len = len(base) + 1 for refname in self.allkeys(): if refname.startswith(base): keys.add(refname[base_len:]) return keys def as_dict(self, base=None): """Return the contents of this container as a dictionary. """ ret = {} keys = self.keys(base) if base is None: base = b'' else: base = base.rstrip(b'/') for key in keys: try: ret[key] = self[(base + b'/' + key).strip(b'/')] except KeyError: continue # Unable to resolve return ret def _check_refname(self, name): """Ensure a refname is valid and lives in refs or is HEAD. HEAD is not a valid refname according to git-check-ref-format, but this class needs to be able to touch HEAD. Also, check_ref_format expects refnames without the leading 'refs/', but this class requires that so it cannot touch anything outside the refs dir (or HEAD). Args: name: The name of the reference. Raises: KeyError: if a refname is not HEAD or is otherwise not valid. """ if name in (b'HEAD', b'refs/stash'): return if not name.startswith(b'refs/') or not check_ref_format(name[5:]): raise RefFormatError(name) def read_ref(self, refname): """Read a reference without following any references. Args: refname: The name of the reference Returns: The contents of the ref file, or None if it does not exist. """ contents = self.read_loose_ref(refname) if not contents: contents = self.get_packed_refs().get(refname, None) return contents def read_loose_ref(self, name): """Read a loose reference and return its contents. Args: name: the refname to read Returns: The contents of the ref file, or None if it does not exist. """ raise NotImplementedError(self.read_loose_ref) def follow(self, name): """Follow a reference name. Returns: a tuple of (refnames, sha), wheres refnames are the names of references in the chain """ contents = SYMREF + name depth = 0 refnames = [] while contents.startswith(SYMREF): refname = contents[len(SYMREF):] refnames.append(refname) contents = self.read_ref(refname) if not contents: break depth += 1 if depth > 5: raise KeyError(name) return refnames, contents def _follow(self, name): import warnings warnings.warn( "RefsContainer._follow is deprecated. Use RefsContainer.follow " "instead.", DeprecationWarning) refnames, contents = self.follow(name) if not refnames: return (None, contents) return (refnames[-1], contents) def __contains__(self, refname): if self.read_ref(refname): return True return False def __getitem__(self, name): """Get the SHA1 for a reference name. This method follows all symbolic references. """ _, sha = self.follow(name) if sha is None: raise KeyError(name) return sha def set_if_equals(self, name, old_ref, new_ref, committer=None, timestamp=None, timezone=None, message=None): """Set a refname to new_ref only if it currently equals old_ref. This method follows all symbolic references if applicable for the subclass, and can be used to perform an atomic compare-and-swap operation. Args: name: The refname to set. old_ref: The old sha the refname must refer to, or None to set unconditionally. new_ref: The new sha the refname will refer to. message: Message for reflog Returns: True if the set was successful, False otherwise. """ raise NotImplementedError(self.set_if_equals) def add_if_new(self, name, ref): """Add a new reference only if it does not already exist. Args: name: Ref name ref: Ref value message: Message for reflog """ raise NotImplementedError(self.add_if_new) def __setitem__(self, name, ref): """Set a reference name to point to the given SHA1. This method follows all symbolic references if applicable for the subclass. Note: This method unconditionally overwrites the contents of a reference. To update atomically only if the reference has not changed, use set_if_equals(). Args: name: The refname to set. ref: The new sha the refname will refer to. """ self.set_if_equals(name, None, ref) def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None): """Remove a refname only if it currently equals old_ref. This method does not follow symbolic references, even if applicable for the subclass. It can be used to perform an atomic compare-and-delete operation. Args: name: The refname to delete. old_ref: The old sha the refname must refer to, or None to delete unconditionally. message: Message for reflog Returns: True if the delete was successful, False otherwise. """ raise NotImplementedError(self.remove_if_equals) def __delitem__(self, name): """Remove a refname. This method does not follow symbolic references, even if applicable for the subclass. Note: This method unconditionally deletes the contents of a reference. To delete atomically only if the reference has not changed, use remove_if_equals(). Args: name: The refname to delete. """ self.remove_if_equals(name, None) def get_symrefs(self): """Get a dict with all symrefs in this container. Returns: Dictionary mapping source ref to target ref """ ret = {} for src in self.allkeys(): try: dst = parse_symref_value(self.read_ref(src)) except ValueError: pass else: ret[src] = dst return ret class DictRefsContainer(RefsContainer): """RefsContainer backed by a simple dict. This container does not support symbolic or packed references and is not threadsafe. """ def __init__(self, refs, logger=None): super(DictRefsContainer, self).__init__(logger=logger) self._refs = refs self._peeled = {} def allkeys(self): return self._refs.keys() def read_loose_ref(self, name): return self._refs.get(name, None) def get_packed_refs(self): return {} def set_symbolic_ref(self, name, other, committer=None, timestamp=None, timezone=None, message=None): old = self.follow(name)[-1] self._refs[name] = SYMREF + other self._log(name, old, old, committer=committer, timestamp=timestamp, timezone=timezone, message=message) def set_if_equals(self, name, old_ref, new_ref, committer=None, timestamp=None, timezone=None, message=None): if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref: return False realnames, _ = self.follow(name) for realname in realnames: self._check_refname(realname) old = self._refs.get(realname) self._refs[realname] = new_ref self._log(realname, old, new_ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def add_if_new(self, name, ref, committer=None, timestamp=None, timezone=None, message=None): if name in self._refs: return False self._refs[name] = ref self._log(name, None, ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None): if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref: return False try: old = self._refs.pop(name) except KeyError: pass else: self._log(name, old, None, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def get_peeled(self, name): return self._peeled.get(name) def _update(self, refs): """Update multiple refs; intended only for testing.""" # TODO(dborowitz): replace this with a public function that uses # set_if_equal. self._refs.update(refs) def _update_peeled(self, peeled): """Update cached peeled refs; intended only for testing.""" self._peeled.update(peeled) class InfoRefsContainer(RefsContainer): """Refs container that reads refs from a info/refs file.""" def __init__(self, f): self._refs = {} self._peeled = {} for line in f.readlines(): sha, name = line.rstrip(b'\n').split(b'\t') if name.endswith(ANNOTATED_TAG_SUFFIX): name = name[:-3] if not check_ref_format(name): raise ValueError("invalid ref name %r" % name) self._peeled[name] = sha else: if not check_ref_format(name): raise ValueError("invalid ref name %r" % name) self._refs[name] = sha def allkeys(self): return self._refs.keys() def read_loose_ref(self, name): return self._refs.get(name, None) def get_packed_refs(self): return {} def get_peeled(self, name): try: return self._peeled[name] except KeyError: return self._refs[name] class DiskRefsContainer(RefsContainer): """Refs container that reads refs from disk.""" def __init__(self, path, worktree_path=None, logger=None): super(DiskRefsContainer, self).__init__(logger=logger) if getattr(path, 'encode', None) is not None: path = os.fsencode(path) self.path = path if worktree_path is None: worktree_path = path if getattr(worktree_path, 'encode', None) is not None: worktree_path = os.fsencode(worktree_path) self.worktree_path = worktree_path self._packed_refs = None self._peeled_refs = None def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.path) def subkeys(self, base): subkeys = set() path = self.refpath(base) for root, unused_dirs, files in os.walk(path): dir = root[len(path):] if os.path.sep != '/': dir = dir.replace(os.fsencode(os.path.sep), b"/") dir = dir.strip(b'/') for filename in files: refname = b"/".join(([dir] if dir else []) + [filename]) # check_ref_format requires at least one /, so we prepend the # base before calling it. if check_ref_format(base + b'/' + refname): subkeys.add(refname) for key in self.get_packed_refs(): if key.startswith(base): subkeys.add(key[len(base):].strip(b'/')) return subkeys def allkeys(self): allkeys = set() if os.path.exists(self.refpath(b'HEAD')): allkeys.add(b'HEAD') path = self.refpath(b'') refspath = self.refpath(b'refs') for root, unused_dirs, files in os.walk(refspath): dir = root[len(path):] if os.path.sep != '/': dir = dir.replace(os.fsencode(os.path.sep), b"/") for filename in files: refname = b"/".join([dir, filename]) if check_ref_format(refname): allkeys.add(refname) allkeys.update(self.get_packed_refs()) return allkeys def refpath(self, name): """Return the disk path of a ref. """ if os.path.sep != "/": name = name.replace(b"/", os.fsencode(os.path.sep)) # TODO: as the 'HEAD' reference is working tree specific, it # should actually not be a part of RefsContainer if name == b'HEAD': return os.path.join(self.worktree_path, name) else: return os.path.join(self.path, name) def get_packed_refs(self): """Get contents of the packed-refs file. Returns: Dictionary mapping ref names to SHA1s Note: Will return an empty dictionary when no packed-refs file is present. """ # TODO: invalidate the cache on repacking if self._packed_refs is None: # set both to empty because we want _peeled_refs to be # None if and only if _packed_refs is also None. self._packed_refs = {} self._peeled_refs = {} path = os.path.join(self.path, b'packed-refs') try: f = GitFile(path, 'rb') except IOError as e: if e.errno == errno.ENOENT: return {} raise with f: first_line = next(iter(f)).rstrip() if (first_line.startswith(b'# pack-refs') and b' peeled' in first_line): for sha, name, peeled in read_packed_refs_with_peeled(f): self._packed_refs[name] = sha if peeled: self._peeled_refs[name] = peeled else: f.seek(0) for sha, name in read_packed_refs(f): self._packed_refs[name] = sha return self._packed_refs def get_peeled(self, name): """Return the cached peeled value of a ref, if available. Args: name: Name of the ref to peel Returns: The peeled value of the ref. If the ref is known not point to a tag, this will be the SHA the ref refers to. If the ref may point to a tag, but no cached information is available, None is returned. """ self.get_packed_refs() if self._peeled_refs is None or name not in self._packed_refs: # No cache: no peeled refs were read, or this ref is loose return None if name in self._peeled_refs: return self._peeled_refs[name] else: # Known not peelable return self[name] def read_loose_ref(self, name): """Read a reference file and return its contents. If the reference file a symbolic reference, only read the first line of the file. Otherwise, only read the first 40 bytes. Args: name: the refname to read, relative to refpath Returns: The contents of the ref file, or None if the file does not exist. Raises: IOError: if any other error occurs """ filename = self.refpath(name) try: with GitFile(filename, 'rb') as f: header = f.read(len(SYMREF)) if header == SYMREF: # Read only the first line return header + next(iter(f)).rstrip(b'\r\n') else: # Read only the first 40 bytes return header + f.read(40 - len(SYMREF)) except IOError as e: if e.errno in (errno.ENOENT, errno.EISDIR, errno.ENOTDIR): return None raise def _remove_packed_ref(self, name): if self._packed_refs is None: return filename = os.path.join(self.path, b'packed-refs') # reread cached refs from disk, while holding the lock f = GitFile(filename, 'wb') try: self._packed_refs = None self.get_packed_refs() if name not in self._packed_refs: return del self._packed_refs[name] if name in self._peeled_refs: del self._peeled_refs[name] write_packed_refs(f, self._packed_refs, self._peeled_refs) f.close() finally: f.abort() def set_symbolic_ref(self, name, other, committer=None, timestamp=None, timezone=None, message=None): """Make a ref point at another ref. Args: name: Name of the ref to set other: Name of the ref to point at message: Optional message to describe the change """ self._check_refname(name) self._check_refname(other) filename = self.refpath(name) f = GitFile(filename, 'wb') try: f.write(SYMREF + other + b'\n') sha = self.follow(name)[-1] self._log(name, sha, sha, committer=committer, timestamp=timestamp, timezone=timezone, message=message) except BaseException: f.abort() raise else: f.close() def set_if_equals(self, name, old_ref, new_ref, committer=None, timestamp=None, timezone=None, message=None): """Set a refname to new_ref only if it currently equals old_ref. This method follows all symbolic references, and can be used to perform an atomic compare-and-swap operation. Args: name: The refname to set. old_ref: The old sha the refname must refer to, or None to set unconditionally. new_ref: The new sha the refname will refer to. message: Set message for reflog Returns: True if the set was successful, False otherwise. """ self._check_refname(name) try: realnames, _ = self.follow(name) realname = realnames[-1] except (KeyError, IndexError): realname = name filename = self.refpath(realname) # make sure none of the ancestor folders is in packed refs probe_ref = os.path.dirname(realname) packed_refs = self.get_packed_refs() while probe_ref: if packed_refs.get(probe_ref, None) is not None: raise OSError(errno.ENOTDIR, 'Not a directory: {}'.format(filename)) probe_ref = os.path.dirname(probe_ref) ensure_dir_exists(os.path.dirname(filename)) with GitFile(filename, 'wb') as f: if old_ref is not None: try: # read again while holding the lock orig_ref = self.read_loose_ref(realname) if orig_ref is None: orig_ref = self.get_packed_refs().get( realname, ZERO_SHA) if orig_ref != old_ref: f.abort() return False except (OSError, IOError): f.abort() raise try: f.write(new_ref + b'\n') except (OSError, IOError): f.abort() raise self._log(realname, old_ref, new_ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def add_if_new(self, name, ref, committer=None, timestamp=None, timezone=None, message=None): """Add a new reference only if it does not already exist. This method follows symrefs, and only ensures that the last ref in the chain does not exist. Args: name: The refname to set. ref: The new sha the refname will refer to. message: Optional message for reflog Returns: True if the add was successful, False otherwise. """ try: realnames, contents = self.follow(name) if contents is not None: return False realname = realnames[-1] except (KeyError, IndexError): realname = name self._check_refname(realname) filename = self.refpath(realname) ensure_dir_exists(os.path.dirname(filename)) with GitFile(filename, 'wb') as f: if os.path.exists(filename) or name in self.get_packed_refs(): f.abort() return False try: f.write(ref + b'\n') except (OSError, IOError): f.abort() raise else: self._log(name, None, ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None): """Remove a refname only if it currently equals old_ref. This method does not follow symbolic references. It can be used to perform an atomic compare-and-delete operation. Args: name: The refname to delete. old_ref: The old sha the refname must refer to, or None to delete unconditionally. message: Optional message Returns: True if the delete was successful, False otherwise. """ self._check_refname(name) filename = self.refpath(name) ensure_dir_exists(os.path.dirname(filename)) f = GitFile(filename, 'wb') try: if old_ref is not None: orig_ref = self.read_loose_ref(name) if orig_ref is None: orig_ref = self.get_packed_refs().get(name, ZERO_SHA) if orig_ref != old_ref: return False # remove the reference file itself try: os.remove(filename) except OSError as e: if e.errno != errno.ENOENT: # may only be packed raise self._remove_packed_ref(name) self._log(name, old_ref, None, committer=committer, timestamp=timestamp, timezone=timezone, message=message) finally: # never write, we just wanted the lock f.abort() # outside of the lock, clean-up any parent directory that might now # be empty. this ensures that re-creating a reference of the same # name of what was previously a directory works as expected parent = name while True: try: parent, _ = parent.rsplit(b'/', 1) except ValueError: break parent_filename = self.refpath(parent) try: os.rmdir(parent_filename) except OSError: # this can be caused by the parent directory being # removed by another process, being not empty, etc. # in any case, this is non fatal because we already # removed the reference, just ignore it break return True def _split_ref_line(line): """Split a single ref line into a tuple of SHA1 and name.""" fields = line.rstrip(b'\n\r').split(b' ') if len(fields) != 2: raise PackedRefsException("invalid ref line %r" % line) sha, name = fields if not valid_hexsha(sha): raise PackedRefsException("Invalid hex sha %r" % sha) if not check_ref_format(name): raise PackedRefsException("invalid ref name %r" % name) return (sha, name) def read_packed_refs(f): """Read a packed refs file. Args: f: file-like object to read from Returns: Iterator over tuples with SHA1s and ref names. """ for line in f: if line.startswith(b'#'): # Comment continue if line.startswith(b'^'): raise PackedRefsException( "found peeled ref in packed-refs without peeled") yield _split_ref_line(line) def read_packed_refs_with_peeled(f): """Read a packed refs file including peeled refs. Assumes the "# pack-refs with: peeled" line was already read. Yields tuples with ref names, SHA1s, and peeled SHA1s (or None). Args: f: file-like object to read from, seek'ed to the second line """ last = None for line in f: if line[0] == b'#': continue line = line.rstrip(b'\r\n') if line.startswith(b'^'): if not last: raise PackedRefsException("unexpected peeled ref line") if not valid_hexsha(line[1:]): raise PackedRefsException("Invalid hex sha %r" % line[1:]) sha, name = _split_ref_line(last) last = None yield (sha, name, line[1:]) else: if last: sha, name = _split_ref_line(last) yield (sha, name, None) last = line if last: sha, name = _split_ref_line(last) yield (sha, name, None) def write_packed_refs(f, packed_refs, peeled_refs=None): """Write a packed refs file. Args: f: empty file-like object to write to packed_refs: dict of refname to sha of packed refs to write peeled_refs: dict of refname to peeled value of sha """ if peeled_refs is None: peeled_refs = {} else: f.write(b'# pack-refs with: peeled\n') for refname in sorted(packed_refs.keys()): f.write(git_line(packed_refs[refname], refname)) if refname in peeled_refs: f.write(b'^' + peeled_refs[refname] + b'\n') def read_info_refs(f): ret = {} for line in f.readlines(): (sha, name) = line.rstrip(b"\r\n").split(b"\t", 1) ret[name] = sha return ret def write_info_refs(refs, store): """Generate info refs.""" for name, sha in sorted(refs.items()): # get_refs() includes HEAD as a special case, but we don't want to # advertise it if name == b'HEAD': continue try: o = store[sha] except KeyError: continue peeled = store.peel_sha(sha) yield o.id + b'\t' + name + b'\n' if o.id != peeled.id: yield peeled.id + b'\t' + name + ANNOTATED_TAG_SUFFIX + b'\n' def is_local_branch(x): return x.startswith(LOCAL_BRANCH_PREFIX) def strip_peeled_refs(refs): """Remove all peeled refs""" return {ref: sha for (ref, sha) in refs.items() if not ref.endswith(ANNOTATED_TAG_SUFFIX)} diff --git a/dulwich/tests/test_repository.py b/dulwich/tests/test_repository.py index 23a78e09..5bb3b0ca 100644 --- a/dulwich/tests/test_repository.py +++ b/dulwich/tests/test_repository.py @@ -1,1148 +1,1149 @@ # -*- coding: utf-8 -*- # test_repository.py -- tests for repository.py # Copyright (C) 2007 James Westby # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for the repository.""" import locale import os import stat import shutil import sys import tempfile import warnings from dulwich import errors from dulwich.object_store import ( tree_lookup_path, ) from dulwich import objects from dulwich.config import Config from dulwich.errors import NotGitRepository from dulwich.repo import ( InvalidUserIdentity, Repo, MemoryRepo, check_user_identity, ) from dulwich.tests import ( TestCase, skipIf, ) from dulwich.tests.utils import ( open_repo, tear_down_repo, setup_warning_catcher, ) missing_sha = b'b91fa4d900e17e99b433218e988c4eb4a3e9a097' class CreateRepositoryTests(TestCase): def assertFileContentsEqual(self, expected, repo, path): f = repo.get_named_file(path) if not f: self.assertEqual(expected, None) else: with f: self.assertEqual(expected, f.read()) def _check_repo_contents(self, repo, expect_bare): self.assertEqual(expect_bare, repo.bare) self.assertFileContentsEqual( b'Unnamed repository', repo, 'description') self.assertFileContentsEqual( b'', repo, os.path.join('info', 'exclude')) self.assertFileContentsEqual(None, repo, 'nonexistent file') barestr = b'bare = ' + str(expect_bare).lower().encode('ascii') with repo.get_named_file('config') as f: config_text = f.read() self.assertTrue(barestr in config_text, "%r" % config_text) expect_filemode = sys.platform != 'win32' barestr = b'filemode = ' + str(expect_filemode).lower().encode('ascii') with repo.get_named_file('config') as f: config_text = f.read() self.assertTrue(barestr in config_text, "%r" % config_text) def test_create_memory(self): repo = MemoryRepo.init_bare([], {}) self._check_repo_contents(repo, True) def test_create_disk_bare(self): tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) repo = Repo.init_bare(tmp_dir) self.assertEqual(tmp_dir, repo._controldir) self._check_repo_contents(repo, True) def test_create_disk_non_bare(self): tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) repo = Repo.init(tmp_dir) self.assertEqual(os.path.join(tmp_dir, '.git'), repo._controldir) self._check_repo_contents(repo, False) def test_create_disk_non_bare_mkdir(self): tmp_dir = tempfile.mkdtemp() target_dir = os.path.join(tmp_dir, "target") self.addCleanup(shutil.rmtree, tmp_dir) repo = Repo.init(target_dir, mkdir=True) self.assertEqual(os.path.join(target_dir, '.git'), repo._controldir) self._check_repo_contents(repo, False) def test_create_disk_bare_mkdir(self): tmp_dir = tempfile.mkdtemp() target_dir = os.path.join(tmp_dir, "target") self.addCleanup(shutil.rmtree, tmp_dir) repo = Repo.init_bare(target_dir, mkdir=True) self.assertEqual(target_dir, repo._controldir) self._check_repo_contents(repo, True) class MemoryRepoTests(TestCase): def test_set_description(self): r = MemoryRepo.init_bare([], {}) description = b"Some description" r.set_description(description) self.assertEqual(description, r.get_description()) class RepositoryRootTests(TestCase): def mkdtemp(self): return tempfile.mkdtemp() def open_repo(self, name): temp_dir = self.mkdtemp() repo = open_repo(name, temp_dir) self.addCleanup(tear_down_repo, repo) return repo def test_simple_props(self): r = self.open_repo('a.git') self.assertEqual(r.controldir(), r.path) def test_setitem(self): r = self.open_repo('a.git') r[b"refs/tags/foo"] = b'a90fa2d900a17e99b433217e988c4eb4a2e9a097' self.assertEqual(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', r[b"refs/tags/foo"].id) def test_getitem_unicode(self): r = self.open_repo('a.git') test_keys = [ (b'refs/heads/master', True), (b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', True), (b'11' * 19 + b'--', False), ] for k, contained in test_keys: self.assertEqual(k in r, contained) # Avoid deprecation warning under Py3.2+ if getattr(self, 'assertRaisesRegex', None): assertRaisesRegexp = self.assertRaisesRegex else: assertRaisesRegexp = self.assertRaisesRegexp for k, _ in test_keys: assertRaisesRegexp( TypeError, "'name' must be bytestring, not int", r.__getitem__, 12 ) def test_delitem(self): r = self.open_repo('a.git') del r[b'refs/heads/master'] self.assertRaises(KeyError, lambda: r[b'refs/heads/master']) del r[b'HEAD'] self.assertRaises(KeyError, lambda: r[b'HEAD']) self.assertRaises(ValueError, r.__delitem__, b'notrefs/foo') def test_get_refs(self): r = self.open_repo('a.git') self.assertEqual({ b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', b'refs/tags/mytag-packed': b'b0931cadc54336e78a1d980420e3268903b57a50', }, r.get_refs()) def test_head(self): r = self.open_repo('a.git') self.assertEqual(r.head(), b'a90fa2d900a17e99b433217e988c4eb4a2e9a097') def test_get_object(self): r = self.open_repo('a.git') obj = r.get_object(r.head()) self.assertEqual(obj.type_name, b'commit') def test_get_object_non_existant(self): r = self.open_repo('a.git') self.assertRaises(KeyError, r.get_object, missing_sha) def test_contains_object(self): r = self.open_repo('a.git') self.assertTrue(r.head() in r) def test_contains_ref(self): r = self.open_repo('a.git') self.assertTrue(b"HEAD" in r) def test_get_no_description(self): r = self.open_repo('a.git') self.assertIs(None, r.get_description()) def test_get_description(self): r = self.open_repo('a.git') with open(os.path.join(r.path, 'description'), 'wb') as f: f.write(b"Some description") self.assertEqual(b"Some description", r.get_description()) def test_set_description(self): r = self.open_repo('a.git') description = b"Some description" r.set_description(description) self.assertEqual(description, r.get_description()) def test_contains_missing(self): r = self.open_repo('a.git') self.assertFalse(b"bar" in r) def test_get_peeled(self): # unpacked ref r = self.open_repo('a.git') tag_sha = b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a' self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head()) self.assertEqual(r.get_peeled(b'refs/tags/mytag'), r.head()) # packed ref with cached peeled value packed_tag_sha = b'b0931cadc54336e78a1d980420e3268903b57a50' parent_sha = r[r.head()].parents[0] self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha) self.assertEqual(r.get_peeled(b'refs/tags/mytag-packed'), parent_sha) # TODO: add more corner cases to test repo def test_get_peeled_not_tag(self): r = self.open_repo('a.git') self.assertEqual(r.get_peeled(b'HEAD'), r.head()) def test_get_walker(self): r = self.open_repo('a.git') # include defaults to [r.head()] self.assertEqual( [e.commit.id for e in r.get_walker()], [r.head(), b'2a72d929692c41d8554c07f6301757ba18a65d91']) self.assertEqual( [e.commit.id for e in r.get_walker([b'2a72d929692c41d8554c07f6301757ba18a65d91'])], [b'2a72d929692c41d8554c07f6301757ba18a65d91']) self.assertEqual( [e.commit.id for e in r.get_walker(b'2a72d929692c41d8554c07f6301757ba18a65d91')], [b'2a72d929692c41d8554c07f6301757ba18a65d91']) def assertFilesystemHidden(self, path): if sys.platform != 'win32': return import ctypes from ctypes.wintypes import DWORD, LPCWSTR GetFileAttributesW = ctypes.WINFUNCTYPE(DWORD, LPCWSTR)( ('GetFileAttributesW', ctypes.windll.kernel32)) self.assertTrue(2 & GetFileAttributesW(path)) def test_init_existing(self): tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) t = Repo.init(tmp_dir) self.addCleanup(t.close) self.assertEqual(os.listdir(tmp_dir), ['.git']) self.assertFilesystemHidden(os.path.join(tmp_dir, '.git')) def test_init_mkdir(self): tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) repo_dir = os.path.join(tmp_dir, 'a-repo') t = Repo.init(repo_dir, mkdir=True) self.addCleanup(t.close) self.assertEqual(os.listdir(repo_dir), ['.git']) self.assertFilesystemHidden(os.path.join(repo_dir, '.git')) def test_init_mkdir_unicode(self): repo_name = u'\xa7' try: os.fsencode(repo_name) except UnicodeEncodeError: self.skipTest('filesystem lacks unicode support') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) repo_dir = os.path.join(tmp_dir, repo_name) t = Repo.init(repo_dir, mkdir=True) self.addCleanup(t.close) self.assertEqual(os.listdir(repo_dir), ['.git']) self.assertFilesystemHidden(os.path.join(repo_dir, '.git')) @skipIf(sys.platform == 'win32', 'fails on Windows') def test_fetch(self): r = self.open_repo('a.git') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) t = Repo.init(tmp_dir) self.addCleanup(t.close) r.fetch(t) self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t) self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t) self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t) self.assertIn(b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', t) self.assertIn(b'b0931cadc54336e78a1d980420e3268903b57a50', t) @skipIf(sys.platform == 'win32', 'fails on Windows') def test_fetch_ignores_missing_refs(self): r = self.open_repo('a.git') missing = b'1234566789123456789123567891234657373833' r.refs[b'refs/heads/blah'] = missing tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) t = Repo.init(tmp_dir) self.addCleanup(t.close) r.fetch(t) self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t) self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t) self.assertIn(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', t) self.assertIn(b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', t) self.assertIn(b'b0931cadc54336e78a1d980420e3268903b57a50', t) self.assertNotIn(missing, t) def test_clone(self): r = self.open_repo('a.git') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) with r.clone(tmp_dir, mkdir=False) as t: self.assertEqual({ b'HEAD': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', b'refs/remotes/origin/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', b'refs/heads/master': b'a90fa2d900a17e99b433217e988c4eb4a2e9a097', b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', b'refs/tags/mytag-packed': b'b0931cadc54336e78a1d980420e3268903b57a50', }, t.refs.as_dict()) shas = [e.commit.id for e in r.get_walker()] self.assertEqual(shas, [t.head(), b'2a72d929692c41d8554c07f6301757ba18a65d91']) c = t.get_config() encoded_path = r.path if not isinstance(encoded_path, bytes): encoded_path = os.fsencode(encoded_path) - self.assertEqual(encoded_path, c.get((b'remote', b'origin'), b'url')) + self.assertEqual( + encoded_path, c.get((b'remote', b'origin'), b'url')) self.assertEqual( b'+refs/heads/*:refs/remotes/origin/*', c.get((b'remote', b'origin'), b'fetch')) def test_clone_no_head(self): temp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, temp_dir) repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos') dest_dir = os.path.join(temp_dir, 'a.git') shutil.copytree(os.path.join(repo_dir, 'a.git'), dest_dir, symlinks=True) r = Repo(dest_dir) del r.refs[b"refs/heads/master"] del r.refs[b"HEAD"] t = r.clone(os.path.join(temp_dir, 'b.git'), mkdir=True) self.assertEqual({ b'refs/tags/mytag': b'28237f4dc30d0d462658d6b937b08a0f0b6ef55a', b'refs/tags/mytag-packed': b'b0931cadc54336e78a1d980420e3268903b57a50', }, t.refs.as_dict()) def test_clone_empty(self): """Test clone() doesn't crash if HEAD points to a non-existing ref. This simulates cloning server-side bare repository either when it is still empty or if user renames master branch and pushes private repo to the server. Non-bare repo HEAD always points to an existing ref. """ r = self.open_repo('empty.git') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) r.clone(tmp_dir, mkdir=False, bare=True) def test_clone_bare(self): r = self.open_repo('a.git') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) t = r.clone(tmp_dir, mkdir=False) t.close() def test_clone_checkout_and_bare(self): r = self.open_repo('a.git') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) self.assertRaises(ValueError, r.clone, tmp_dir, mkdir=False, checkout=True, bare=True) def test_merge_history(self): r = self.open_repo('simple_merge.git') shas = [e.commit.id for e in r.get_walker()] self.assertEqual(shas, [b'5dac377bdded4c9aeb8dff595f0faeebcc8498cc', b'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', b'4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e', b'0d89f20333fbb1d2f3a94da77f4981373d8f4310']) def test_out_of_order_merge(self): """Test that revision history is ordered by date, not parent order.""" r = self.open_repo('ooo_merge.git') shas = [e.commit.id for e in r.get_walker()] self.assertEqual(shas, [b'7601d7f6231db6a57f7bbb79ee52e4d462fd44d1', b'f507291b64138b875c28e03469025b1ea20bc614', b'fb5b0425c7ce46959bec94d54b9a157645e114f5', b'f9e39b120c68182a4ba35349f832d0e4e61f485c']) def test_get_tags_empty(self): r = self.open_repo('ooo_merge.git') self.assertEqual({}, r.refs.as_dict(b'refs/tags')) def test_get_config(self): r = self.open_repo('ooo_merge.git') self.assertIsInstance(r.get_config(), Config) def test_get_config_stack(self): r = self.open_repo('ooo_merge.git') self.assertIsInstance(r.get_config_stack(), Config) @skipIf(not getattr(os, 'symlink', None), 'Requires symlink support') def test_submodule(self): temp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, temp_dir) repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos') shutil.copytree(os.path.join(repo_dir, 'a.git'), os.path.join(temp_dir, 'a.git'), symlinks=True) rel = os.path.relpath(os.path.join(repo_dir, 'submodule'), temp_dir) os.symlink(os.path.join(rel, 'dotgit'), os.path.join(temp_dir, '.git')) with Repo(temp_dir) as r: self.assertEqual(r.head(), b'a90fa2d900a17e99b433217e988c4eb4a2e9a097') def test_common_revisions(self): """ This test demonstrates that ``find_common_revisions()`` actually returns common heads, not revisions; dulwich already uses ``find_common_revisions()`` in such a manner (see ``Repo.fetch_objects()``). """ expected_shas = set([b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e']) # Source for objects. r_base = self.open_repo('simple_merge.git') # Re-create each-side of the merge in simple_merge.git. # # Since the trees and blobs are missing, the repository created is # corrupted, but we're only checking for commits for the purpose of # this test, so it's immaterial. r1_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, r1_dir) r1_commits = [b'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e', b'0d89f20333fbb1d2f3a94da77f4981373d8f4310'] r2_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, r2_dir) r2_commits = [b'4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD b'60dacdc733de308bb77bb76ce0fb0f9b44c9769e', b'0d89f20333fbb1d2f3a94da77f4981373d8f4310'] r1 = Repo.init_bare(r1_dir) for c in r1_commits: r1.object_store.add_object(r_base.get_object(c)) r1.refs[b'HEAD'] = r1_commits[0] r2 = Repo.init_bare(r2_dir) for c in r2_commits: r2.object_store.add_object(r_base.get_object(c)) r2.refs[b'HEAD'] = r2_commits[0] # Finally, the 'real' testing! shas = r2.object_store.find_common_revisions(r1.get_graph_walker()) self.assertEqual(set(shas), expected_shas) shas = r1.object_store.find_common_revisions(r2.get_graph_walker()) self.assertEqual(set(shas), expected_shas) def test_shell_hook_pre_commit(self): if os.name != 'posix': self.skipTest('shell hook tests requires POSIX shell') pre_commit_fail = """#!/bin/sh exit 1 """ pre_commit_success = """#!/bin/sh exit 0 """ repo_dir = os.path.join(self.mkdtemp()) self.addCleanup(shutil.rmtree, repo_dir) r = Repo.init(repo_dir) self.addCleanup(r.close) pre_commit = os.path.join(r.controldir(), 'hooks', 'pre-commit') with open(pre_commit, 'w') as f: f.write(pre_commit_fail) os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) self.assertRaises(errors.CommitError, r.do_commit, 'failed commit', committer='Test Committer ', author='Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) with open(pre_commit, 'w') as f: f.write(pre_commit_success) os.chmod(pre_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) commit_sha = r.do_commit( b'empty commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual([], r[commit_sha].parents) def test_shell_hook_commit_msg(self): if os.name != 'posix': self.skipTest('shell hook tests requires POSIX shell') commit_msg_fail = """#!/bin/sh exit 1 """ commit_msg_success = """#!/bin/sh exit 0 """ repo_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, repo_dir) r = Repo.init(repo_dir) self.addCleanup(r.close) commit_msg = os.path.join(r.controldir(), 'hooks', 'commit-msg') with open(commit_msg, 'w') as f: f.write(commit_msg_fail) os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) self.assertRaises(errors.CommitError, r.do_commit, b'failed commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) with open(commit_msg, 'w') as f: f.write(commit_msg_success) os.chmod(commit_msg, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) commit_sha = r.do_commit( b'empty commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual([], r[commit_sha].parents) def test_shell_hook_post_commit(self): if os.name != 'posix': self.skipTest('shell hook tests requires POSIX shell') repo_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, repo_dir) r = Repo.init(repo_dir) self.addCleanup(r.close) (fd, path) = tempfile.mkstemp(dir=repo_dir) os.close(fd) post_commit_msg = """#!/bin/sh rm """ + path + """ """ root_sha = r.do_commit( b'empty commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) self.assertEqual([], r[root_sha].parents) post_commit = os.path.join(r.controldir(), 'hooks', 'post-commit') with open(post_commit, 'wb') as f: f.write(post_commit_msg.encode(locale.getpreferredencoding())) os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) commit_sha = r.do_commit( b'empty commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) self.assertEqual([root_sha], r[commit_sha].parents) self.assertFalse(os.path.exists(path)) post_commit_msg_fail = """#!/bin/sh exit 1 """ with open(post_commit, 'w') as f: f.write(post_commit_msg_fail) os.chmod(post_commit, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) warnings.simplefilter("always", UserWarning) self.addCleanup(warnings.resetwarnings) warnings_list, restore_warnings = setup_warning_catcher() self.addCleanup(restore_warnings) commit_sha2 = r.do_commit( b'empty commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) expected_warning = UserWarning( 'post-commit hook failed: Hook post-commit exited with ' 'non-zero status',) for w in warnings_list: if (type(w) == type(expected_warning) and w.args == expected_warning.args): break else: raise AssertionError( 'Expected warning %r not in %r' % (expected_warning, warnings_list)) self.assertEqual([commit_sha], r[commit_sha2].parents) def test_as_dict(self): def check(repo): self.assertEqual( repo.refs.subkeys(b'refs/tags'), repo.refs.subkeys(b'refs/tags/')) self.assertEqual( repo.refs.as_dict(b'refs/tags'), repo.refs.as_dict(b'refs/tags/')) self.assertEqual( repo.refs.as_dict(b'refs/heads'), repo.refs.as_dict(b'refs/heads/')) bare = self.open_repo('a.git') tmp_dir = self.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) with bare.clone(tmp_dir, mkdir=False) as nonbare: check(nonbare) check(bare) def test_working_tree(self): temp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, temp_dir) worktree_temp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, worktree_temp_dir) r = Repo.init(temp_dir) self.addCleanup(r.close) root_sha = r.do_commit( b'empty commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) r.refs[b'refs/heads/master'] = root_sha w = Repo._init_new_working_directory(worktree_temp_dir, r) self.addCleanup(w.close) new_sha = w.do_commit( b'new commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) w.refs[b'HEAD'] = new_sha self.assertEqual(os.path.abspath(r.controldir()), os.path.abspath(w.commondir())) self.assertEqual(r.refs.keys(), w.refs.keys()) self.assertNotEqual(r.head(), w.head()) class BuildRepoRootTests(TestCase): """Tests that build on-disk repos from scratch. Repos live in a temp dir and are torn down after each test. They start with a single commit in master having single file named 'a'. """ def get_repo_dir(self): return os.path.join(tempfile.mkdtemp(), 'test') def setUp(self): super(BuildRepoRootTests, self).setUp() self._repo_dir = self.get_repo_dir() os.makedirs(self._repo_dir) r = self._repo = Repo.init(self._repo_dir) self.addCleanup(tear_down_repo, r) self.assertFalse(r.bare) self.assertEqual(b'ref: refs/heads/master', r.refs.read_ref(b'HEAD')) self.assertRaises(KeyError, lambda: r.refs[b'refs/heads/master']) with open(os.path.join(r.path, 'a'), 'wb') as f: f.write(b'file contents') r.stage(['a']) commit_sha = r.do_commit( b'msg', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) self.assertEqual([], r[commit_sha].parents) self._root_commit = commit_sha def test_get_shallow(self): self.assertEqual(set(), self._repo.get_shallow()) with open(os.path.join(self._repo.path, '.git', 'shallow'), 'wb') as f: f.write(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097\n') self.assertEqual({b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'}, self._repo.get_shallow()) def test_update_shallow(self): self._repo.update_shallow(None, None) # no op self.assertEqual(set(), self._repo.get_shallow()) self._repo.update_shallow( [b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'], None) self.assertEqual( {b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'}, self._repo.get_shallow()) self._repo.update_shallow( [b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'], [b'f9e39b120c68182a4ba35349f832d0e4e61f485c']) self.assertEqual({b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'}, self._repo.get_shallow()) def test_build_repo(self): r = self._repo self.assertEqual(b'ref: refs/heads/master', r.refs.read_ref(b'HEAD')) self.assertEqual(self._root_commit, r.refs[b'refs/heads/master']) expected_blob = objects.Blob.from_string(b'file contents') self.assertEqual(expected_blob.data, r[expected_blob.id].data) actual_commit = r[self._root_commit] self.assertEqual(b'msg', actual_commit.message) def test_commit_modified(self): r = self._repo with open(os.path.join(r.path, 'a'), 'wb') as f: f.write(b'new contents') r.stage(['a']) commit_sha = r.do_commit( b'modified a', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual([self._root_commit], r[commit_sha].parents) a_mode, a_id = tree_lookup_path(r.get_object, r[commit_sha].tree, b'a') self.assertEqual(stat.S_IFREG | 0o644, a_mode) self.assertEqual(b'new contents', r[a_id].data) @skipIf(not getattr(os, 'symlink', None), 'Requires symlink support') def test_commit_symlink(self): r = self._repo os.symlink('a', os.path.join(r.path, 'b')) r.stage(['a', 'b']) commit_sha = r.do_commit( b'Symlink b', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual([self._root_commit], r[commit_sha].parents) b_mode, b_id = tree_lookup_path(r.get_object, r[commit_sha].tree, b'b') self.assertTrue(stat.S_ISLNK(b_mode)) self.assertEqual(b'a', r[b_id].data) def test_commit_merge_heads_file(self): tmp_dir = tempfile.mkdtemp() self.addCleanup(shutil.rmtree, tmp_dir) r = Repo.init(tmp_dir) with open(os.path.join(r.path, 'a'), 'w') as f: f.write('initial text') c1 = r.do_commit( b'initial commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) with open(os.path.join(r.path, 'a'), 'w') as f: f.write('merged text') with open(os.path.join(r.path, '.git', 'MERGE_HEADS'), 'w') as f: f.write('c27a2d21dd136312d7fa9e8baabb82561a1727d0\n') r.stage(['a']) commit_sha = r.do_commit( b'deleted a', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual([ c1, b'c27a2d21dd136312d7fa9e8baabb82561a1727d0'], r[commit_sha].parents) def test_commit_deleted(self): r = self._repo os.remove(os.path.join(r.path, 'a')) r.stage(['a']) commit_sha = r.do_commit( b'deleted a', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual([self._root_commit], r[commit_sha].parents) self.assertEqual([], list(r.open_index())) tree = r[r[commit_sha].tree] self.assertEqual([], list(tree.iteritems())) def test_commit_follows(self): r = self._repo r.refs.set_symbolic_ref(b'HEAD', b'refs/heads/bla') commit_sha = r.do_commit( b'commit with strange character', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=b'HEAD') self.assertEqual(commit_sha, r[b'refs/heads/bla'].id) def test_commit_encoding(self): r = self._repo commit_sha = r.do_commit( b'commit with strange character \xee', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, encoding=b"iso8859-1") self.assertEqual(b"iso8859-1", r[commit_sha].encoding) def test_compression_level(self): r = self._repo c = r.get_config() c.set(('core',), 'compression', '3') c.set(('core',), 'looseCompression', '4') c.write_to_path() r = Repo(self._repo_dir) self.assertEqual(r.object_store.loose_compression_level, 4) def test_commit_encoding_from_config(self): r = self._repo c = r.get_config() c.set(('i18n',), 'commitEncoding', 'iso8859-1') c.write_to_path() commit_sha = r.do_commit( b'commit with strange character \xee', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0) self.assertEqual(b"iso8859-1", r[commit_sha].encoding) def test_commit_config_identity(self): # commit falls back to the users' identity if it wasn't specified r = self._repo c = r.get_config() c.set((b"user", ), b"name", b"Jelmer") c.set((b"user", ), b"email", b"jelmer@apache.org") c.write_to_path() commit_sha = r.do_commit(b'message') self.assertEqual( b"Jelmer ", r[commit_sha].author) self.assertEqual( b"Jelmer ", r[commit_sha].committer) def test_commit_config_identity_strips_than(self): # commit falls back to the users' identity if it wasn't specified, # and strips superfluous <> r = self._repo c = r.get_config() c.set((b"user", ), b"name", b"Jelmer") c.set((b"user", ), b"email", b"") c.write_to_path() commit_sha = r.do_commit(b'message') self.assertEqual( b"Jelmer ", r[commit_sha].author) self.assertEqual( b"Jelmer ", r[commit_sha].committer) def test_commit_config_identity_in_memoryrepo(self): # commit falls back to the users' identity if it wasn't specified r = MemoryRepo.init_bare([], {}) c = r.get_config() c.set((b"user", ), b"name", b"Jelmer") c.set((b"user", ), b"email", b"jelmer@apache.org") commit_sha = r.do_commit(b'message', tree=objects.Tree().id) self.assertEqual( b"Jelmer ", r[commit_sha].author) self.assertEqual( b"Jelmer ", r[commit_sha].committer) def overrideEnv(self, name, value): def restore(): if oldval is not None: os.environ[name] = oldval else: del os.environ[name] oldval = os.environ.get(name) os.environ[name] = value self.addCleanup(restore) def test_commit_config_identity_from_env(self): # commit falls back to the users' identity if it wasn't specified self.overrideEnv('GIT_COMMITTER_NAME', 'joe') self.overrideEnv('GIT_COMMITTER_EMAIL', 'joe@example.com') r = self._repo c = r.get_config() c.set((b"user", ), b"name", b"Jelmer") c.set((b"user", ), b"email", b"jelmer@apache.org") c.write_to_path() commit_sha = r.do_commit(b'message') self.assertEqual( b"Jelmer ", r[commit_sha].author) self.assertEqual( b"joe ", r[commit_sha].committer) def test_commit_fail_ref(self): r = self._repo def set_if_equals(name, old_ref, new_ref, **kwargs): return False r.refs.set_if_equals = set_if_equals def add_if_new(name, new_ref, **kwargs): self.fail('Unexpected call to add_if_new') r.refs.add_if_new = add_if_new old_shas = set(r.object_store) self.assertRaises(errors.CommitError, r.do_commit, b'failed commit', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12345, commit_timezone=0, author_timestamp=12345, author_timezone=0) new_shas = set(r.object_store) - old_shas self.assertEqual(1, len(new_shas)) # Check that the new commit (now garbage) was added. new_commit = r[new_shas.pop()] self.assertEqual(r[self._root_commit].tree, new_commit.tree) self.assertEqual(b'failed commit', new_commit.message) def test_commit_branch(self): r = self._repo commit_sha = r.do_commit( b'commit to branch', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=b"refs/heads/new_branch") self.assertEqual(self._root_commit, r[b"HEAD"].id) self.assertEqual(commit_sha, r[b"refs/heads/new_branch"].id) self.assertEqual([], r[commit_sha].parents) self.assertTrue(b"refs/heads/new_branch" in r) new_branch_head = commit_sha commit_sha = r.do_commit( b'commit to branch 2', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=b"refs/heads/new_branch") self.assertEqual(self._root_commit, r[b"HEAD"].id) self.assertEqual(commit_sha, r[b"refs/heads/new_branch"].id) self.assertEqual([new_branch_head], r[commit_sha].parents) def test_commit_merge_heads(self): r = self._repo merge_1 = r.do_commit( b'commit to branch 2', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=b"refs/heads/new_branch") commit_sha = r.do_commit( b'commit with merge', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, merge_heads=[merge_1]) self.assertEqual( [self._root_commit, merge_1], r[commit_sha].parents) def test_commit_dangling_commit(self): r = self._repo old_shas = set(r.object_store) old_refs = r.get_refs() commit_sha = r.do_commit( b'commit with no ref', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=None) new_shas = set(r.object_store) - old_shas # New sha is added, but no new refs self.assertEqual(1, len(new_shas)) new_commit = r[new_shas.pop()] self.assertEqual(r[self._root_commit].tree, new_commit.tree) self.assertEqual([], r[commit_sha].parents) self.assertEqual(old_refs, r.get_refs()) def test_commit_dangling_commit_with_parents(self): r = self._repo old_shas = set(r.object_store) old_refs = r.get_refs() commit_sha = r.do_commit( b'commit with no ref', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=None, merge_heads=[self._root_commit]) new_shas = set(r.object_store) - old_shas # New sha is added, but no new refs self.assertEqual(1, len(new_shas)) new_commit = r[new_shas.pop()] self.assertEqual(r[self._root_commit].tree, new_commit.tree) self.assertEqual([self._root_commit], r[commit_sha].parents) self.assertEqual(old_refs, r.get_refs()) def test_stage_absolute(self): r = self._repo os.remove(os.path.join(r.path, 'a')) self.assertRaises(ValueError, r.stage, [os.path.join(r.path, 'a')]) def test_stage_deleted(self): r = self._repo os.remove(os.path.join(r.path, 'a')) r.stage(['a']) r.stage(['a']) # double-stage a deleted path def test_stage_directory(self): r = self._repo os.mkdir(os.path.join(r.path, 'c')) r.stage(['c']) self.assertEqual([b'a'], list(r.open_index())) @skipIf(sys.platform == 'win32', 'tries to implicitly decode as utf8') def test_commit_no_encode_decode(self): r = self._repo repo_path_bytes = os.fsencode(r.path) encodings = ('utf8', 'latin1') names = [u'À'.encode(encoding) for encoding in encodings] for name, encoding in zip(names, encodings): full_path = os.path.join(repo_path_bytes, name) with open(full_path, 'wb') as f: f.write(encoding.encode('ascii')) # These files are break tear_down_repo, so cleanup these files # ourselves. self.addCleanup(os.remove, full_path) r.stage(names) commit_sha = r.do_commit( b'Files with different encodings', committer=b'Test Committer ', author=b'Test Author ', commit_timestamp=12395, commit_timezone=0, author_timestamp=12395, author_timezone=0, ref=None, merge_heads=[self._root_commit]) for name, encoding in zip(names, encodings): mode, id = tree_lookup_path(r.get_object, r[commit_sha].tree, name) self.assertEqual(stat.S_IFREG | 0o644, mode) self.assertEqual(encoding.encode('ascii'), r[id].data) def test_discover_intended(self): path = os.path.join(self._repo_dir, 'b/c') r = Repo.discover(path) self.assertEqual(r.head(), self._repo.head()) def test_discover_isrepo(self): r = Repo.discover(self._repo_dir) self.assertEqual(r.head(), self._repo.head()) def test_discover_notrepo(self): with self.assertRaises(NotGitRepository): Repo.discover('/') class CheckUserIdentityTests(TestCase): def test_valid(self): check_user_identity(b'Me ') def test_invalid(self): self.assertRaises(InvalidUserIdentity, check_user_identity, b'No Email') self.assertRaises(InvalidUserIdentity, check_user_identity, b'Fullname ') self.assertRaises(InvalidUserIdentity, check_user_identity, b'Fullname >order<>')