diff --git a/dulwich/refs.py b/dulwich/refs.py index aded324b..75e1cba1 100644 --- a/dulwich/refs.py +++ b/dulwich/refs.py @@ -1,1008 +1,1052 @@ # refs.py -- For dealing with git refs # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Ref handling. """ import os from dulwich.errors import ( PackedRefsException, RefFormatError, ) from dulwich.objects import ( git_line, valid_hexsha, ZERO_SHA, ) from dulwich.file import ( GitFile, ensure_dir_exists, ) SYMREF = b'ref: ' LOCAL_BRANCH_PREFIX = b'refs/heads/' LOCAL_TAG_PREFIX = b'refs/tags/' BAD_REF_CHARS = set(b'\177 ~^:?*[') ANNOTATED_TAG_SUFFIX = b'^{}' def parse_symref_value(contents): """Parse a symref value. Args: contents: Contents to parse Returns: Destination """ if contents.startswith(SYMREF): return contents[len(SYMREF):].rstrip(b'\r\n') raise ValueError(contents) def check_ref_format(refname): """Check if a refname is correctly formatted. Implements all the same rules as git-check-ref-format[1]. [1] http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html Args: refname: The refname to check Returns: True if refname is valid, False otherwise """ # These could be combined into one big expression, but are listed # separately to parallel [1]. if b'/.' in refname or refname.startswith(b'.'): return False if b'/' not in refname: return False if b'..' in refname: return False for i, c in enumerate(refname): if ord(refname[i:i+1]) < 0o40 or c in BAD_REF_CHARS: return False if refname[-1] in b'/.': return False if refname.endswith(b'.lock'): return False if b'@{' in refname: return False if b'\\' in refname: return False return True class RefsContainer(object): """A container for refs.""" def __init__(self, logger=None): self._logger = logger def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None, timezone=None, message=None): if self._logger is None: return if message is None: return self._logger(ref, old_sha, new_sha, committer, timestamp, timezone, message) def set_symbolic_ref(self, name, other, committer=None, timestamp=None, timezone=None, message=None): """Make a ref point at another ref. Args: name: Name of the ref to set other: Name of the ref to point at message: Optional message """ raise NotImplementedError(self.set_symbolic_ref) def get_packed_refs(self): """Get contents of the packed-refs file. Returns: Dictionary mapping ref names to SHA1s Note: Will return an empty dictionary when no packed-refs file is present. """ raise NotImplementedError(self.get_packed_refs) def get_peeled(self, name): """Return the cached peeled value of a ref, if available. Args: name: Name of the ref to peel Returns: The peeled value of the ref. If the ref is known not point to a tag, this will be the SHA the ref refers to. If the ref may point to a tag, but no cached information is available, None is returned. """ return None def import_refs(self, base, other, committer=None, timestamp=None, timezone=None, message=None, prune=False): if prune: to_delete = set(self.subkeys(base)) else: to_delete = set() for name, value in other.items(): if value is None: to_delete.add(name) else: self.set_if_equals(b'/'.join((base, name)), None, value, message=message) if to_delete: try: to_delete.remove(name) except KeyError: pass for ref in to_delete: self.remove_if_equals( b'/'.join((base, ref)), None, message=message) def allkeys(self): """All refs present in this container.""" raise NotImplementedError(self.allkeys) def __iter__(self): return iter(self.allkeys()) def keys(self, base=None): """Refs present in this container. Args: base: An optional base to return refs under. Returns: An unsorted set of valid refs in this container, including packed refs. """ if base is not None: return self.subkeys(base) else: return self.allkeys() def subkeys(self, base): """Refs present in this container under a base. Args: base: The base to return refs under. Returns: A set of valid refs in this container under the base; the base prefix is stripped from the ref names returned. """ keys = set() base_len = len(base) + 1 for refname in self.allkeys(): if refname.startswith(base): keys.add(refname[base_len:]) return keys def as_dict(self, base=None): """Return the contents of this container as a dictionary. """ ret = {} keys = self.keys(base) if base is None: base = b'' else: base = base.rstrip(b'/') for key in keys: try: ret[key] = self[(base + b'/' + key).strip(b'/')] except KeyError: continue # Unable to resolve return ret def _check_refname(self, name): """Ensure a refname is valid and lives in refs or is HEAD. HEAD is not a valid refname according to git-check-ref-format, but this class needs to be able to touch HEAD. Also, check_ref_format expects refnames without the leading 'refs/', but this class requires that so it cannot touch anything outside the refs dir (or HEAD). Args: name: The name of the reference. Raises: KeyError: if a refname is not HEAD or is otherwise not valid. """ if name in (b'HEAD', b'refs/stash'): return if not name.startswith(b'refs/') or not check_ref_format(name[5:]): raise RefFormatError(name) def read_ref(self, refname): """Read a reference without following any references. Args: refname: The name of the reference Returns: The contents of the ref file, or None if it does not exist. """ contents = self.read_loose_ref(refname) if not contents: contents = self.get_packed_refs().get(refname, None) return contents def read_loose_ref(self, name): """Read a loose reference and return its contents. Args: name: the refname to read Returns: The contents of the ref file, or None if it does not exist. """ raise NotImplementedError(self.read_loose_ref) def follow(self, name): """Follow a reference name. Returns: a tuple of (refnames, sha), wheres refnames are the names of references in the chain """ contents = SYMREF + name depth = 0 refnames = [] while contents.startswith(SYMREF): refname = contents[len(SYMREF):] refnames.append(refname) contents = self.read_ref(refname) if not contents: break depth += 1 if depth > 5: raise KeyError(name) return refnames, contents def _follow(self, name): import warnings warnings.warn( "RefsContainer._follow is deprecated. Use RefsContainer.follow " "instead.", DeprecationWarning) refnames, contents = self.follow(name) if not refnames: return (None, contents) return (refnames[-1], contents) def __contains__(self, refname): if self.read_ref(refname): return True return False def __getitem__(self, name): """Get the SHA1 for a reference name. This method follows all symbolic references. """ _, sha = self.follow(name) if sha is None: raise KeyError(name) return sha def set_if_equals(self, name, old_ref, new_ref, committer=None, timestamp=None, timezone=None, message=None): """Set a refname to new_ref only if it currently equals old_ref. This method follows all symbolic references if applicable for the subclass, and can be used to perform an atomic compare-and-swap operation. Args: name: The refname to set. old_ref: The old sha the refname must refer to, or None to set unconditionally. new_ref: The new sha the refname will refer to. message: Message for reflog Returns: True if the set was successful, False otherwise. """ raise NotImplementedError(self.set_if_equals) def add_if_new(self, name, ref): """Add a new reference only if it does not already exist. Args: name: Ref name ref: Ref value message: Message for reflog """ raise NotImplementedError(self.add_if_new) def __setitem__(self, name, ref): """Set a reference name to point to the given SHA1. This method follows all symbolic references if applicable for the subclass. Note: This method unconditionally overwrites the contents of a reference. To update atomically only if the reference has not changed, use set_if_equals(). Args: name: The refname to set. ref: The new sha the refname will refer to. """ self.set_if_equals(name, None, ref) def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None): """Remove a refname only if it currently equals old_ref. This method does not follow symbolic references, even if applicable for the subclass. It can be used to perform an atomic compare-and-delete operation. Args: name: The refname to delete. old_ref: The old sha the refname must refer to, or None to delete unconditionally. message: Message for reflog Returns: True if the delete was successful, False otherwise. """ raise NotImplementedError(self.remove_if_equals) def __delitem__(self, name): """Remove a refname. This method does not follow symbolic references, even if applicable for the subclass. Note: This method unconditionally deletes the contents of a reference. To delete atomically only if the reference has not changed, use remove_if_equals(). Args: name: The refname to delete. """ self.remove_if_equals(name, None) def get_symrefs(self): """Get a dict with all symrefs in this container. Returns: Dictionary mapping source ref to target ref """ ret = {} for src in self.allkeys(): try: dst = parse_symref_value(self.read_ref(src)) except ValueError: pass else: ret[src] = dst return ret def watch(self): """Watch for changes to the refs in this container. - Returns a context manager that yields tuples with (refname, old_sha, - new_sha) + Returns a context manager that yields tuples with (refname, new_sha) """ raise NotImplementedError(self.watch) class _DictRefsWatcher(object): def __init__(self, refs): self._refs = refs def __enter__(self): from queue import Queue self.queue = Queue() self._refs._watchers.add(self) return self def __next__(self): return self.queue.get() def _notify(self, entry): self.queue.put_nowait(entry) def __exit__(self, exc_type, exc_val, exc_tb): self._refs._watchers.remove(self) return False class DictRefsContainer(RefsContainer): """RefsContainer backed by a simple dict. This container does not support symbolic or packed references and is not threadsafe. """ def __init__(self, refs, logger=None): super(DictRefsContainer, self).__init__(logger=logger) self._refs = refs self._peeled = {} self._watchers = set() def allkeys(self): return self._refs.keys() def read_loose_ref(self, name): return self._refs.get(name, None) def get_packed_refs(self): return {} - def _notify(self, ref, oldsha, newsha): + def _notify(self, ref, newsha): for watcher in self._watchers: - watcher._notify((ref, oldsha, newsha)) + watcher._notify((ref, newsha)) def watch(self): return _DictRefsWatcher(self) def set_symbolic_ref(self, name, other, committer=None, timestamp=None, timezone=None, message=None): old = self.follow(name)[-1] new = SYMREF + other self._refs[name] = new - self._notify(name, old, new) + self._notify(name, new) self._log(name, old, new, committer=committer, timestamp=timestamp, timezone=timezone, message=message) def set_if_equals(self, name, old_ref, new_ref, committer=None, timestamp=None, timezone=None, message=None): if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref: return False realnames, _ = self.follow(name) for realname in realnames: self._check_refname(realname) old = self._refs.get(realname) self._refs[realname] = new_ref - self._notify(realname, old, new_ref) + self._notify(realname, new_ref) self._log(realname, old, new_ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def add_if_new(self, name, ref, committer=None, timestamp=None, timezone=None, message=None): if name in self._refs: return False self._refs[name] = ref - self._notify(name, None, ref) + self._notify(name, ref) self._log(name, None, ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None): if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref: return False try: old = self._refs.pop(name) except KeyError: pass else: - self._notify(name, old, None) + self._notify(name, None) self._log(name, old, None, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def get_peeled(self, name): return self._peeled.get(name) def _update(self, refs): """Update multiple refs; intended only for testing.""" # TODO(dborowitz): replace this with a public function that uses # set_if_equal. for ref, sha in refs.items(): self.set_if_equal(ref, None, sha) def _update_peeled(self, peeled): """Update cached peeled refs; intended only for testing.""" self._peeled.update(peeled) class InfoRefsContainer(RefsContainer): """Refs container that reads refs from a info/refs file.""" def __init__(self, f): self._refs = {} self._peeled = {} for line in f.readlines(): sha, name = line.rstrip(b'\n').split(b'\t') if name.endswith(ANNOTATED_TAG_SUFFIX): name = name[:-3] if not check_ref_format(name): raise ValueError("invalid ref name %r" % name) self._peeled[name] = sha else: if not check_ref_format(name): raise ValueError("invalid ref name %r" % name) self._refs[name] = sha def allkeys(self): return self._refs.keys() def read_loose_ref(self, name): return self._refs.get(name, None) def get_packed_refs(self): return {} def get_peeled(self, name): try: return self._peeled[name] except KeyError: return self._refs[name] +class _InotifyRefsWatcher(object): + + def __init__(self, path): + import pyinotify + from queue import Queue + self.path = os.fsdecode(path) + self.manager = pyinotify.WatchManager() + self.manager.add_watch( + self.path, pyinotify.IN_DELETE | + pyinotify.IN_CLOSE_WRITE | pyinotify.IN_MOVED_TO, rec=True, + auto_add=True) + + self.notifier = pyinotify.ThreadedNotifier( + self.manager, default_proc_fun=self._notify) + self.queue = Queue() + + def _notify(self, event): + if event.dir: + return + if event.pathname.endswith('.lock'): + return + ref = os.fsencode(os.path.relpath(event.pathname, self.path)) + if event.maskname == 'IN_DELETE': + self.queue.put_nowait((ref, None)) + elif event.maskname in ('IN_CLOSE_WRITE', 'IN_MOVED_TO'): + with open(event.pathname, 'rb') as f: + sha = f.readline().rstrip(b'\n\r') + self.queue.put_nowait((ref, sha)) + + def __next__(self): + return self.queue.get() + + def __enter__(self): + self.notifier.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.notifier.stop() + return False + + class DiskRefsContainer(RefsContainer): """Refs container that reads refs from disk.""" def __init__(self, path, worktree_path=None, logger=None): super(DiskRefsContainer, self).__init__(logger=logger) if getattr(path, 'encode', None) is not None: path = os.fsencode(path) self.path = path if worktree_path is None: worktree_path = path if getattr(worktree_path, 'encode', None) is not None: worktree_path = os.fsencode(worktree_path) self.worktree_path = worktree_path self._packed_refs = None self._peeled_refs = None def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self.path) def subkeys(self, base): subkeys = set() path = self.refpath(base) for root, unused_dirs, files in os.walk(path): dir = root[len(path):] if os.path.sep != '/': dir = dir.replace(os.fsencode(os.path.sep), b"/") dir = dir.strip(b'/') for filename in files: refname = b"/".join(([dir] if dir else []) + [filename]) # check_ref_format requires at least one /, so we prepend the # base before calling it. if check_ref_format(base + b'/' + refname): subkeys.add(refname) for key in self.get_packed_refs(): if key.startswith(base): subkeys.add(key[len(base):].strip(b'/')) return subkeys def allkeys(self): allkeys = set() if os.path.exists(self.refpath(b'HEAD')): allkeys.add(b'HEAD') path = self.refpath(b'') refspath = self.refpath(b'refs') for root, unused_dirs, files in os.walk(refspath): dir = root[len(path):] if os.path.sep != '/': dir = dir.replace(os.fsencode(os.path.sep), b"/") for filename in files: refname = b"/".join([dir, filename]) if check_ref_format(refname): allkeys.add(refname) allkeys.update(self.get_packed_refs()) return allkeys def refpath(self, name): """Return the disk path of a ref. """ if os.path.sep != "/": name = name.replace(b"/", os.fsencode(os.path.sep)) # TODO: as the 'HEAD' reference is working tree specific, it # should actually not be a part of RefsContainer if name == b'HEAD': return os.path.join(self.worktree_path, name) else: return os.path.join(self.path, name) def get_packed_refs(self): """Get contents of the packed-refs file. Returns: Dictionary mapping ref names to SHA1s Note: Will return an empty dictionary when no packed-refs file is present. """ # TODO: invalidate the cache on repacking if self._packed_refs is None: # set both to empty because we want _peeled_refs to be # None if and only if _packed_refs is also None. self._packed_refs = {} self._peeled_refs = {} path = os.path.join(self.path, b'packed-refs') try: f = GitFile(path, 'rb') except FileNotFoundError: return {} with f: first_line = next(iter(f)).rstrip() if (first_line.startswith(b'# pack-refs') and b' peeled' in first_line): for sha, name, peeled in read_packed_refs_with_peeled(f): self._packed_refs[name] = sha if peeled: self._peeled_refs[name] = peeled else: f.seek(0) for sha, name in read_packed_refs(f): self._packed_refs[name] = sha return self._packed_refs def get_peeled(self, name): """Return the cached peeled value of a ref, if available. Args: name: Name of the ref to peel Returns: The peeled value of the ref. If the ref is known not point to a tag, this will be the SHA the ref refers to. If the ref may point to a tag, but no cached information is available, None is returned. """ self.get_packed_refs() if self._peeled_refs is None or name not in self._packed_refs: # No cache: no peeled refs were read, or this ref is loose return None if name in self._peeled_refs: return self._peeled_refs[name] else: # Known not peelable return self[name] def read_loose_ref(self, name): """Read a reference file and return its contents. If the reference file a symbolic reference, only read the first line of the file. Otherwise, only read the first 40 bytes. Args: name: the refname to read, relative to refpath Returns: The contents of the ref file, or None if the file does not exist. Raises: IOError: if any other error occurs """ filename = self.refpath(name) try: with GitFile(filename, 'rb') as f: header = f.read(len(SYMREF)) if header == SYMREF: # Read only the first line return header + next(iter(f)).rstrip(b'\r\n') else: # Read only the first 40 bytes return header + f.read(40 - len(SYMREF)) except (FileNotFoundError, IsADirectoryError, NotADirectoryError): return None def _remove_packed_ref(self, name): if self._packed_refs is None: return filename = os.path.join(self.path, b'packed-refs') # reread cached refs from disk, while holding the lock f = GitFile(filename, 'wb') try: self._packed_refs = None self.get_packed_refs() if name not in self._packed_refs: return del self._packed_refs[name] if name in self._peeled_refs: del self._peeled_refs[name] write_packed_refs(f, self._packed_refs, self._peeled_refs) f.close() finally: f.abort() def set_symbolic_ref(self, name, other, committer=None, timestamp=None, timezone=None, message=None): """Make a ref point at another ref. Args: name: Name of the ref to set other: Name of the ref to point at message: Optional message to describe the change """ self._check_refname(name) self._check_refname(other) filename = self.refpath(name) f = GitFile(filename, 'wb') try: f.write(SYMREF + other + b'\n') sha = self.follow(name)[-1] self._log(name, sha, sha, committer=committer, timestamp=timestamp, timezone=timezone, message=message) except BaseException: f.abort() raise else: f.close() def set_if_equals(self, name, old_ref, new_ref, committer=None, timestamp=None, timezone=None, message=None): """Set a refname to new_ref only if it currently equals old_ref. This method follows all symbolic references, and can be used to perform an atomic compare-and-swap operation. Args: name: The refname to set. old_ref: The old sha the refname must refer to, or None to set unconditionally. new_ref: The new sha the refname will refer to. message: Set message for reflog Returns: True if the set was successful, False otherwise. """ self._check_refname(name) try: realnames, _ = self.follow(name) realname = realnames[-1] except (KeyError, IndexError): realname = name filename = self.refpath(realname) # make sure none of the ancestor folders is in packed refs probe_ref = os.path.dirname(realname) packed_refs = self.get_packed_refs() while probe_ref: if packed_refs.get(probe_ref, None) is not None: raise NotADirectoryError(filename) probe_ref = os.path.dirname(probe_ref) ensure_dir_exists(os.path.dirname(filename)) with GitFile(filename, 'wb') as f: if old_ref is not None: try: # read again while holding the lock orig_ref = self.read_loose_ref(realname) if orig_ref is None: orig_ref = self.get_packed_refs().get( realname, ZERO_SHA) if orig_ref != old_ref: f.abort() return False except (OSError, IOError): f.abort() raise try: f.write(new_ref + b'\n') except (OSError, IOError): f.abort() raise self._log(realname, old_ref, new_ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def add_if_new(self, name, ref, committer=None, timestamp=None, timezone=None, message=None): """Add a new reference only if it does not already exist. This method follows symrefs, and only ensures that the last ref in the chain does not exist. Args: name: The refname to set. ref: The new sha the refname will refer to. message: Optional message for reflog Returns: True if the add was successful, False otherwise. """ try: realnames, contents = self.follow(name) if contents is not None: return False realname = realnames[-1] except (KeyError, IndexError): realname = name self._check_refname(realname) filename = self.refpath(realname) ensure_dir_exists(os.path.dirname(filename)) with GitFile(filename, 'wb') as f: if os.path.exists(filename) or name in self.get_packed_refs(): f.abort() return False try: f.write(ref + b'\n') except (OSError, IOError): f.abort() raise else: self._log(name, None, ref, committer=committer, timestamp=timestamp, timezone=timezone, message=message) return True def remove_if_equals(self, name, old_ref, committer=None, timestamp=None, timezone=None, message=None): """Remove a refname only if it currently equals old_ref. This method does not follow symbolic references. It can be used to perform an atomic compare-and-delete operation. Args: name: The refname to delete. old_ref: The old sha the refname must refer to, or None to delete unconditionally. message: Optional message Returns: True if the delete was successful, False otherwise. """ self._check_refname(name) filename = self.refpath(name) ensure_dir_exists(os.path.dirname(filename)) f = GitFile(filename, 'wb') try: if old_ref is not None: orig_ref = self.read_loose_ref(name) if orig_ref is None: orig_ref = self.get_packed_refs().get(name, ZERO_SHA) if orig_ref != old_ref: return False # remove the reference file itself try: os.remove(filename) except FileNotFoundError: pass # may only be packed self._remove_packed_ref(name) self._log(name, old_ref, None, committer=committer, timestamp=timestamp, timezone=timezone, message=message) finally: # never write, we just wanted the lock f.abort() # outside of the lock, clean-up any parent directory that might now # be empty. this ensures that re-creating a reference of the same # name of what was previously a directory works as expected parent = name while True: try: parent, _ = parent.rsplit(b'/', 1) except ValueError: break parent_filename = self.refpath(parent) try: os.rmdir(parent_filename) except OSError: # this can be caused by the parent directory being # removed by another process, being not empty, etc. # in any case, this is non fatal because we already # removed the reference, just ignore it break return True + def watch(self): + import pyinotify # noqa: F401 + return _InotifyRefsWatcher(self.path) + def _split_ref_line(line): """Split a single ref line into a tuple of SHA1 and name.""" fields = line.rstrip(b'\n\r').split(b' ') if len(fields) != 2: raise PackedRefsException("invalid ref line %r" % line) sha, name = fields if not valid_hexsha(sha): raise PackedRefsException("Invalid hex sha %r" % sha) if not check_ref_format(name): raise PackedRefsException("invalid ref name %r" % name) return (sha, name) def read_packed_refs(f): """Read a packed refs file. Args: f: file-like object to read from Returns: Iterator over tuples with SHA1s and ref names. """ for line in f: if line.startswith(b'#'): # Comment continue if line.startswith(b'^'): raise PackedRefsException( "found peeled ref in packed-refs without peeled") yield _split_ref_line(line) def read_packed_refs_with_peeled(f): """Read a packed refs file including peeled refs. Assumes the "# pack-refs with: peeled" line was already read. Yields tuples with ref names, SHA1s, and peeled SHA1s (or None). Args: f: file-like object to read from, seek'ed to the second line """ last = None for line in f: if line[0] == b'#': continue line = line.rstrip(b'\r\n') if line.startswith(b'^'): if not last: raise PackedRefsException("unexpected peeled ref line") if not valid_hexsha(line[1:]): raise PackedRefsException("Invalid hex sha %r" % line[1:]) sha, name = _split_ref_line(last) last = None yield (sha, name, line[1:]) else: if last: sha, name = _split_ref_line(last) yield (sha, name, None) last = line if last: sha, name = _split_ref_line(last) yield (sha, name, None) def write_packed_refs(f, packed_refs, peeled_refs=None): """Write a packed refs file. Args: f: empty file-like object to write to packed_refs: dict of refname to sha of packed refs to write peeled_refs: dict of refname to peeled value of sha """ if peeled_refs is None: peeled_refs = {} else: f.write(b'# pack-refs with: peeled\n') for refname in sorted(packed_refs.keys()): f.write(git_line(packed_refs[refname], refname)) if refname in peeled_refs: f.write(b'^' + peeled_refs[refname] + b'\n') def read_info_refs(f): ret = {} for line in f.readlines(): (sha, name) = line.rstrip(b"\r\n").split(b"\t", 1) ret[name] = sha return ret def write_info_refs(refs, store): """Generate info refs.""" for name, sha in sorted(refs.items()): # get_refs() includes HEAD as a special case, but we don't want to # advertise it if name == b'HEAD': continue try: o = store[sha] except KeyError: continue peeled = store.peel_sha(sha) yield o.id + b'\t' + name + b'\n' if o.id != peeled.id: yield peeled.id + b'\t' + name + ANNOTATED_TAG_SUFFIX + b'\n' def is_local_branch(x): return x.startswith(LOCAL_BRANCH_PREFIX) def strip_peeled_refs(refs): """Remove all peeled refs""" return {ref: sha for (ref, sha) in refs.items() if not ref.endswith(ANNOTATED_TAG_SUFFIX)} diff --git a/dulwich/tests/test_refs.py b/dulwich/tests/test_refs.py index b2b24795..ac0ec513 100644 --- a/dulwich/tests/test_refs.py +++ b/dulwich/tests/test_refs.py @@ -1,715 +1,713 @@ # test_refs.py -- tests for refs.py # encoding: utf-8 # Copyright (C) 2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for dulwich.refs.""" from io import BytesIO import os import sys import tempfile from dulwich import errors from dulwich.file import ( GitFile, ) from dulwich.objects import ZERO_SHA from dulwich.refs import ( DictRefsContainer, InfoRefsContainer, check_ref_format, _split_ref_line, parse_symref_value, read_packed_refs_with_peeled, read_packed_refs, strip_peeled_refs, write_packed_refs, ) from dulwich.repo import Repo from dulwich.tests import ( SkipTest, TestCase, ) from dulwich.tests.utils import ( open_repo, tear_down_repo, ) class CheckRefFormatTests(TestCase): """Tests for the check_ref_format function. These are the same tests as in the git test suite. """ def test_valid(self): self.assertTrue(check_ref_format(b'heads/foo')) self.assertTrue(check_ref_format(b'foo/bar/baz')) self.assertTrue(check_ref_format(b'refs///heads/foo')) self.assertTrue(check_ref_format(b'foo./bar')) self.assertTrue(check_ref_format(b'heads/foo@bar')) self.assertTrue(check_ref_format(b'heads/fix.lock.error')) def test_invalid(self): self.assertFalse(check_ref_format(b'foo')) self.assertFalse(check_ref_format(b'heads/foo/')) self.assertFalse(check_ref_format(b'./foo')) self.assertFalse(check_ref_format(b'.refs/foo')) self.assertFalse(check_ref_format(b'heads/foo..bar')) self.assertFalse(check_ref_format(b'heads/foo?bar')) self.assertFalse(check_ref_format(b'heads/foo.lock')) self.assertFalse(check_ref_format(b'heads/v@{ation')) self.assertFalse(check_ref_format(b'heads/foo\bar')) ONES = b'1' * 40 TWOS = b'2' * 40 THREES = b'3' * 40 FOURS = b'4' * 40 class PackedRefsFileTests(TestCase): def test_split_ref_line_errors(self): self.assertRaises(errors.PackedRefsException, _split_ref_line, b'singlefield') self.assertRaises(errors.PackedRefsException, _split_ref_line, b'badsha name') self.assertRaises(errors.PackedRefsException, _split_ref_line, ONES + b' bad/../refname') def test_read_without_peeled(self): f = BytesIO(b'\n'.join([ b'# comment', ONES + b' ref/1', TWOS + b' ref/2'])) self.assertEqual([(ONES, b'ref/1'), (TWOS, b'ref/2')], list(read_packed_refs(f))) def test_read_without_peeled_errors(self): f = BytesIO(b'\n'.join([ ONES + b' ref/1', b'^' + TWOS])) self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f)) def test_read_with_peeled(self): f = BytesIO(b'\n'.join([ ONES + b' ref/1', TWOS + b' ref/2', b'^' + THREES, FOURS + b' ref/4'])) self.assertEqual([ (ONES, b'ref/1', None), (TWOS, b'ref/2', THREES), (FOURS, b'ref/4', None), ], list(read_packed_refs_with_peeled(f))) def test_read_with_peeled_errors(self): f = BytesIO(b'\n'.join([ b'^' + TWOS, ONES + b' ref/1'])) self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f)) f = BytesIO(b'\n'.join([ ONES + b' ref/1', b'^' + TWOS, b'^' + THREES])) self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f)) def test_write_with_peeled(self): f = BytesIO() write_packed_refs(f, {b'ref/1': ONES, b'ref/2': TWOS}, {b'ref/1': THREES}) self.assertEqual( b'\n'.join([b'# pack-refs with: peeled', ONES + b' ref/1', b'^' + THREES, TWOS + b' ref/2']) + b'\n', f.getvalue()) def test_write_without_peeled(self): f = BytesIO() write_packed_refs(f, {b'ref/1': ONES, b'ref/2': TWOS}) self.assertEqual(b'\n'.join([ONES + b' ref/1', TWOS + b' ref/2']) + b'\n', f.getvalue()) # Dict of refs that we expect all RefsContainerTests subclasses to define. _TEST_REFS = { b'HEAD': b'42d06bd4b77fed026b154d16493e5deab78f02ec', b'refs/heads/40-char-ref-aaaaaaaaaaaaaaaaaa': b'42d06bd4b77fed026b154d16493e5deab78f02ec', b'refs/heads/master': b'42d06bd4b77fed026b154d16493e5deab78f02ec', b'refs/heads/packed': b'42d06bd4b77fed026b154d16493e5deab78f02ec', b'refs/tags/refs-0.1': b'df6800012397fb85c56e7418dd4eb9405dee075c', b'refs/tags/refs-0.2': b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8', b'refs/heads/loop': b'ref: refs/heads/loop', } class RefsContainerTests(object): def test_keys(self): actual_keys = set(self._refs.keys()) self.assertEqual(set(self._refs.allkeys()), actual_keys) self.assertEqual(set(_TEST_REFS.keys()), actual_keys) actual_keys = self._refs.keys(b'refs/heads') actual_keys.discard(b'loop') self.assertEqual( [b'40-char-ref-aaaaaaaaaaaaaaaaaa', b'master', b'packed'], sorted(actual_keys)) self.assertEqual([b'refs-0.1', b'refs-0.2'], sorted(self._refs.keys(b'refs/tags'))) def test_iter(self): actual_keys = set(self._refs.keys()) self.assertEqual(set(self._refs), actual_keys) self.assertEqual(set(_TEST_REFS.keys()), actual_keys) def test_as_dict(self): # refs/heads/loop does not show up even if it exists expected_refs = dict(_TEST_REFS) del expected_refs[b'refs/heads/loop'] self.assertEqual(expected_refs, self._refs.as_dict()) def test_get_symrefs(self): self._refs.set_symbolic_ref(b'refs/heads/src', b'refs/heads/dst') symrefs = self._refs.get_symrefs() if b'HEAD' in symrefs: symrefs.pop(b'HEAD') self.assertEqual({b'refs/heads/src': b'refs/heads/dst', b'refs/heads/loop': b'refs/heads/loop'}, symrefs) def test_setitem(self): self._refs[b'refs/some/ref'] = ( b'42d06bd4b77fed026b154d16493e5deab78f02ec') self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/some/ref']) self.assertRaises( errors.RefFormatError, self._refs.__setitem__, b'notrefs/foo', b'42d06bd4b77fed026b154d16493e5deab78f02ec') def test_set_if_equals(self): nines = b'9' * 40 self.assertFalse(self._refs.set_if_equals(b'HEAD', b'c0ffee', nines)) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'HEAD']) self.assertTrue(self._refs.set_if_equals( b'HEAD', b'42d06bd4b77fed026b154d16493e5deab78f02ec', nines)) self.assertEqual(nines, self._refs[b'HEAD']) # Setting the ref again is a no-op, but will return True. self.assertTrue(self._refs.set_if_equals(b'HEAD', nines, nines)) self.assertEqual(nines, self._refs[b'HEAD']) self.assertTrue(self._refs.set_if_equals(b'refs/heads/master', None, nines)) self.assertEqual(nines, self._refs[b'refs/heads/master']) self.assertTrue(self._refs.set_if_equals( b'refs/heads/nonexistant', ZERO_SHA, nines)) self.assertEqual(nines, self._refs[b'refs/heads/nonexistant']) def test_add_if_new(self): nines = b'9' * 40 self.assertFalse(self._refs.add_if_new(b'refs/heads/master', nines)) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/heads/master']) self.assertTrue(self._refs.add_if_new(b'refs/some/ref', nines)) self.assertEqual(nines, self._refs[b'refs/some/ref']) def test_set_symbolic_ref(self): self._refs.set_symbolic_ref(b'refs/heads/symbolic', b'refs/heads/master') self.assertEqual(b'ref: refs/heads/master', self._refs.read_loose_ref(b'refs/heads/symbolic')) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/heads/symbolic']) def test_set_symbolic_ref_overwrite(self): nines = b'9' * 40 self.assertFalse(b'refs/heads/symbolic' in self._refs) self._refs[b'refs/heads/symbolic'] = nines self.assertEqual(nines, self._refs.read_loose_ref(b'refs/heads/symbolic')) self._refs.set_symbolic_ref(b'refs/heads/symbolic', b'refs/heads/master') self.assertEqual(b'ref: refs/heads/master', self._refs.read_loose_ref(b'refs/heads/symbolic')) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/heads/symbolic']) def test_check_refname(self): self._refs._check_refname(b'HEAD') self._refs._check_refname(b'refs/stash') self._refs._check_refname(b'refs/heads/foo') self.assertRaises(errors.RefFormatError, self._refs._check_refname, b'refs') self.assertRaises(errors.RefFormatError, self._refs._check_refname, b'notrefs/foo') def test_contains(self): self.assertTrue(b'refs/heads/master' in self._refs) self.assertFalse(b'refs/heads/bar' in self._refs) def test_delitem(self): self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/heads/master']) del self._refs[b'refs/heads/master'] self.assertRaises(KeyError, lambda: self._refs[b'refs/heads/master']) def test_remove_if_equals(self): self.assertFalse(self._refs.remove_if_equals(b'HEAD', b'c0ffee')) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'HEAD']) self.assertTrue(self._refs.remove_if_equals( b'refs/tags/refs-0.2', b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8')) self.assertTrue(self._refs.remove_if_equals( b'refs/tags/refs-0.2', ZERO_SHA)) self.assertFalse(b'refs/tags/refs-0.2' in self._refs) def test_import_refs_name(self): self._refs[b'refs/remotes/origin/other'] = ( b'48d01bd4b77fed026b154d16493e5deab78f02ec') self._refs.import_refs( b'refs/remotes/origin', {b'master': b'42d06bd4b77fed026b154d16493e5deab78f02ec'}) self.assertEqual( b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/remotes/origin/master']) self.assertEqual( b'48d01bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/remotes/origin/other']) def test_import_refs_name_prune(self): self._refs[b'refs/remotes/origin/other'] = ( b'48d01bd4b77fed026b154d16493e5deab78f02ec') self._refs.import_refs( b'refs/remotes/origin', {b'master': b'42d06bd4b77fed026b154d16493e5deab78f02ec'}, prune=True) self.assertEqual( b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/remotes/origin/master']) self.assertNotIn( b'refs/remotes/origin/other', self._refs) def test_watch(self): try: watcher = self._refs.watch() except NotImplementedError: self.skipTest('watching not supported') with watcher: self._refs[b'refs/remotes/origin/other'] = ( b'48d01bd4b77fed026b154d16493e5deab78f02ec') change = next(watcher) self.assertEqual( - (b'refs/remotes/origin/other', None, + (b'refs/remotes/origin/other', b'48d01bd4b77fed026b154d16493e5deab78f02ec'), change) self._refs[b'refs/remotes/origin/other'] = ( b'48d01bd4b77fed026b154d16493e5deab78f02ed') change = next(watcher) self.assertEqual( (b'refs/remotes/origin/other', - b'48d01bd4b77fed026b154d16493e5deab78f02ec', b'48d01bd4b77fed026b154d16493e5deab78f02ed'), change) del self._refs[b'refs/remotes/origin/other'] change = next(watcher) self.assertEqual( (b'refs/remotes/origin/other', - b'48d01bd4b77fed026b154d16493e5deab78f02ed', None), change) class DictRefsContainerTests(RefsContainerTests, TestCase): def setUp(self): TestCase.setUp(self) self._refs = DictRefsContainer(dict(_TEST_REFS)) def test_invalid_refname(self): # FIXME: Move this test into RefsContainerTests, but requires # some way of injecting invalid refs. self._refs._refs[b'refs/stash'] = b'00' * 20 expected_refs = dict(_TEST_REFS) del expected_refs[b'refs/heads/loop'] expected_refs[b'refs/stash'] = b'00' * 20 self.assertEqual(expected_refs, self._refs.as_dict()) class DiskRefsContainerTests(RefsContainerTests, TestCase): def setUp(self): TestCase.setUp(self) self._repo = open_repo('refs.git') self.addCleanup(tear_down_repo, self._repo) self._refs = self._repo.refs def test_get_packed_refs(self): self.assertEqual({ b'refs/heads/packed': b'42d06bd4b77fed026b154d16493e5deab78f02ec', b'refs/tags/refs-0.1': b'df6800012397fb85c56e7418dd4eb9405dee075c', }, self._refs.get_packed_refs()) def test_get_peeled_not_packed(self): # not packed self.assertEqual(None, self._refs.get_peeled(b'refs/tags/refs-0.2')) self.assertEqual(b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8', self._refs[b'refs/tags/refs-0.2']) # packed, known not peelable self.assertEqual(self._refs[b'refs/heads/packed'], self._refs.get_peeled(b'refs/heads/packed')) # packed, peeled self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs.get_peeled(b'refs/tags/refs-0.1')) def test_setitem(self): RefsContainerTests.test_setitem(self) path = os.path.join(self._refs.path, b'refs', b'some', b'ref') with open(path, 'rb') as f: self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', f.read()[:40]) self.assertRaises( OSError, self._refs.__setitem__, b'refs/some/ref/sub', b'42d06bd4b77fed026b154d16493e5deab78f02ec') def test_setitem_packed(self): with open(os.path.join(self._refs.path, b'packed-refs'), 'w') as f: f.write('# pack-refs with: peeled fully-peeled sorted \n') f.write( '42d06bd4b77fed026b154d16493e5deab78f02ec refs/heads/packed\n') # It's allowed to set a new ref on a packed ref, the new ref will be # placed outside on refs/ self._refs[b'refs/heads/packed'] = ( b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8' ) packed_ref_path = os.path.join( self._refs.path, b'refs', b'heads', b'packed') with open(packed_ref_path, 'rb') as f: self.assertEqual( b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8', f.read()[:40]) self.assertRaises( OSError, self._refs.__setitem__, b'refs/heads/packed/sub', b'42d06bd4b77fed026b154d16493e5deab78f02ec') def test_setitem_symbolic(self): ones = b'1' * 40 self._refs[b'HEAD'] = ones self.assertEqual(ones, self._refs[b'HEAD']) # ensure HEAD was not modified f = open(os.path.join(self._refs.path, b'HEAD'), 'rb') v = next(iter(f)).rstrip(b'\n\r') f.close() self.assertEqual(b'ref: refs/heads/master', v) # ensure the symbolic link was written through f = open(os.path.join(self._refs.path, b'refs', b'heads', b'master'), 'rb') self.assertEqual(ones, f.read()[:40]) f.close() def test_set_if_equals(self): RefsContainerTests.test_set_if_equals(self) # ensure symref was followed self.assertEqual(b'9' * 40, self._refs[b'refs/heads/master']) # ensure lockfile was deleted self.assertFalse(os.path.exists( os.path.join(self._refs.path, b'refs', b'heads', b'master.lock'))) self.assertFalse(os.path.exists( os.path.join(self._refs.path, b'HEAD.lock'))) def test_add_if_new_packed(self): # don't overwrite packed ref self.assertFalse(self._refs.add_if_new(b'refs/tags/refs-0.1', b'9' * 40)) self.assertEqual(b'df6800012397fb85c56e7418dd4eb9405dee075c', self._refs[b'refs/tags/refs-0.1']) def test_add_if_new_symbolic(self): # Use an empty repo instead of the default. repo_dir = os.path.join(tempfile.mkdtemp(), 'test') os.makedirs(repo_dir) repo = Repo.init(repo_dir) self.addCleanup(tear_down_repo, repo) refs = repo.refs nines = b'9' * 40 self.assertEqual(b'ref: refs/heads/master', refs.read_ref(b'HEAD')) self.assertFalse(b'refs/heads/master' in refs) self.assertTrue(refs.add_if_new(b'HEAD', nines)) self.assertEqual(b'ref: refs/heads/master', refs.read_ref(b'HEAD')) self.assertEqual(nines, refs[b'HEAD']) self.assertEqual(nines, refs[b'refs/heads/master']) self.assertFalse(refs.add_if_new(b'HEAD', b'1' * 40)) self.assertEqual(nines, refs[b'HEAD']) self.assertEqual(nines, refs[b'refs/heads/master']) def test_follow(self): self.assertEqual(([b'HEAD', b'refs/heads/master'], b'42d06bd4b77fed026b154d16493e5deab78f02ec'), self._refs.follow(b'HEAD')) self.assertEqual(([b'refs/heads/master'], b'42d06bd4b77fed026b154d16493e5deab78f02ec'), self._refs.follow(b'refs/heads/master')) self.assertRaises(KeyError, self._refs.follow, b'refs/heads/loop') def test_delitem(self): RefsContainerTests.test_delitem(self) ref_file = os.path.join(self._refs.path, b'refs', b'heads', b'master') self.assertFalse(os.path.exists(ref_file)) self.assertFalse(b'refs/heads/master' in self._refs.get_packed_refs()) def test_delitem_symbolic(self): self.assertEqual(b'ref: refs/heads/master', self._refs.read_loose_ref(b'HEAD')) del self._refs[b'HEAD'] self.assertRaises(KeyError, lambda: self._refs[b'HEAD']) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs[b'refs/heads/master']) self.assertFalse( os.path.exists(os.path.join(self._refs.path, b'HEAD'))) def test_remove_if_equals_symref(self): # HEAD is a symref, so shouldn't equal its dereferenced value self.assertFalse(self._refs.remove_if_equals( b'HEAD', b'42d06bd4b77fed026b154d16493e5deab78f02ec')) self.assertTrue(self._refs.remove_if_equals( b'refs/heads/master', b'42d06bd4b77fed026b154d16493e5deab78f02ec')) self.assertRaises(KeyError, lambda: self._refs[b'refs/heads/master']) # HEAD is now a broken symref self.assertRaises(KeyError, lambda: self._refs[b'HEAD']) self.assertEqual(b'ref: refs/heads/master', self._refs.read_loose_ref(b'HEAD')) self.assertFalse(os.path.exists( os.path.join(self._refs.path, b'refs', b'heads', b'master.lock'))) self.assertFalse(os.path.exists( os.path.join(self._refs.path, b'HEAD.lock'))) def test_remove_packed_without_peeled(self): refs_file = os.path.join(self._repo.path, 'packed-refs') f = GitFile(refs_file) refs_data = f.read() f.close() f = GitFile(refs_file, 'wb') f.write(b'\n'.join(line for line in refs_data.split(b'\n') if not line or line[0] not in b'#^')) f.close() self._repo = Repo(self._repo.path) refs = self._repo.refs self.assertTrue(refs.remove_if_equals( b'refs/heads/packed', b'42d06bd4b77fed026b154d16493e5deab78f02ec')) def test_remove_if_equals_packed(self): # test removing ref that is only packed self.assertEqual(b'df6800012397fb85c56e7418dd4eb9405dee075c', self._refs[b'refs/tags/refs-0.1']) self.assertTrue( self._refs.remove_if_equals( b'refs/tags/refs-0.1', b'df6800012397fb85c56e7418dd4eb9405dee075c')) self.assertRaises(KeyError, lambda: self._refs[b'refs/tags/refs-0.1']) def test_remove_parent(self): self._refs[b'refs/heads/foo/bar'] = ( b'df6800012397fb85c56e7418dd4eb9405dee075c' ) del self._refs[b'refs/heads/foo/bar'] ref_file = os.path.join( self._refs.path, b'refs', b'heads', b'foo', b'bar', ) self.assertFalse(os.path.exists(ref_file)) ref_file = os.path.join(self._refs.path, b'refs', b'heads', b'foo') self.assertFalse(os.path.exists(ref_file)) ref_file = os.path.join(self._refs.path, b'refs', b'heads') self.assertTrue(os.path.exists(ref_file)) self._refs[b'refs/heads/foo'] = ( b'df6800012397fb85c56e7418dd4eb9405dee075c' ) def test_read_ref(self): self.assertEqual(b'ref: refs/heads/master', self._refs.read_ref(b'HEAD')) self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec', self._refs.read_ref(b'refs/heads/packed')) self.assertEqual(None, self._refs.read_ref(b'nonexistant')) def test_read_loose_ref(self): self._refs[b'refs/heads/foo'] = ( b'df6800012397fb85c56e7418dd4eb9405dee075c' ) self.assertEqual(None, self._refs.read_ref(b'refs/heads/foo/bar')) def test_non_ascii(self): try: encoded_ref = os.fsencode(u'refs/tags/schön') except UnicodeEncodeError: raise SkipTest( "filesystem encoding doesn't support special character") p = os.path.join(os.fsencode(self._repo.path), encoded_ref) with open(p, 'w') as f: f.write('00' * 20) expected_refs = dict(_TEST_REFS) expected_refs[encoded_ref] = b'00' * 20 del expected_refs[b'refs/heads/loop'] self.assertEqual(expected_refs, self._repo.get_refs()) def test_cyrillic(self): if sys.platform in ('darwin', 'win32'): raise SkipTest( "filesystem encoding doesn't support arbitrary bytes") # reported in https://github.com/dulwich/dulwich/issues/608 name = b'\xcd\xee\xe2\xe0\xff\xe2\xe5\xf2\xea\xe01' encoded_ref = b'refs/heads/' + name with open(os.path.join( os.fsencode(self._repo.path), encoded_ref), 'w') as f: f.write('00' * 20) expected_refs = set(_TEST_REFS.keys()) expected_refs.add(encoded_ref) self.assertEqual(expected_refs, set(self._repo.refs.allkeys())) self.assertEqual({r[len(b'refs/'):] for r in expected_refs if r.startswith(b'refs/')}, set(self._repo.refs.subkeys(b'refs/'))) expected_refs.remove(b'refs/heads/loop') expected_refs.add(b'HEAD') self.assertEqual(expected_refs, set(self._repo.get_refs().keys())) _TEST_REFS_SERIALIZED = ( b'42d06bd4b77fed026b154d16493e5deab78f02ec\t' b'refs/heads/40-char-ref-aaaaaaaaaaaaaaaaaa\n' b'42d06bd4b77fed026b154d16493e5deab78f02ec\trefs/heads/master\n' b'42d06bd4b77fed026b154d16493e5deab78f02ec\trefs/heads/packed\n' b'df6800012397fb85c56e7418dd4eb9405dee075c\trefs/tags/refs-0.1\n' b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8\trefs/tags/refs-0.2\n') class InfoRefsContainerTests(TestCase): def test_invalid_refname(self): text = _TEST_REFS_SERIALIZED + b'00' * 20 + b'\trefs/stash\n' refs = InfoRefsContainer(BytesIO(text)) expected_refs = dict(_TEST_REFS) del expected_refs[b'HEAD'] expected_refs[b'refs/stash'] = b'00' * 20 del expected_refs[b'refs/heads/loop'] self.assertEqual(expected_refs, refs.as_dict()) def test_keys(self): refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED)) actual_keys = set(refs.keys()) self.assertEqual(set(refs.allkeys()), actual_keys) expected_refs = dict(_TEST_REFS) del expected_refs[b'HEAD'] del expected_refs[b'refs/heads/loop'] self.assertEqual(set(expected_refs.keys()), actual_keys) actual_keys = refs.keys(b'refs/heads') actual_keys.discard(b'loop') self.assertEqual( [b'40-char-ref-aaaaaaaaaaaaaaaaaa', b'master', b'packed'], sorted(actual_keys)) self.assertEqual([b'refs-0.1', b'refs-0.2'], sorted(refs.keys(b'refs/tags'))) def test_as_dict(self): refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED)) # refs/heads/loop does not show up even if it exists expected_refs = dict(_TEST_REFS) del expected_refs[b'HEAD'] del expected_refs[b'refs/heads/loop'] self.assertEqual(expected_refs, refs.as_dict()) def test_contains(self): refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED)) self.assertTrue(b'refs/heads/master' in refs) self.assertFalse(b'refs/heads/bar' in refs) def test_get_peeled(self): refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED)) # refs/heads/loop does not show up even if it exists self.assertEqual( _TEST_REFS[b'refs/heads/master'], refs.get_peeled(b'refs/heads/master')) class ParseSymrefValueTests(TestCase): def test_valid(self): self.assertEqual( b'refs/heads/foo', parse_symref_value(b'ref: refs/heads/foo')) def test_invalid(self): self.assertRaises(ValueError, parse_symref_value, b'foobar') class StripPeeledRefsTests(TestCase): all_refs = { b'refs/heads/master': b'8843d7f92416211de9ebb963ff4ce28125932878', b'refs/heads/testing': b'186a005b134d8639a58b6731c7c1ea821a6eedba', b'refs/tags/1.0.0': b'a93db4b0360cc635a2b93675010bac8d101f73f0', b'refs/tags/1.0.0^{}': b'a93db4b0360cc635a2b93675010bac8d101f73f0', b'refs/tags/2.0.0': b'0749936d0956c661ac8f8d3483774509c165f89e', b'refs/tags/2.0.0^{}': b'0749936d0956c661ac8f8d3483774509c165f89e', } non_peeled_refs = { b'refs/heads/master': b'8843d7f92416211de9ebb963ff4ce28125932878', b'refs/heads/testing': b'186a005b134d8639a58b6731c7c1ea821a6eedba', b'refs/tags/1.0.0': b'a93db4b0360cc635a2b93675010bac8d101f73f0', b'refs/tags/2.0.0': b'0749936d0956c661ac8f8d3483774509c165f89e', } def test_strip_peeled_refs(self): # Simple check of two dicts self.assertEqual( strip_peeled_refs(self.all_refs), self.non_peeled_refs)