diff --git a/dulwich/_diff_tree.c b/dulwich/_diff_tree.c
index 7c67ea28..3ddc23de 100644
--- a/dulwich/_diff_tree.c
+++ b/dulwich/_diff_tree.c
@@ -1,504 +1,505 @@
/*
* Copyright (C) 2010 Google, Inc.
*
* Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
* General Public License as public by the Free Software Foundation; version 2.0
* or (at your option) any later version. You can redistribute it and/or
* modify it under the terms of either of these two licenses.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* You should have received a copy of the licenses; if not, see
* for a copy of the GNU General Public License
* and for a copy of the Apache
* License, Version 2.0.
*/
#include
#include
#ifdef _MSC_VER
typedef unsigned short mode_t;
#endif
#if PY_MAJOR_VERSION < 3
typedef long Py_hash_t;
#endif
#if PY_MAJOR_VERSION >= 3
#define PyInt_FromLong PyLong_FromLong
#define PyInt_AsLong PyLong_AsLong
#define PyInt_AS_LONG PyLong_AS_LONG
#define PyString_AS_STRING PyBytes_AS_STRING
#define PyString_AsStringAndSize PyBytes_AsStringAndSize
#define PyString_Check PyBytes_Check
#define PyString_CheckExact PyBytes_CheckExact
#define PyString_FromStringAndSize PyBytes_FromStringAndSize
#define PyString_FromString PyBytes_FromString
#define PyString_GET_SIZE PyBytes_GET_SIZE
#define PyString_Size PyBytes_Size
#define _PyString_Join _PyBytes_Join
#endif
static PyObject *tree_entry_cls = NULL, *null_entry = NULL,
*defaultdict_cls = NULL, *int_cls = NULL;
static int block_size;
/**
* Free an array of PyObject pointers, decrementing any references.
*/
static void free_objects(PyObject **objs, Py_ssize_t n)
{
Py_ssize_t i;
for (i = 0; i < n; i++)
Py_XDECREF(objs[i]);
PyMem_Free(objs);
}
/**
* Get the entries of a tree, prepending the given path.
*
- * :param path: The path to prepend, without trailing slashes.
- * :param path_len: The length of path.
- * :param tree: The Tree object to iterate.
- * :param n: Set to the length of result.
- * :return: A (C) array of PyObject pointers to TreeEntry objects for each path
+ * Args:
+ * path: The path to prepend, without trailing slashes.
+ * path_len: The length of path.
+ * tree: The Tree object to iterate.
+ * n: Set to the length of result.
+ * Returns: A (C) array of PyObject pointers to TreeEntry objects for each path
* in tree.
*/
static PyObject **tree_entries(char *path, Py_ssize_t path_len, PyObject *tree,
Py_ssize_t *n)
{
PyObject *iteritems, *items, **result = NULL;
PyObject *old_entry, *name, *sha;
Py_ssize_t i = 0, name_len, new_path_len;
char *new_path;
if (tree == Py_None) {
*n = 0;
result = PyMem_New(PyObject*, 0);
if (!result) {
PyErr_NoMemory();
return NULL;
}
return result;
}
iteritems = PyObject_GetAttrString(tree, "iteritems");
if (!iteritems)
return NULL;
items = PyObject_CallFunctionObjArgs(iteritems, Py_True, NULL);
Py_DECREF(iteritems);
if (items == NULL) {
return NULL;
}
/* The C implementation of iteritems returns a list, so depend on that. */
if (!PyList_Check(items)) {
PyErr_SetString(PyExc_TypeError,
"Tree.iteritems() did not return a list");
return NULL;
}
*n = PyList_Size(items);
result = PyMem_New(PyObject*, *n);
if (!result) {
PyErr_NoMemory();
goto error;
}
for (i = 0; i < *n; i++) {
old_entry = PyList_GetItem(items, i);
if (!old_entry)
goto error;
sha = PyTuple_GetItem(old_entry, 2);
if (!sha)
goto error;
name = PyTuple_GET_ITEM(old_entry, 0);
name_len = PyString_Size(name);
if (PyErr_Occurred())
goto error;
new_path_len = name_len;
if (path_len)
new_path_len += path_len + 1;
new_path = PyMem_Malloc(new_path_len);
if (!new_path) {
PyErr_NoMemory();
goto error;
}
if (path_len) {
memcpy(new_path, path, path_len);
new_path[path_len] = '/';
memcpy(new_path + path_len + 1, PyString_AS_STRING(name), name_len);
} else {
memcpy(new_path, PyString_AS_STRING(name), name_len);
}
#if PY_MAJOR_VERSION >= 3
result[i] = PyObject_CallFunction(tree_entry_cls, "y#OO", new_path,
new_path_len, PyTuple_GET_ITEM(old_entry, 1), sha);
#else
result[i] = PyObject_CallFunction(tree_entry_cls, "s#OO", new_path,
new_path_len, PyTuple_GET_ITEM(old_entry, 1), sha);
#endif
PyMem_Free(new_path);
if (!result[i]) {
goto error;
}
}
Py_DECREF(items);
return result;
error:
if (result)
free_objects(result, i);
Py_DECREF(items);
return NULL;
}
/**
* Use strcmp to compare the paths of two TreeEntry objects.
*/
static int entry_path_cmp(PyObject *entry1, PyObject *entry2)
{
PyObject *path1 = NULL, *path2 = NULL;
int result = 0;
path1 = PyObject_GetAttrString(entry1, "path");
if (!path1)
goto done;
if (!PyString_Check(path1)) {
PyErr_SetString(PyExc_TypeError, "path is not a (byte)string");
goto done;
}
path2 = PyObject_GetAttrString(entry2, "path");
if (!path2)
goto done;
if (!PyString_Check(path2)) {
PyErr_SetString(PyExc_TypeError, "path is not a (byte)string");
goto done;
}
result = strcmp(PyString_AS_STRING(path1), PyString_AS_STRING(path2));
done:
Py_XDECREF(path1);
Py_XDECREF(path2);
return result;
}
static PyObject *py_merge_entries(PyObject *self, PyObject *args)
{
PyObject *tree1, *tree2, **entries1 = NULL, **entries2 = NULL;
PyObject *e1, *e2, *pair, *result = NULL;
Py_ssize_t n1 = 0, n2 = 0, i1 = 0, i2 = 0;
int path_len;
char *path_str;
int cmp;
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "y#OO", &path_str, &path_len, &tree1, &tree2))
#else
if (!PyArg_ParseTuple(args, "s#OO", &path_str, &path_len, &tree1, &tree2))
#endif
return NULL;
entries1 = tree_entries(path_str, path_len, tree1, &n1);
if (!entries1)
goto error;
entries2 = tree_entries(path_str, path_len, tree2, &n2);
if (!entries2)
goto error;
result = PyList_New(0);
if (!result)
goto error;
while (i1 < n1 && i2 < n2) {
cmp = entry_path_cmp(entries1[i1], entries2[i2]);
if (PyErr_Occurred())
goto error;
if (!cmp) {
e1 = entries1[i1++];
e2 = entries2[i2++];
} else if (cmp < 0) {
e1 = entries1[i1++];
e2 = null_entry;
} else {
e1 = null_entry;
e2 = entries2[i2++];
}
pair = PyTuple_Pack(2, e1, e2);
if (!pair)
goto error;
PyList_Append(result, pair);
Py_DECREF(pair);
}
while (i1 < n1) {
pair = PyTuple_Pack(2, entries1[i1++], null_entry);
if (!pair)
goto error;
PyList_Append(result, pair);
Py_DECREF(pair);
}
while (i2 < n2) {
pair = PyTuple_Pack(2, null_entry, entries2[i2++]);
if (!pair)
goto error;
PyList_Append(result, pair);
Py_DECREF(pair);
}
goto done;
error:
Py_XDECREF(result);
result = NULL;
done:
if (entries1)
free_objects(entries1, n1);
if (entries2)
free_objects(entries2, n2);
return result;
}
static PyObject *py_is_tree(PyObject *self, PyObject *args)
{
PyObject *entry, *mode, *result;
long lmode;
if (!PyArg_ParseTuple(args, "O", &entry))
return NULL;
mode = PyObject_GetAttrString(entry, "mode");
if (!mode)
return NULL;
if (mode == Py_None) {
result = Py_False;
Py_INCREF(result);
} else {
lmode = PyInt_AsLong(mode);
if (lmode == -1 && PyErr_Occurred()) {
Py_DECREF(mode);
return NULL;
}
result = PyBool_FromLong(S_ISDIR((mode_t)lmode));
}
Py_DECREF(mode);
return result;
}
static Py_hash_t add_hash(PyObject *get, PyObject *set, char *str, int n)
{
PyObject *str_obj = NULL, *hash_obj = NULL, *value = NULL,
*set_value = NULL;
Py_hash_t hash;
/* It would be nice to hash without copying str into a PyString, but that
* isn't exposed by the API. */
str_obj = PyString_FromStringAndSize(str, n);
if (!str_obj)
goto error;
hash = PyObject_Hash(str_obj);
if (hash == -1)
goto error;
hash_obj = PyInt_FromLong(hash);
if (!hash_obj)
goto error;
value = PyObject_CallFunctionObjArgs(get, hash_obj, NULL);
if (!value)
goto error;
set_value = PyObject_CallFunction(set, "(Ol)", hash_obj,
PyInt_AS_LONG(value) + n);
if (!set_value)
goto error;
Py_DECREF(str_obj);
Py_DECREF(hash_obj);
Py_DECREF(value);
Py_DECREF(set_value);
return 0;
error:
Py_XDECREF(str_obj);
Py_XDECREF(hash_obj);
Py_XDECREF(value);
Py_XDECREF(set_value);
return -1;
}
static PyObject *py_count_blocks(PyObject *self, PyObject *args)
{
PyObject *obj, *chunks = NULL, *chunk, *counts = NULL, *get = NULL,
*set = NULL;
char *chunk_str, *block = NULL;
Py_ssize_t num_chunks, chunk_len;
int i, j, n = 0;
char c;
if (!PyArg_ParseTuple(args, "O", &obj))
goto error;
counts = PyObject_CallFunctionObjArgs(defaultdict_cls, int_cls, NULL);
if (!counts)
goto error;
get = PyObject_GetAttrString(counts, "__getitem__");
set = PyObject_GetAttrString(counts, "__setitem__");
chunks = PyObject_CallMethod(obj, "as_raw_chunks", NULL);
if (!chunks)
goto error;
if (!PyList_Check(chunks)) {
PyErr_SetString(PyExc_TypeError,
"as_raw_chunks() did not return a list");
goto error;
}
num_chunks = PyList_GET_SIZE(chunks);
block = PyMem_New(char, block_size);
if (!block) {
PyErr_NoMemory();
goto error;
}
for (i = 0; i < num_chunks; i++) {
chunk = PyList_GET_ITEM(chunks, i);
if (!PyString_Check(chunk)) {
PyErr_SetString(PyExc_TypeError, "chunk is not a string");
goto error;
}
if (PyString_AsStringAndSize(chunk, &chunk_str, &chunk_len) == -1)
goto error;
for (j = 0; j < chunk_len; j++) {
c = chunk_str[j];
block[n++] = c;
if (c == '\n' || n == block_size) {
if (add_hash(get, set, block, n) == -1)
goto error;
n = 0;
}
}
}
if (n && add_hash(get, set, block, n) == -1)
goto error;
Py_DECREF(chunks);
Py_DECREF(get);
Py_DECREF(set);
PyMem_Free(block);
return counts;
error:
Py_XDECREF(chunks);
Py_XDECREF(get);
Py_XDECREF(set);
Py_XDECREF(counts);
PyMem_Free(block);
return NULL;
}
static PyMethodDef py_diff_tree_methods[] = {
{ "_is_tree", (PyCFunction)py_is_tree, METH_VARARGS, NULL },
{ "_merge_entries", (PyCFunction)py_merge_entries, METH_VARARGS, NULL },
{ "_count_blocks", (PyCFunction)py_count_blocks, METH_VARARGS, NULL },
{ NULL, NULL, 0, NULL }
};
static PyObject *
moduleinit(void)
{
PyObject *m, *objects_mod = NULL, *diff_tree_mod = NULL;
PyObject *block_size_obj = NULL;
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_diff_tree", /* m_name */
NULL, /* m_doc */
-1, /* m_size */
py_diff_tree_methods, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear*/
NULL, /* m_free */
};
m = PyModule_Create(&moduledef);
#else
m = Py_InitModule("_diff_tree", py_diff_tree_methods);
#endif
if (!m)
goto error;
objects_mod = PyImport_ImportModule("dulwich.objects");
if (!objects_mod)
goto error;
tree_entry_cls = PyObject_GetAttrString(objects_mod, "TreeEntry");
Py_DECREF(objects_mod);
if (!tree_entry_cls)
goto error;
diff_tree_mod = PyImport_ImportModule("dulwich.diff_tree");
if (!diff_tree_mod)
goto error;
null_entry = PyObject_GetAttrString(diff_tree_mod, "_NULL_ENTRY");
if (!null_entry)
goto error;
block_size_obj = PyObject_GetAttrString(diff_tree_mod, "_BLOCK_SIZE");
if (!block_size_obj)
goto error;
block_size = (int)PyInt_AsLong(block_size_obj);
if (PyErr_Occurred())
goto error;
defaultdict_cls = PyObject_GetAttrString(diff_tree_mod, "defaultdict");
if (!defaultdict_cls)
goto error;
/* This is kind of hacky, but I don't know of a better way to get the
* PyObject* version of int. */
int_cls = PyDict_GetItemString(PyEval_GetBuiltins(), "int");
if (!int_cls) {
PyErr_SetString(PyExc_NameError, "int");
goto error;
}
Py_DECREF(diff_tree_mod);
return m;
error:
Py_XDECREF(objects_mod);
Py_XDECREF(diff_tree_mod);
Py_XDECREF(null_entry);
Py_XDECREF(block_size_obj);
Py_XDECREF(defaultdict_cls);
Py_XDECREF(int_cls);
return NULL;
}
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
PyInit__diff_tree(void)
{
return moduleinit();
}
#else
PyMODINIT_FUNC
init_diff_tree(void)
{
moduleinit();
}
#endif
diff --git a/dulwich/config.py b/dulwich/config.py
index aebd6fd1..d19038f3 100644
--- a/dulwich/config.py
+++ b/dulwich/config.py
@@ -1,563 +1,563 @@
# config.py - Reading and writing Git config files
# Copyright (C) 2011-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Reading and writing Git configuration files.
TODO:
* preserve formatting when updating configuration files
* treat subsection names as case-insensitive for [branch.foo] style
subsections
"""
import errno
import os
import sys
from collections import (
OrderedDict,
)
try:
from collections.abc import (
Iterable,
MutableMapping,
)
except ImportError: # python < 3.7
from collections import (
Iterable,
MutableMapping,
)
from dulwich.file import GitFile
SENTINAL = object()
def lower_key(key):
if isinstance(key, (bytes, str)):
return key.lower()
if isinstance(key, Iterable):
return type(key)(
map(lower_key, key)
)
return key
class CaseInsensitiveDict(OrderedDict):
@classmethod
def make(cls, dict_in=None):
if isinstance(dict_in, cls):
return dict_in
out = cls()
if dict_in is None:
return out
if not isinstance(dict_in, MutableMapping):
raise TypeError
for key, value in dict_in.items():
out[key] = value
return out
def __setitem__(self, key, value, **kwargs):
key = lower_key(key)
super(CaseInsensitiveDict, self).__setitem__(key, value, **kwargs)
def __getitem__(self, item):
key = lower_key(item)
return super(CaseInsensitiveDict, self).__getitem__(key)
def get(self, key, default=SENTINAL):
try:
return self[key]
except KeyError:
pass
if default is SENTINAL:
return type(self)()
return default
def setdefault(self, key, default=SENTINAL):
try:
return self[key]
except KeyError:
self[key] = self.get(key, default)
return self[key]
class Config(object):
"""A Git configuration."""
def get(self, section, name):
"""Retrieve the contents of a configuration setting.
Args:
section: Tuple with section name and optional subsection namee
subsection: Subsection name
Returns:
Contents of the setting
Raises:
KeyError: if the value is not set
"""
raise NotImplementedError(self.get)
def get_boolean(self, section, name, default=None):
"""Retrieve a configuration setting as boolean.
Args:
section: Tuple with section name and optional subsection name
name: Name of the setting, including section and possible
subsection.
Returns:
Contents of the setting
Raises:
KeyError: if the value is not set
"""
try:
value = self.get(section, name)
except KeyError:
return default
if value.lower() == b"true":
return True
elif value.lower() == b"false":
return False
raise ValueError("not a valid boolean string: %r" % value)
def set(self, section, name, value):
"""Set a configuration value.
Args:
section: Tuple with section name and optional subsection namee
name: Name of the configuration value, including section
and optional subsection
value: value of the setting
"""
raise NotImplementedError(self.set)
def iteritems(self, section):
"""Iterate over the configuration pairs for a specific section.
Args:
section: Tuple with section name and optional subsection namee
Returns:
Iterator over (name, value) pairs
"""
raise NotImplementedError(self.iteritems)
def itersections(self):
"""Iterate over the sections.
- :return: Iterator over section tuples
+ Returns: Iterator over section tuples
"""
raise NotImplementedError(self.itersections)
def has_section(self, name):
"""Check if a specified section exists.
Args:
name: Name of section to check for
Returns:
boolean indicating whether the section exists
"""
return (name in self.itersections())
class ConfigDict(Config, MutableMapping):
"""Git configuration stored in a dictionary."""
def __init__(self, values=None, encoding=None):
"""Create a new ConfigDict."""
if encoding is None:
encoding = sys.getdefaultencoding()
self.encoding = encoding
self._values = CaseInsensitiveDict.make(values)
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self._values)
def __eq__(self, other):
return (
isinstance(other, self.__class__) and
other._values == self._values)
def __getitem__(self, key):
return self._values.__getitem__(key)
def __setitem__(self, key, value):
return self._values.__setitem__(key, value)
def __delitem__(self, key):
return self._values.__delitem__(key)
def __iter__(self):
return self._values.__iter__()
def __len__(self):
return self._values.__len__()
@classmethod
def _parse_setting(cls, name):
parts = name.split(".")
if len(parts) == 3:
return (parts[0], parts[1], parts[2])
else:
return (parts[0], None, parts[1])
def _check_section_and_name(self, section, name):
if not isinstance(section, tuple):
section = (section, )
section = tuple([
subsection.encode(self.encoding)
if not isinstance(subsection, bytes) else subsection
for subsection in section
])
if not isinstance(name, bytes):
name = name.encode(self.encoding)
return section, name
def get(self, section, name):
section, name = self._check_section_and_name(section, name)
if len(section) > 1:
try:
return self._values[section][name]
except KeyError:
pass
return self._values[(section[0],)][name]
def set(self, section, name, value):
section, name = self._check_section_and_name(section, name)
if type(value) not in (bool, bytes):
value = value.encode(self.encoding)
self._values.setdefault(section)[name] = value
def iteritems(self, section):
return self._values.get(section).items()
def itersections(self):
return self._values.keys()
def _format_string(value):
if (value.startswith(b" ") or
value.startswith(b"\t") or
value.endswith(b" ") or
b'#' in value or
value.endswith(b"\t")):
return b'"' + _escape_value(value) + b'"'
else:
return _escape_value(value)
_ESCAPE_TABLE = {
ord(b"\\"): ord(b"\\"),
ord(b"\""): ord(b"\""),
ord(b"n"): ord(b"\n"),
ord(b"t"): ord(b"\t"),
ord(b"b"): ord(b"\b"),
}
_COMMENT_CHARS = [ord(b"#"), ord(b";")]
_WHITESPACE_CHARS = [ord(b"\t"), ord(b" ")]
def _parse_string(value):
value = bytearray(value.strip())
ret = bytearray()
whitespace = bytearray()
in_quotes = False
i = 0
while i < len(value):
c = value[i]
if c == ord(b"\\"):
i += 1
try:
v = _ESCAPE_TABLE[value[i]]
except IndexError:
raise ValueError(
"escape character in %r at %d before end of string" %
(value, i))
except KeyError:
raise ValueError(
"escape character followed by unknown character "
"%s at %d in %r" % (value[i], i, value))
if whitespace:
ret.extend(whitespace)
whitespace = bytearray()
ret.append(v)
elif c == ord(b"\""):
in_quotes = (not in_quotes)
elif c in _COMMENT_CHARS and not in_quotes:
# the rest of the line is a comment
break
elif c in _WHITESPACE_CHARS:
whitespace.append(c)
else:
if whitespace:
ret.extend(whitespace)
whitespace = bytearray()
ret.append(c)
i += 1
if in_quotes:
raise ValueError("missing end quote")
return bytes(ret)
def _escape_value(value):
"""Escape a value."""
value = value.replace(b"\\", b"\\\\")
value = value.replace(b"\n", b"\\n")
value = value.replace(b"\t", b"\\t")
value = value.replace(b"\"", b"\\\"")
return value
def _check_variable_name(name):
for i in range(len(name)):
c = name[i:i+1]
if not c.isalnum() and c != b'-':
return False
return True
def _check_section_name(name):
for i in range(len(name)):
c = name[i:i+1]
if not c.isalnum() and c not in (b'-', b'.'):
return False
return True
def _strip_comments(line):
comment_bytes = {ord(b"#"), ord(b";")}
quote = ord(b'"')
string_open = False
# Normalize line to bytearray for simple 2/3 compatibility
for i, character in enumerate(bytearray(line)):
# Comment characters outside balanced quotes denote comment start
if character == quote:
string_open = not string_open
elif not string_open and character in comment_bytes:
return line[:i]
return line
class ConfigFile(ConfigDict):
"""A Git configuration file, like .git/config or ~/.gitconfig.
"""
@classmethod
def from_file(cls, f):
"""Read configuration from a file-like object."""
ret = cls()
section = None
setting = None
for lineno, line in enumerate(f.readlines()):
line = line.lstrip()
if setting is None:
# Parse section header ("[bla]")
if len(line) > 0 and line[:1] == b"[":
line = _strip_comments(line).rstrip()
try:
last = line.index(b"]")
except ValueError:
raise ValueError("expected trailing ]")
pts = line[1:last].split(b" ", 1)
line = line[last+1:]
if len(pts) == 2:
if pts[1][:1] != b"\"" or pts[1][-1:] != b"\"":
raise ValueError(
"Invalid subsection %r" % pts[1])
else:
pts[1] = pts[1][1:-1]
if not _check_section_name(pts[0]):
raise ValueError("invalid section name %r" %
pts[0])
section = (pts[0], pts[1])
else:
if not _check_section_name(pts[0]):
raise ValueError(
"invalid section name %r" % pts[0])
pts = pts[0].split(b".", 1)
if len(pts) == 2:
section = (pts[0], pts[1])
else:
section = (pts[0], )
ret._values.setdefault(section)
if _strip_comments(line).strip() == b"":
continue
if section is None:
raise ValueError("setting %r without section" % line)
try:
setting, value = line.split(b"=", 1)
except ValueError:
setting = line
value = b"true"
setting = setting.strip()
if not _check_variable_name(setting):
raise ValueError("invalid variable name %s" % setting)
if value.endswith(b"\\\n"):
continuation = value[:-2]
else:
continuation = None
value = _parse_string(value)
ret._values[section][setting] = value
setting = None
else: # continuation line
if line.endswith(b"\\\n"):
continuation += line[:-2]
else:
continuation += line
value = _parse_string(continuation)
ret._values[section][setting] = value
continuation = None
setting = None
return ret
@classmethod
def from_path(cls, path):
"""Read configuration from a file on disk."""
with GitFile(path, 'rb') as f:
ret = cls.from_file(f)
ret.path = path
return ret
def write_to_path(self, path=None):
"""Write configuration to a file on disk."""
if path is None:
path = self.path
with GitFile(path, 'wb') as f:
self.write_to_file(f)
def write_to_file(self, f):
"""Write configuration to a file-like object."""
for section, values in self._values.items():
try:
section_name, subsection_name = section
except ValueError:
(section_name, ) = section
subsection_name = None
if subsection_name is None:
f.write(b"[" + section_name + b"]\n")
else:
f.write(b"[" + section_name +
b" \"" + subsection_name + b"\"]\n")
for key, value in values.items():
if value is True:
value = b"true"
elif value is False:
value = b"false"
else:
value = _format_string(value)
f.write(b"\t" + key + b" = " + value + b"\n")
class StackedConfig(Config):
"""Configuration which reads from multiple config files.."""
def __init__(self, backends, writable=None):
self.backends = backends
self.writable = writable
def __repr__(self):
return "<%s for %r>" % (self.__class__.__name__, self.backends)
@classmethod
def default(cls):
return cls(cls.default_backends())
@classmethod
def default_backends(cls):
"""Retrieve the default configuration.
See git-config(1) for details on the files searched.
"""
paths = []
paths.append(os.path.expanduser("~/.gitconfig"))
xdg_config_home = os.environ.get(
"XDG_CONFIG_HOME", os.path.expanduser("~/.config/"),
)
paths.append(os.path.join(xdg_config_home, "git", "config"))
if "GIT_CONFIG_NOSYSTEM" not in os.environ:
paths.append("/etc/gitconfig")
backends = []
for path in paths:
try:
cf = ConfigFile.from_path(path)
except (IOError, OSError) as e:
if e.errno != errno.ENOENT:
raise
else:
continue
backends.append(cf)
return backends
def get(self, section, name):
if not isinstance(section, tuple):
section = (section, )
for backend in self.backends:
try:
return backend.get(section, name)
except KeyError:
pass
raise KeyError(name)
def set(self, section, name, value):
if self.writable is None:
raise NotImplementedError(self.set)
return self.writable.set(section, name, value)
def parse_submodules(config):
"""Parse a gitmodules GitConfig file, returning submodules.
Args:
config: A `ConfigFile`
Returns:
list of tuples (submodule path, url, name),
where name is quoted part of the section's name.
"""
for section in config.keys():
section_kind, section_name = section
if section_kind == b'submodule':
sm_path = config.get(section, b'path')
sm_url = config.get(section, b'url')
yield (sm_path, sm_url, section_name)
diff --git a/dulwich/contrib/swift.py b/dulwich/contrib/swift.py
index 8436c252..eb96d2e9 100644
--- a/dulwich/contrib/swift.py
+++ b/dulwich/contrib/swift.py
@@ -1,1056 +1,1058 @@
# swift.py -- Repo implementation atop OpenStack SWIFT
# Copyright (C) 2013 eNovance SAS
#
# Author: Fabien Boucher
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Repo implementation atop OpenStack SWIFT."""
# TODO: Refactor to share more code with dulwich/repo.py.
# TODO(fbo): Second attempt to _send() must be notified via real log
# TODO(fbo): More logs for operations
import os
import stat
import zlib
import tempfile
import posixpath
try:
import urlparse
except ImportError:
import urllib.parse as urlparse
from io import BytesIO
try:
from ConfigParser import ConfigParser
except ImportError:
from configparser import ConfigParser
from geventhttpclient import HTTPClient
from dulwich.greenthreads import (
GreenThreadsMissingObjectFinder,
GreenThreadsObjectStoreIterator,
)
from dulwich.lru_cache import LRUSizeCache
from dulwich.objects import (
Blob,
Commit,
Tree,
Tag,
S_ISGITLINK,
)
from dulwich.object_store import (
PackBasedObjectStore,
PACKDIR,
INFODIR,
)
from dulwich.pack import (
PackData,
Pack,
PackIndexer,
PackStreamCopier,
write_pack_header,
compute_file_sha,
iter_sha1,
write_pack_index_v2,
load_pack_index_file,
read_pack_header,
_compute_object_size,
unpack_object,
write_pack_object,
)
from dulwich.protocol import TCP_GIT_PORT
from dulwich.refs import (
InfoRefsContainer,
read_info_refs,
write_info_refs,
)
from dulwich.repo import (
BaseRepo,
OBJECTDIR,
)
from dulwich.server import (
Backend,
TCPGitServer,
)
try:
from simplejson import loads as json_loads
from simplejson import dumps as json_dumps
except ImportError:
from json import loads as json_loads
from json import dumps as json_dumps
import sys
"""
# Configuration file sample
[swift]
# Authentication URL (Keystone or Swift)
auth_url = http://127.0.0.1:5000/v2.0
# Authentication version to use
auth_ver = 2
# The tenant and username separated by a semicolon
username = admin;admin
# The user password
password = pass
# The Object storage region to use (auth v2) (Default RegionOne)
region_name = RegionOne
# The Object storage endpoint URL to use (auth v2) (Default internalURL)
endpoint_type = internalURL
# Concurrency to use for parallel tasks (Default 10)
concurrency = 10
# Size of the HTTP pool (Default 10)
http_pool_length = 10
# Timeout delay for HTTP connections (Default 20)
http_timeout = 20
# Chunk size to read from pack (Bytes) (Default 12228)
chunk_length = 12228
# Cache size (MBytes) (Default 20)
cache_length = 20
"""
class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
def __len__(self):
while len(self.finder.objects_to_send):
for _ in range(0, len(self.finder.objects_to_send)):
sha = self.finder.next()
self._shas.append(sha)
return len(self._shas)
class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
def next(self):
while True:
if not self.objects_to_send:
return None
(sha, name, leaf) = self.objects_to_send.pop()
if sha not in self.sha_done:
break
if not leaf:
info = self.object_store.pack_info_get(sha)
if info[0] == Commit.type_num:
self.add_todo([(info[2], "", False)])
elif info[0] == Tree.type_num:
self.add_todo([tuple(i) for i in info[1]])
elif info[0] == Tag.type_num:
self.add_todo([(info[1], None, False)])
if sha in self._tagged:
self.add_todo([(self._tagged[sha], None, True)])
self.sha_done.add(sha)
self.progress("counting objects: %d\r" % len(self.sha_done))
return (sha, name)
def load_conf(path=None, file=None):
"""Load configuration in global var CONF
Args:
path: The path to the configuration file
file: If provided read instead the file like object
"""
conf = ConfigParser()
if file:
try:
conf.read_file(file, path)
except AttributeError:
# read_file only exists in Python3
conf.readfp(file)
return conf
confpath = None
if not path:
try:
confpath = os.environ['DULWICH_SWIFT_CFG']
except KeyError:
raise Exception("You need to specify a configuration file")
else:
confpath = path
if not os.path.isfile(confpath):
raise Exception("Unable to read configuration file %s" % confpath)
conf.read(confpath)
return conf
def swift_load_pack_index(scon, filename):
"""Read a pack index file from Swift
Args:
scon: a `SwiftConnector` instance
filename: Path to the index file objectise
Returns: a `PackIndexer` instance
"""
with scon.get_object(filename) as f:
return load_pack_index_file(filename, f)
def pack_info_create(pack_data, pack_index):
pack = Pack.from_objects(pack_data, pack_index)
info = {}
for obj in pack.iterobjects():
# Commit
if obj.type_num == Commit.type_num:
info[obj.id] = (obj.type_num, obj.parents, obj.tree)
# Tree
elif obj.type_num == Tree.type_num:
shas = [(s, n, not stat.S_ISDIR(m)) for
n, m, s in obj.items() if not S_ISGITLINK(m)]
info[obj.id] = (obj.type_num, shas)
# Blob
elif obj.type_num == Blob.type_num:
info[obj.id] = None
# Tag
elif obj.type_num == Tag.type_num:
info[obj.id] = (obj.type_num, obj.object[1])
return zlib.compress(json_dumps(info))
def load_pack_info(filename, scon=None, file=None):
if not file:
f = scon.get_object(filename)
else:
f = file
if not f:
return None
try:
return json_loads(zlib.decompress(f.read()))
finally:
f.close()
class SwiftException(Exception):
pass
class SwiftConnector(object):
"""A Connector to swift that manage authentication and errors catching
"""
def __init__(self, root, conf):
""" Initialize a SwiftConnector
Args:
root: The swift container that will act as Git bare repository
conf: A ConfigParser Object
"""
self.conf = conf
self.auth_ver = self.conf.get("swift", "auth_ver")
if self.auth_ver not in ["1", "2"]:
raise NotImplementedError(
"Wrong authentication version use either 1 or 2")
self.auth_url = self.conf.get("swift", "auth_url")
self.user = self.conf.get("swift", "username")
self.password = self.conf.get("swift", "password")
self.concurrency = self.conf.getint('swift', 'concurrency') or 10
self.http_timeout = self.conf.getint('swift', 'http_timeout') or 20
self.http_pool_length = \
self.conf.getint('swift', 'http_pool_length') or 10
self.region_name = self.conf.get("swift", "region_name") or "RegionOne"
self.endpoint_type = \
self.conf.get("swift", "endpoint_type") or "internalURL"
self.cache_length = self.conf.getint("swift", "cache_length") or 20
self.chunk_length = self.conf.getint("swift", "chunk_length") or 12228
self.root = root
block_size = 1024 * 12 # 12KB
if self.auth_ver == "1":
self.storage_url, self.token = self.swift_auth_v1()
else:
self.storage_url, self.token = self.swift_auth_v2()
token_header = {'X-Auth-Token': str(self.token)}
self.httpclient = \
HTTPClient.from_url(str(self.storage_url),
concurrency=self.http_pool_length,
block_size=block_size,
connection_timeout=self.http_timeout,
network_timeout=self.http_timeout,
headers=token_header)
self.base_path = str(posixpath.join(
urlparse.urlparse(self.storage_url).path, self.root))
def swift_auth_v1(self):
self.user = self.user.replace(";", ":")
auth_httpclient = HTTPClient.from_url(
self.auth_url,
connection_timeout=self.http_timeout,
network_timeout=self.http_timeout,
)
headers = {'X-Auth-User': self.user,
'X-Auth-Key': self.password}
path = urlparse.urlparse(self.auth_url).path
ret = auth_httpclient.request('GET', path, headers=headers)
# Should do something with redirections (301 in my case)
if ret.status_code < 200 or ret.status_code >= 300:
raise SwiftException('AUTH v1.0 request failed on ' +
'%s with error code %s (%s)'
% (str(auth_httpclient.get_base_url()) +
path, ret.status_code,
str(ret.items())))
storage_url = ret['X-Storage-Url']
token = ret['X-Auth-Token']
return storage_url, token
def swift_auth_v2(self):
self.tenant, self.user = self.user.split(';')
auth_dict = {}
auth_dict['auth'] = {'passwordCredentials':
{
'username': self.user,
'password': self.password,
},
'tenantName': self.tenant}
auth_json = json_dumps(auth_dict)
headers = {'Content-Type': 'application/json'}
auth_httpclient = HTTPClient.from_url(
self.auth_url,
connection_timeout=self.http_timeout,
network_timeout=self.http_timeout,
)
path = urlparse.urlparse(self.auth_url).path
if not path.endswith('tokens'):
path = posixpath.join(path, 'tokens')
ret = auth_httpclient.request('POST', path,
body=auth_json,
headers=headers)
if ret.status_code < 200 or ret.status_code >= 300:
raise SwiftException('AUTH v2.0 request failed on ' +
'%s with error code %s (%s)'
% (str(auth_httpclient.get_base_url()) +
path, ret.status_code,
str(ret.items())))
auth_ret_json = json_loads(ret.read())
token = auth_ret_json['access']['token']['id']
catalogs = auth_ret_json['access']['serviceCatalog']
object_store = [o_store for o_store in catalogs if
o_store['type'] == 'object-store'][0]
endpoints = object_store['endpoints']
endpoint = [endp for endp in endpoints if
endp["region"] == self.region_name][0]
return endpoint[self.endpoint_type], token
def test_root_exists(self):
"""Check that Swift container exist
- :return: True if exist or None it not
+ Returns: True if exist or None it not
"""
ret = self.httpclient.request('HEAD', self.base_path)
if ret.status_code == 404:
return None
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('HEAD request failed with error code %s'
% ret.status_code)
return True
def create_root(self):
"""Create the Swift container
- :raise: `SwiftException` if unable to create
+ Raises:
+ SwiftException: if unable to create
"""
if not self.test_root_exists():
ret = self.httpclient.request('PUT', self.base_path)
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('PUT request failed with error code %s'
% ret.status_code)
def get_container_objects(self):
"""Retrieve objects list in a container
- :return: A list of dict that describe objects
+ Returns: A list of dict that describe objects
or None if container does not exist
"""
qs = '?format=json'
path = self.base_path + qs
ret = self.httpclient.request('GET', path)
if ret.status_code == 404:
return None
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('GET request failed with error code %s'
% ret.status_code)
content = ret.read()
return json_loads(content)
def get_object_stat(self, name):
"""Retrieve object stat
Args:
name: The object name
Returns:
A dict that describe the object or None if object does not exist
"""
path = self.base_path + '/' + name
ret = self.httpclient.request('HEAD', path)
if ret.status_code == 404:
return None
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('HEAD request failed with error code %s'
% ret.status_code)
resp_headers = {}
for header, value in ret.items():
resp_headers[header.lower()] = value
return resp_headers
def put_object(self, name, content):
"""Put an object
Args:
name: The object name
content: A file object
Raises:
SwiftException: if unable to create
"""
content.seek(0)
data = content.read()
path = self.base_path + '/' + name
headers = {'Content-Length': str(len(data))}
def _send():
ret = self.httpclient.request('PUT', path,
body=data,
headers=headers)
return ret
try:
# Sometime got Broken Pipe - Dirty workaround
ret = _send()
except Exception:
# Second attempt work
ret = _send()
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('PUT request failed with error code %s'
% ret.status_code)
def get_object(self, name, range=None):
"""Retrieve an object
Args:
name: The object name
range: A string range like "0-10" to
retrieve specified bytes in object content
Returns:
A file like instance or bytestring if range is specified
"""
headers = {}
if range:
headers['Range'] = 'bytes=%s' % range
path = self.base_path + '/' + name
ret = self.httpclient.request('GET', path, headers=headers)
if ret.status_code == 404:
return None
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('GET request failed with error code %s'
% ret.status_code)
content = ret.read()
if range:
return content
return BytesIO(content)
def del_object(self, name):
"""Delete an object
Args:
name: The object name
Raises:
SwiftException: if unable to delete
"""
path = self.base_path + '/' + name
ret = self.httpclient.request('DELETE', path)
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('DELETE request failed with error code %s'
% ret.status_code)
def del_root(self):
"""Delete the root container by removing container content
- :raise: `SwiftException` if unable to delete
+ Raises:
+ SwiftException: if unable to delete
"""
for obj in self.get_container_objects():
self.del_object(obj['name'])
ret = self.httpclient.request('DELETE', self.base_path)
if ret.status_code < 200 or ret.status_code > 300:
raise SwiftException('DELETE request failed with error code %s'
% ret.status_code)
class SwiftPackReader(object):
"""A SwiftPackReader that mimic read and sync method
The reader allows to read a specified amount of bytes from
a given offset of a Swift object. A read offset is kept internaly.
The reader will read from Swift a specified amount of data to complete
its internal buffer. chunk_length specifiy the amount of data
to read from Swift.
"""
def __init__(self, scon, filename, pack_length):
"""Initialize a SwiftPackReader
Args:
scon: a `SwiftConnector` instance
filename: the pack filename
pack_length: The size of the pack object
"""
self.scon = scon
self.filename = filename
self.pack_length = pack_length
self.offset = 0
self.base_offset = 0
self.buff = b''
self.buff_length = self.scon.chunk_length
def _read(self, more=False):
if more:
self.buff_length = self.buff_length * 2
offset = self.base_offset
r = min(self.base_offset + self.buff_length, self.pack_length)
ret = self.scon.get_object(self.filename, range="%s-%s" % (offset, r))
self.buff = ret
def read(self, length):
"""Read a specified amount of Bytes form the pack object
Args:
length: amount of bytes to read
Returns:
a bytestring
"""
end = self.offset+length
if self.base_offset + end > self.pack_length:
data = self.buff[self.offset:]
self.offset = end
return data
if end > len(self.buff):
# Need to read more from swift
self._read(more=True)
return self.read(length)
data = self.buff[self.offset:end]
self.offset = end
return data
def seek(self, offset):
"""Seek to a specified offset
Args:
offset: the offset to seek to
"""
self.base_offset = offset
self._read()
self.offset = 0
def read_checksum(self):
"""Read the checksum from the pack
- :return: the checksum bytestring
+ Returns: the checksum bytestring
"""
return self.scon.get_object(self.filename, range="-20")
class SwiftPackData(PackData):
"""The data contained in a packfile.
We use the SwiftPackReader to read bytes from packs stored in Swift
using the Range header feature of Swift.
"""
def __init__(self, scon, filename):
""" Initialize a SwiftPackReader
Args:
scon: a `SwiftConnector` instance
filename: the pack filename
"""
self.scon = scon
self._filename = filename
self._header_size = 12
headers = self.scon.get_object_stat(self._filename)
self.pack_length = int(headers['content-length'])
pack_reader = SwiftPackReader(self.scon, self._filename,
self.pack_length)
(version, self._num_objects) = read_pack_header(pack_reader.read)
self._offset_cache = LRUSizeCache(1024*1024*self.scon.cache_length,
compute_size=_compute_object_size)
self.pack = None
def get_object_at(self, offset):
if offset in self._offset_cache:
return self._offset_cache[offset]
assert offset >= self._header_size
pack_reader = SwiftPackReader(self.scon, self._filename,
self.pack_length)
pack_reader.seek(offset)
unpacked, _ = unpack_object(pack_reader.read)
return (unpacked.pack_type_num, unpacked._obj())
def get_stored_checksum(self):
pack_reader = SwiftPackReader(self.scon, self._filename,
self.pack_length)
return pack_reader.read_checksum()
def close(self):
pass
class SwiftPack(Pack):
"""A Git pack object.
Same implementation as pack.Pack except that _idx_load and
_data_load are bounded to Swift version of load_pack_index and
PackData.
"""
def __init__(self, *args, **kwargs):
self.scon = kwargs['scon']
del kwargs['scon']
super(SwiftPack, self).__init__(*args, **kwargs)
self._pack_info_path = self._basename + '.info'
self._pack_info = None
self._pack_info_load = lambda: load_pack_info(self._pack_info_path,
self.scon)
self._idx_load = lambda: swift_load_pack_index(self.scon,
self._idx_path)
self._data_load = lambda: SwiftPackData(self.scon, self._data_path)
@property
def pack_info(self):
"""The pack data object being used."""
if self._pack_info is None:
self._pack_info = self._pack_info_load()
return self._pack_info
class SwiftObjectStore(PackBasedObjectStore):
"""A Swift Object Store
Allow to manage a bare Git repository from Openstack Swift.
This object store only supports pack files and not loose objects.
"""
def __init__(self, scon):
"""Open a Swift object store.
Args:
scon: A `SwiftConnector` instance
"""
super(SwiftObjectStore, self).__init__()
self.scon = scon
self.root = self.scon.root
self.pack_dir = posixpath.join(OBJECTDIR, PACKDIR)
self._alternates = None
def _update_pack_cache(self):
objects = self.scon.get_container_objects()
pack_files = [o['name'].replace(".pack", "")
for o in objects if o['name'].endswith(".pack")]
ret = []
for basename in pack_files:
pack = SwiftPack(basename, scon=self.scon)
self._pack_cache[basename] = pack
ret.append(pack)
return ret
def _iter_loose_objects(self):
"""Loose objects are not supported by this repository
"""
return []
def iter_shas(self, finder):
"""An iterator over pack's ObjectStore.
- :return: a `ObjectStoreIterator` or `GreenThreadsObjectStoreIterator`
+ Returns: a `ObjectStoreIterator` or `GreenThreadsObjectStoreIterator`
instance if gevent is enabled
"""
shas = iter(finder.next, None)
return PackInfoObjectStoreIterator(
self, shas, finder, self.scon.concurrency)
def find_missing_objects(self, *args, **kwargs):
kwargs['concurrency'] = self.scon.concurrency
return PackInfoMissingObjectFinder(self, *args, **kwargs)
def pack_info_get(self, sha):
for pack in self.packs:
if sha in pack:
return pack.pack_info[sha]
def _collect_ancestors(self, heads, common=set()):
def _find_parents(commit):
for pack in self.packs:
if commit in pack:
try:
parents = pack.pack_info[commit][1]
except KeyError:
# Seems to have no parents
return []
return parents
bases = set()
commits = set()
queue = []
queue.extend(heads)
while queue:
e = queue.pop(0)
if e in common:
bases.add(e)
elif e not in commits:
commits.add(e)
parents = _find_parents(e)
queue.extend(parents)
return (commits, bases)
def add_pack(self):
"""Add a new pack to this object store.
- :return: Fileobject to write to and a commit function to
+ Returns: Fileobject to write to and a commit function to
call when the pack is finished.
"""
f = BytesIO()
def commit():
f.seek(0)
pack = PackData(file=f, filename="")
entries = pack.sorted_entries()
if len(entries):
basename = posixpath.join(self.pack_dir,
"pack-%s" %
iter_sha1(entry[0] for
entry in entries))
index = BytesIO()
write_pack_index_v2(index, entries, pack.get_stored_checksum())
self.scon.put_object(basename + ".pack", f)
f.close()
self.scon.put_object(basename + ".idx", index)
index.close()
final_pack = SwiftPack(basename, scon=self.scon)
final_pack.check_length_and_checksum()
self._add_cached_pack(basename, final_pack)
return final_pack
else:
return None
def abort():
pass
return f, commit, abort
def add_object(self, obj):
self.add_objects([(obj, None), ])
def _pack_cache_stale(self):
return False
def _get_loose_object(self, sha):
return None
def add_thin_pack(self, read_all, read_some):
"""Read a thin pack
Read it from a stream and complete it in a temporary file.
Then the pack and the corresponding index file are uploaded to Swift.
"""
fd, path = tempfile.mkstemp(prefix='tmp_pack_')
f = os.fdopen(fd, 'w+b')
try:
indexer = PackIndexer(f, resolve_ext_ref=self.get_raw)
copier = PackStreamCopier(read_all, read_some, f,
delta_iter=indexer)
copier.verify()
return self._complete_thin_pack(f, path, copier, indexer)
finally:
f.close()
os.unlink(path)
def _complete_thin_pack(self, f, path, copier, indexer):
entries = list(indexer)
# Update the header with the new number of objects.
f.seek(0)
write_pack_header(f, len(entries) + len(indexer.ext_refs()))
# Must flush before reading (http://bugs.python.org/issue3207)
f.flush()
# Rescan the rest of the pack, computing the SHA with the new header.
new_sha = compute_file_sha(f, end_ofs=-20)
# Must reposition before writing (http://bugs.python.org/issue3207)
f.seek(0, os.SEEK_CUR)
# Complete the pack.
for ext_sha in indexer.ext_refs():
assert len(ext_sha) == 20
type_num, data = self.get_raw(ext_sha)
offset = f.tell()
crc32 = write_pack_object(f, type_num, data, sha=new_sha)
entries.append((ext_sha, offset, crc32))
pack_sha = new_sha.digest()
f.write(pack_sha)
f.flush()
# Move the pack in.
entries.sort()
pack_base_name = posixpath.join(
self.pack_dir,
'pack-' + iter_sha1(e[0] for e in entries).decode(
sys.getfilesystemencoding()))
self.scon.put_object(pack_base_name + '.pack', f)
# Write the index.
filename = pack_base_name + '.idx'
index_file = BytesIO()
write_pack_index_v2(index_file, entries, pack_sha)
self.scon.put_object(filename, index_file)
# Write pack info.
f.seek(0)
pack_data = PackData(filename="", file=f)
index_file.seek(0)
pack_index = load_pack_index_file('', index_file)
serialized_pack_info = pack_info_create(pack_data, pack_index)
f.close()
index_file.close()
pack_info_file = BytesIO(serialized_pack_info)
filename = pack_base_name + '.info'
self.scon.put_object(filename, pack_info_file)
pack_info_file.close()
# Add the pack to the store and return it.
final_pack = SwiftPack(pack_base_name, scon=self.scon)
final_pack.check_length_and_checksum()
self._add_cached_pack(pack_base_name, final_pack)
return final_pack
class SwiftInfoRefsContainer(InfoRefsContainer):
"""Manage references in info/refs object.
"""
def __init__(self, scon, store):
self.scon = scon
self.filename = 'info/refs'
self.store = store
f = self.scon.get_object(self.filename)
if not f:
f = BytesIO(b'')
super(SwiftInfoRefsContainer, self).__init__(f)
def _load_check_ref(self, name, old_ref):
self._check_refname(name)
f = self.scon.get_object(self.filename)
if not f:
return {}
refs = read_info_refs(f)
if old_ref is not None:
if refs[name] != old_ref:
return False
return refs
def _write_refs(self, refs):
f = BytesIO()
f.writelines(write_info_refs(refs, self.store))
self.scon.put_object(self.filename, f)
def set_if_equals(self, name, old_ref, new_ref):
"""Set a refname to new_ref only if it currently equals old_ref.
"""
if name == 'HEAD':
return True
refs = self._load_check_ref(name, old_ref)
if not isinstance(refs, dict):
return False
refs[name] = new_ref
self._write_refs(refs)
self._refs[name] = new_ref
return True
def remove_if_equals(self, name, old_ref):
"""Remove a refname only if it currently equals old_ref.
"""
if name == 'HEAD':
return True
refs = self._load_check_ref(name, old_ref)
if not isinstance(refs, dict):
return False
del refs[name]
self._write_refs(refs)
del self._refs[name]
return True
def allkeys(self):
try:
self._refs['HEAD'] = self._refs['refs/heads/master']
except KeyError:
pass
return self._refs.keys()
class SwiftRepo(BaseRepo):
def __init__(self, root, conf):
"""Init a Git bare Repository on top of a Swift container.
References are managed in info/refs objects by
`SwiftInfoRefsContainer`. The root attribute is the Swift
container that contain the Git bare repository.
Args:
root: The container which contains the bare repo
conf: A ConfigParser object
"""
self.root = root.lstrip('/')
self.conf = conf
self.scon = SwiftConnector(self.root, self.conf)
objects = self.scon.get_container_objects()
if not objects:
raise Exception('There is not any GIT repo here : %s' % self.root)
objects = [o['name'].split('/')[0] for o in objects]
if OBJECTDIR not in objects:
raise Exception('This repository (%s) is not bare.' % self.root)
self.bare = True
self._controldir = self.root
object_store = SwiftObjectStore(self.scon)
refs = SwiftInfoRefsContainer(self.scon, object_store)
BaseRepo.__init__(self, object_store, refs)
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
- :return: True if permissions can be trusted, False otherwise.
+ Returns: True if permissions can be trusted, False otherwise.
"""
return False
def _put_named_file(self, filename, contents):
"""Put an object in a Swift container
Args:
filename: the path to the object to put on Swift
contents: the content as bytestring
"""
with BytesIO() as f:
f.write(contents)
self.scon.put_object(filename, f)
@classmethod
def init_bare(cls, scon, conf):
"""Create a new bare repository.
Args:
scon: a `SwiftConnector` instance
conf: a ConfigParser object
Returns:
a `SwiftRepo` instance
"""
scon.create_root()
for obj in [posixpath.join(OBJECTDIR, PACKDIR),
posixpath.join(INFODIR, 'refs')]:
scon.put_object(obj, BytesIO(b''))
ret = cls(scon.root, conf)
ret._init_files(True)
return ret
class SwiftSystemBackend(Backend):
def __init__(self, logger, conf):
self.conf = conf
self.logger = logger
def open_repository(self, path):
self.logger.info('opening repository at %s', path)
return SwiftRepo(path, self.conf)
def cmd_daemon(args):
"""Entry point for starting a TCP git server."""
import optparse
parser = optparse.OptionParser()
parser.add_option("-l", "--listen_address", dest="listen_address",
default="127.0.0.1",
help="Binding IP address.")
parser.add_option("-p", "--port", dest="port", type=int,
default=TCP_GIT_PORT,
help="Binding TCP port.")
parser.add_option("-c", "--swift_config", dest="swift_config",
default="",
help="Path to the configuration file for Swift backend.")
options, args = parser.parse_args(args)
try:
import gevent
import geventhttpclient # noqa: F401
except ImportError:
print("gevent and geventhttpclient libraries are mandatory "
" for use the Swift backend.")
sys.exit(1)
import gevent.monkey
gevent.monkey.patch_socket()
from dulwich import log_utils
logger = log_utils.getLogger(__name__)
conf = load_conf(options.swift_config)
backend = SwiftSystemBackend(logger, conf)
log_utils.default_logging_config()
server = TCPGitServer(backend, options.listen_address,
port=options.port)
server.serve_forever()
def cmd_init(args):
import optparse
parser = optparse.OptionParser()
parser.add_option("-c", "--swift_config", dest="swift_config",
default="",
help="Path to the configuration file for Swift backend.")
options, args = parser.parse_args(args)
conf = load_conf(options.swift_config)
if args == []:
parser.error("missing repository name")
repo = args[0]
scon = SwiftConnector(repo, conf)
SwiftRepo.init_bare(scon, conf)
def main(argv=sys.argv):
commands = {
"init": cmd_init,
"daemon": cmd_daemon,
}
if len(sys.argv) < 2:
print("Usage: %s <%s> [OPTIONS...]" % (
sys.argv[0], "|".join(commands.keys())))
sys.exit(1)
cmd = sys.argv[1]
if cmd not in commands:
print("No such subcommand: %s" % cmd)
sys.exit(1)
commands[cmd](sys.argv[2:])
if __name__ == '__main__':
main()
diff --git a/dulwich/file.py b/dulwich/file.py
index 5bcb7725..19d9e2e7 100644
--- a/dulwich/file.py
+++ b/dulwich/file.py
@@ -1,195 +1,196 @@
# file.py -- Safe access to git files
# Copyright (C) 2010 Google, Inc.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Safe access to git files."""
import errno
import io
import os
import sys
import tempfile
def ensure_dir_exists(dirname):
"""Ensure a directory exists, creating if necessary."""
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != errno.EEXIST:
raise
def _fancy_rename(oldname, newname):
"""Rename file with temporary backup file to rollback if rename fails"""
if not os.path.exists(newname):
try:
os.rename(oldname, newname)
except OSError:
raise
return
# destination file exists
try:
(fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname, dir=".")
os.close(fd)
os.remove(tmpfile)
except OSError:
# either file could not be created (e.g. permission problem)
# or could not be deleted (e.g. rude virus scanner)
raise
try:
os.rename(newname, tmpfile)
except OSError:
raise # no rename occurred
try:
os.rename(oldname, newname)
except OSError:
os.rename(tmpfile, newname)
raise
os.remove(tmpfile)
def GitFile(filename, mode='rb', bufsize=-1):
"""Create a file object that obeys the git file locking protocol.
- :return: a builtin file object or a _GitFile object
+ Returns: a builtin file object or a _GitFile object
- :note: See _GitFile for a description of the file locking protocol.
+ Note: See _GitFile for a description of the file locking protocol.
Only read-only and write-only (binary) modes are supported; r+, w+, and a
are not. To read and write from the same file, you can take advantage of
the fact that opening a file for write does not actually open the file you
request.
"""
if 'a' in mode:
raise IOError('append mode not supported for Git files')
if '+' in mode:
raise IOError('read/write mode not supported for Git files')
if 'b' not in mode:
raise IOError('text mode not supported for Git files')
if 'w' in mode:
return _GitFile(filename, mode, bufsize)
else:
return io.open(filename, mode, bufsize)
class FileLocked(Exception):
"""File is already locked."""
def __init__(self, filename, lockfilename):
self.filename = filename
self.lockfilename = lockfilename
super(FileLocked, self).__init__(filename, lockfilename)
class _GitFile(object):
"""File that follows the git locking protocol for writes.
All writes to a file foo will be written into foo.lock in the same
directory, and the lockfile will be renamed to overwrite the original file
on close.
- :note: You *must* call close() or abort() on a _GitFile for the lock to be
+ Note: You *must* call close() or abort() on a _GitFile for the lock to be
released. Typically this will happen in a finally block.
"""
PROXY_PROPERTIES = set(['closed', 'encoding', 'errors', 'mode', 'name',
'newlines', 'softspace'])
PROXY_METHODS = ('__iter__', 'flush', 'fileno', 'isatty', 'read',
'readline', 'readlines', 'seek', 'tell',
'truncate', 'write', 'writelines')
def __init__(self, filename, mode, bufsize):
self._filename = filename
if isinstance(self._filename, bytes):
self._lockfilename = self._filename + b'.lock'
else:
self._lockfilename = self._filename + '.lock'
try:
fd = os.open(
self._lockfilename,
os.O_RDWR | os.O_CREAT | os.O_EXCL |
getattr(os, "O_BINARY", 0))
except OSError as e:
if e.errno == errno.EEXIST:
raise FileLocked(filename, self._lockfilename)
raise
self._file = os.fdopen(fd, mode, bufsize)
self._closed = False
for method in self.PROXY_METHODS:
setattr(self, method, getattr(self._file, method))
def abort(self):
"""Close and discard the lockfile without overwriting the target.
If the file is already closed, this is a no-op.
"""
if self._closed:
return
self._file.close()
try:
os.remove(self._lockfilename)
self._closed = True
except OSError as e:
# The file may have been removed already, which is ok.
if e.errno != errno.ENOENT:
raise
self._closed = True
def close(self):
"""Close this file, saving the lockfile over the original.
- :note: If this method fails, it will attempt to delete the lockfile.
+ Note: If this method fails, it will attempt to delete the lockfile.
However, it is not guaranteed to do so (e.g. if a filesystem
becomes suddenly read-only), which will prevent future writes to
this file until the lockfile is removed manually.
- :raises OSError: if the original file could not be overwritten. The
+ Raises:
+ OSError: if the original file could not be overwritten. The
lock file is still closed, so further attempts to write to the same
file object will raise ValueError.
"""
if self._closed:
return
os.fsync(self._file.fileno())
self._file.close()
try:
if getattr(os, 'replace', None) is not None:
os.replace(self._lockfilename, self._filename)
else:
if sys.platform != 'win32':
os.rename(self._lockfilename, self._filename)
else:
# Windows versions prior to Vista don't support atomic
# renames
_fancy_rename(self._lockfilename, self._filename)
finally:
self.abort()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def __getattr__(self, name):
"""Proxy property calls to the underlying file."""
if name in self.PROXY_PROPERTIES:
return getattr(self._file, name)
raise AttributeError(name)
diff --git a/dulwich/ignore.py b/dulwich/ignore.py
index a04d29dc..2bcfecc7 100644
--- a/dulwich/ignore.py
+++ b/dulwich/ignore.py
@@ -1,374 +1,374 @@
# Copyright (C) 2017 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Parsing of gitignore files.
For details for the matching rules, see https://git-scm.com/docs/gitignore
"""
import os.path
import re
import sys
def _translate_segment(segment):
if segment == b"*":
return b'[^/]+'
res = b""
i, n = 0, len(segment)
while i < n:
c = segment[i:i+1]
i = i+1
if c == b'*':
res += b'[^/]*'
elif c == b'?':
res += b'[^/]'
elif c == b'[':
j = i
if j < n and segment[j:j+1] == b'!':
j = j+1
if j < n and segment[j:j+1] == b']':
j = j+1
while j < n and segment[j:j+1] != b']':
j = j+1
if j >= n:
res += b'\\['
else:
stuff = segment[i:j].replace(b'\\', b'\\\\')
i = j+1
if stuff.startswith(b'!'):
stuff = b'^' + stuff[1:]
elif stuff.startswith(b'^'):
stuff = b'\\' + stuff
res += b'[' + stuff + b']'
else:
res += re.escape(c)
return res
def translate(pat):
"""Translate a shell PATTERN to a regular expression.
There is no way to quote meta-characters.
Originally copied from fnmatch in Python 2.7, but modified for Dulwich
to cope with features in Git ignore patterns.
"""
res = b'(?ms)'
if b'/' not in pat[:-1]:
# If there's no slash, this is a filename-based match
res += b'(.*/)?'
if pat.startswith(b'**/'):
# Leading **/
pat = pat[2:]
res += b'(.*/)?'
if pat.startswith(b'/'):
pat = pat[1:]
for i, segment in enumerate(pat.split(b'/')):
if segment == b'**':
res += b'(/.*)?'
continue
else:
res += ((re.escape(b'/') if i > 0 else b'') +
_translate_segment(segment))
if not pat.endswith(b'/'):
res += b'/?'
return res + b'\\Z'
def read_ignore_patterns(f):
"""Read a git ignore file.
Args:
f: File-like object to read from
Returns: List of patterns
"""
for line in f:
line = line.rstrip(b"\r\n")
# Ignore blank lines, they're used for readability.
if not line:
continue
if line.startswith(b'#'):
# Comment
continue
# Trailing spaces are ignored unless they are quoted with a backslash.
while line.endswith(b' ') and not line.endswith(b'\\ '):
line = line[:-1]
line = line.replace(b'\\ ', b' ')
yield line
def match_pattern(path, pattern, ignorecase=False):
"""Match a gitignore-style pattern against a path.
Args:
path: Path to match
pattern: Pattern to match
ignorecase: Whether to do case-sensitive matching
Returns:
bool indicating whether the pattern matched
"""
return Pattern(pattern, ignorecase).match(path)
class Pattern(object):
"""A single ignore pattern."""
def __init__(self, pattern, ignorecase=False):
self.pattern = pattern
self.ignorecase = ignorecase
if pattern[0:1] == b'!':
self.is_exclude = False
pattern = pattern[1:]
else:
if pattern[0:1] == b'\\':
pattern = pattern[1:]
self.is_exclude = True
flags = 0
if self.ignorecase:
flags = re.IGNORECASE
self._re = re.compile(translate(pattern), flags)
def __bytes__(self):
return self.pattern
def __str__(self):
return self.pattern.decode(sys.getfilesystemencoding())
def __eq__(self, other):
return (type(self) == type(other) and
self.pattern == other.pattern and
self.ignorecase == other.ignorecase)
def __repr__(self):
return "%s(%s, %r)" % (
type(self).__name__, self.pattern, self.ignorecase)
def match(self, path):
"""Try to match a path against this ignore pattern.
Args:
path: Path to match (relative to ignore location)
Returns: boolean
"""
return bool(self._re.match(path))
class IgnoreFilter(object):
def __init__(self, patterns, ignorecase=False):
self._patterns = []
self._ignorecase = ignorecase
for pattern in patterns:
self.append_pattern(pattern)
def append_pattern(self, pattern):
"""Add a pattern to the set."""
self._patterns.append(Pattern(pattern, self._ignorecase))
def find_matching(self, path):
"""Yield all matching patterns for path.
Args:
path: Path to match
Returns:
Iterator over iterators
"""
if not isinstance(path, bytes):
path = path.encode(sys.getfilesystemencoding())
for pattern in self._patterns:
if pattern.match(path):
yield pattern
def is_ignored(self, path):
"""Check whether a path is ignored.
For directories, include a trailing slash.
- :return: status is None if file is not mentioned, True if it is
+ Returns: status is None if file is not mentioned, True if it is
included, False if it is explicitly excluded.
"""
status = None
for pattern in self.find_matching(path):
status = pattern.is_exclude
return status
@classmethod
def from_path(cls, path, ignorecase=False):
with open(path, 'rb') as f:
ret = cls(read_ignore_patterns(f), ignorecase)
ret._path = path
return ret
def __repr__(self):
if getattr(self, '_path', None) is None:
return "<%s>" % (type(self).__name__)
else:
return "%s.from_path(%r)" % (type(self).__name__, self._path)
class IgnoreFilterStack(object):
"""Check for ignore status in multiple filters."""
def __init__(self, filters):
self._filters = filters
def is_ignored(self, path):
"""Check whether a path is explicitly included or excluded in ignores.
Args:
path: Path to check
Returns:
None if the file is not mentioned, True if it is included,
False if it is explicitly excluded.
"""
status = None
for filter in self._filters:
status = filter.is_ignored(path)
if status is not None:
return status
return status
def default_user_ignore_filter_path(config):
"""Return default user ignore filter path.
Args:
config: A Config object
Returns:
Path to a global ignore file
"""
try:
return config.get((b'core', ), b'excludesFile')
except KeyError:
pass
xdg_config_home = os.environ.get(
"XDG_CONFIG_HOME", os.path.expanduser("~/.config/"),
)
return os.path.join(xdg_config_home, 'git', 'ignore')
class IgnoreFilterManager(object):
"""Ignore file manager."""
def __init__(self, top_path, global_filters, ignorecase):
self._path_filters = {}
self._top_path = top_path
self._global_filters = global_filters
self._ignorecase = ignorecase
def __repr__(self):
return "%s(%s, %r, %r)" % (
type(self).__name__, self._top_path,
self._global_filters,
self._ignorecase)
def _load_path(self, path):
try:
return self._path_filters[path]
except KeyError:
pass
p = os.path.join(self._top_path, path, '.gitignore')
try:
self._path_filters[path] = IgnoreFilter.from_path(
p, self._ignorecase)
except IOError:
self._path_filters[path] = None
return self._path_filters[path]
def find_matching(self, path):
"""Find matching patterns for path.
Stops after the first ignore file with matches.
Args:
path: Path to check
Returns:
Iterator over Pattern instances
"""
if os.path.isabs(path):
raise ValueError('%s is an absolute path' % path)
filters = [(0, f) for f in self._global_filters]
if os.path.sep != '/':
path = path.replace(os.path.sep, '/')
parts = path.split('/')
for i in range(len(parts)+1):
dirname = '/'.join(parts[:i])
for s, f in filters:
relpath = '/'.join(parts[s:i])
if i < len(parts):
# Paths leading up to the final part are all directories,
# so need a trailing slash.
relpath += '/'
matches = list(f.find_matching(relpath))
if matches:
return iter(matches)
ignore_filter = self._load_path(dirname)
if ignore_filter is not None:
filters.insert(0, (i, ignore_filter))
return iter([])
def is_ignored(self, path):
"""Check whether a path is explicitly included or excluded in ignores.
Args:
path: Path to check
Returns:
None if the file is not mentioned, True if it is included,
False if it is explicitly excluded.
"""
matches = list(self.find_matching(path))
if matches:
return matches[-1].is_exclude
return None
@classmethod
def from_repo(cls, repo):
"""Create a IgnoreFilterManager from a repository.
Args:
repo: Repository object
Returns:
A `IgnoreFilterManager` object
"""
global_filters = []
for p in [
os.path.join(repo.controldir(), 'info', 'exclude'),
default_user_ignore_filter_path(repo.get_config_stack())]:
try:
global_filters.append(IgnoreFilter.from_path(p))
except IOError:
pass
config = repo.get_config_stack()
ignorecase = config.get_boolean((b'core'), (b'ignorecase'), False)
return cls(repo.path, global_filters, ignorecase)
diff --git a/dulwich/patch.py b/dulwich/patch.py
index 272488eb..ea1bdfe1 100644
--- a/dulwich/patch.py
+++ b/dulwich/patch.py
@@ -1,364 +1,374 @@
# patch.py -- For dealing with packed-style patches.
# Copyright (C) 2009-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Classes for dealing with git am-style patches.
These patches are basically unified diffs with some extra metadata tacked
on.
"""
from difflib import SequenceMatcher
import email.parser
import time
from dulwich.objects import (
Blob,
Commit,
S_ISGITLINK,
)
FIRST_FEW_BYTES = 8000
def write_commit_patch(f, commit, contents, progress, version=None,
encoding=None):
"""Write a individual file patch.
- :param commit: Commit object
- :param progress: Tuple with current patch number and total.
- :return: tuple with filename and contents
+ Args:
+ commit: Commit object
+ progress: Tuple with current patch number and total.
+ Returns:
+ tuple with filename and contents
"""
encoding = encoding or getattr(f, "encoding", "ascii")
if isinstance(contents, str):
contents = contents.encode(encoding)
(num, total) = progress
f.write(b"From " + commit.id + b" " +
time.ctime(commit.commit_time).encode(encoding) + b"\n")
f.write(b"From: " + commit.author + b"\n")
f.write(b"Date: " +
time.strftime("%a, %d %b %Y %H:%M:%S %Z").encode(encoding) + b"\n")
f.write(("Subject: [PATCH %d/%d] " % (num, total)).encode(encoding) +
commit.message + b"\n")
f.write(b"\n")
f.write(b"---\n")
try:
import subprocess
p = subprocess.Popen(["diffstat"], stdout=subprocess.PIPE,
stdin=subprocess.PIPE)
except (ImportError, OSError):
pass # diffstat not available?
else:
(diffstat, _) = p.communicate(contents)
f.write(diffstat)
f.write(b"\n")
f.write(contents)
f.write(b"-- \n")
if version is None:
from dulwich import __version__ as dulwich_version
f.write(b"Dulwich %d.%d.%d\n" % dulwich_version)
else:
f.write(version.encode(encoding) + b"\n")
def get_summary(commit):
"""Determine the summary line for use in a filename.
- :param commit: Commit
- :return: Summary string
+ Args:
+ commit: Commit
+ Returns: Summary string
"""
decoded = commit.message.decode(errors='replace')
return decoded.splitlines()[0].replace(" ", "-")
# Unified Diff
def _format_range_unified(start, stop):
'Convert range to the "ed" format'
# Per the diff spec at http://www.unix.org/single_unix_specification/
beginning = start + 1 # lines start numbering with one
length = stop - start
if length == 1:
return '{}'.format(beginning)
if not length:
beginning -= 1 # empty ranges begin at line just before the range
return '{},{}'.format(beginning, length)
def unified_diff(a, b, fromfile='', tofile='', fromfiledate='',
tofiledate='', n=3, lineterm='\n'):
"""difflib.unified_diff that can detect "No newline at end of file" as
original "git diff" does.
Based on the same function in Python2.7 difflib.py
"""
started = False
for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
if not started:
started = True
fromdate = '\t{}'.format(fromfiledate) if fromfiledate else ''
todate = '\t{}'.format(tofiledate) if tofiledate else ''
yield '--- {}{}{}'.format(
fromfile.decode("ascii"),
fromdate,
lineterm
).encode('ascii')
yield '+++ {}{}{}'.format(
tofile.decode("ascii"),
todate,
lineterm
).encode('ascii')
first, last = group[0], group[-1]
file1_range = _format_range_unified(first[1], last[2])
file2_range = _format_range_unified(first[3], last[4])
yield '@@ -{} +{} @@{}'.format(
file1_range,
file2_range,
lineterm
).encode('ascii')
for tag, i1, i2, j1, j2 in group:
if tag == 'equal':
for line in a[i1:i2]:
yield b' ' + line
continue
if tag in ('replace', 'delete'):
for line in a[i1:i2]:
if not line[-1:] == b'\n':
line += b'\n\\ No newline at end of file\n'
yield b'-' + line
if tag in ('replace', 'insert'):
for line in b[j1:j2]:
if not line[-1:] == b'\n':
line += b'\n\\ No newline at end of file\n'
yield b'+' + line
def is_binary(content):
"""See if the first few bytes contain any null characters.
- :param content: Bytestring to check for binary content
+ Args:
+ content: Bytestring to check for binary content
"""
return b'\0' in content[:FIRST_FEW_BYTES]
def shortid(hexsha):
if hexsha is None:
return b"0" * 7
else:
return hexsha[:7]
def patch_filename(p, root):
if p is None:
return b"/dev/null"
else:
return root + b"/" + p
def write_object_diff(f, store, old_file, new_file, diff_binary=False):
"""Write the diff for an object.
- :param f: File-like object to write to
- :param store: Store to retrieve objects from, if necessary
- :param old_file: (path, mode, hexsha) tuple
- :param new_file: (path, mode, hexsha) tuple
- :param diff_binary: Whether to diff files even if they
+ Args:
+ f: File-like object to write to
+ store: Store to retrieve objects from, if necessary
+ old_file: (path, mode, hexsha) tuple
+ new_file: (path, mode, hexsha) tuple
+ diff_binary: Whether to diff files even if they
are considered binary files by is_binary().
- :note: the tuple elements should be None for nonexistant files
+ Note: the tuple elements should be None for nonexistant files
"""
(old_path, old_mode, old_id) = old_file
(new_path, new_mode, new_id) = new_file
patched_old_path = patch_filename(old_path, b"a")
patched_new_path = patch_filename(new_path, b"b")
def content(mode, hexsha):
if hexsha is None:
return Blob.from_string(b'')
elif S_ISGITLINK(mode):
return Blob.from_string(b"Submodule commit " + hexsha + b"\n")
else:
return store[hexsha]
def lines(content):
if not content:
return []
else:
return content.splitlines()
f.writelines(gen_diff_header(
(old_path, new_path), (old_mode, new_mode), (old_id, new_id)))
old_content = content(old_mode, old_id)
new_content = content(new_mode, new_id)
if not diff_binary and (
is_binary(old_content.data) or is_binary(new_content.data)):
binary_diff = (
b"Binary files "
+ patched_old_path
+ b" and "
+ patched_new_path
+ b" differ\n"
)
f.write(binary_diff)
else:
f.writelines(unified_diff(lines(old_content), lines(new_content),
patched_old_path, patched_new_path))
# TODO(jelmer): Support writing unicode, rather than bytes.
def gen_diff_header(paths, modes, shas):
"""Write a blob diff header.
- :param paths: Tuple with old and new path
- :param modes: Tuple with old and new modes
- :param shas: Tuple with old and new shas
+ Args:
+ paths: Tuple with old and new path
+ modes: Tuple with old and new modes
+ shas: Tuple with old and new shas
"""
(old_path, new_path) = paths
(old_mode, new_mode) = modes
(old_sha, new_sha) = shas
if old_path is None and new_path is not None:
old_path = new_path
if new_path is None and old_path is not None:
new_path = old_path
old_path = patch_filename(old_path, b"a")
new_path = patch_filename(new_path, b"b")
yield b"diff --git " + old_path + b" " + new_path + b"\n"
if old_mode != new_mode:
if new_mode is not None:
if old_mode is not None:
yield ("old file mode %o\n" % old_mode).encode('ascii')
yield ("new file mode %o\n" % new_mode).encode('ascii')
else:
yield ("deleted file mode %o\n" % old_mode).encode('ascii')
yield b"index " + shortid(old_sha) + b".." + shortid(new_sha)
if new_mode is not None and old_mode is not None:
yield (" %o" % new_mode).encode('ascii')
yield b"\n"
# TODO(jelmer): Support writing unicode, rather than bytes.
def write_blob_diff(f, old_file, new_file):
"""Write blob diff.
- :param f: File-like object to write to
- :param old_file: (path, mode, hexsha) tuple (None if nonexisting)
- :param new_file: (path, mode, hexsha) tuple (None if nonexisting)
+ Args:
+ f: File-like object to write to
+ old_file: (path, mode, hexsha) tuple (None if nonexisting)
+ new_file: (path, mode, hexsha) tuple (None if nonexisting)
- :note: The use of write_object_diff is recommended over this function.
+ Note: The use of write_object_diff is recommended over this function.
"""
(old_path, old_mode, old_blob) = old_file
(new_path, new_mode, new_blob) = new_file
patched_old_path = patch_filename(old_path, b"a")
patched_new_path = patch_filename(new_path, b"b")
def lines(blob):
if blob is not None:
return blob.splitlines()
else:
return []
f.writelines(gen_diff_header(
(old_path, new_path), (old_mode, new_mode),
(getattr(old_blob, "id", None), getattr(new_blob, "id", None))))
old_contents = lines(old_blob)
new_contents = lines(new_blob)
f.writelines(unified_diff(old_contents, new_contents,
patched_old_path, patched_new_path))
def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
"""Write tree diff.
- :param f: File-like object to write to.
- :param old_tree: Old tree id
- :param new_tree: New tree id
- :param diff_binary: Whether to diff files even if they
+ Args:
+ f: File-like object to write to.
+ old_tree: Old tree id
+ new_tree: New tree id
+ diff_binary: Whether to diff files even if they
are considered binary files by is_binary().
"""
changes = store.tree_changes(old_tree, new_tree)
for (oldpath, newpath), (oldmode, newmode), (oldsha, newsha) in changes:
write_object_diff(f, store, (oldpath, oldmode, oldsha),
(newpath, newmode, newsha), diff_binary=diff_binary)
def git_am_patch_split(f, encoding=None):
"""Parse a git-am-style patch and split it up into bits.
- :param f: File-like object to parse
- :param encoding: Encoding to use when creating Git objects
- :return: Tuple with commit object, diff contents and git version
+ Args:
+ f: File-like object to parse
+ encoding: Encoding to use when creating Git objects
+ Returns: Tuple with commit object, diff contents and git version
"""
encoding = encoding or getattr(f, "encoding", "ascii")
encoding = encoding or "ascii"
contents = f.read()
if (isinstance(contents, bytes) and
getattr(email.parser, "BytesParser", None)):
parser = email.parser.BytesParser()
msg = parser.parsebytes(contents)
else:
parser = email.parser.Parser()
msg = parser.parsestr(contents)
return parse_patch_message(msg, encoding)
def parse_patch_message(msg, encoding=None):
"""Extract a Commit object and patch from an e-mail message.
- :param msg: An email message (email.message.Message)
- :param encoding: Encoding to use to encode Git commits
- :return: Tuple with commit object, diff contents and git version
+ Args:
+ msg: An email message (email.message.Message)
+ encoding: Encoding to use to encode Git commits
+ Returns: Tuple with commit object, diff contents and git version
"""
c = Commit()
c.author = msg["from"].encode(encoding)
c.committer = msg["from"].encode(encoding)
try:
patch_tag_start = msg["subject"].index("[PATCH")
except ValueError:
subject = msg["subject"]
else:
close = msg["subject"].index("] ", patch_tag_start)
subject = msg["subject"][close+2:]
c.message = (subject.replace("\n", "") + "\n").encode(encoding)
first = True
body = msg.get_payload(decode=True)
lines = body.splitlines(True)
line_iter = iter(lines)
for line in line_iter:
if line == b"---\n":
break
if first:
if line.startswith(b"From: "):
c.author = line[len(b"From: "):].rstrip()
else:
c.message += b"\n" + line
first = False
else:
c.message += line
diff = b""
for line in line_iter:
if line == b"-- \n":
break
diff += line
try:
version = next(line_iter).rstrip(b"\n")
except StopIteration:
version = None
return c, diff, version
diff --git a/dulwich/porcelain.py b/dulwich/porcelain.py
index 210c74fb..8db8c92f 100644
--- a/dulwich/porcelain.py
+++ b/dulwich/porcelain.py
@@ -1,1532 +1,1581 @@
# porcelain.py -- Porcelain-like layer on top of Dulwich
# Copyright (C) 2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Simple wrapper that provides porcelain-like functions on top of Dulwich.
Currently implemented:
* archive
* add
* branch{_create,_delete,_list}
* check-ignore
* checkout
* clone
* commit
* commit-tree
* daemon
* describe
* diff-tree
* fetch
* init
* ls-files
* ls-remote
* ls-tree
* pull
* push
* rm
* remote{_add}
* receive-pack
* reset
* rev-list
* tag{_create,_delete,_list}
* upload-pack
* update-server-info
* status
* symbolic-ref
These functions are meant to behave similarly to the git subcommands.
Differences in behaviour are considered bugs.
Functions should generally accept both unicode strings and bytestrings
"""
from collections import namedtuple
from contextlib import (
closing,
contextmanager,
)
from io import BytesIO, RawIOBase
import datetime
import os
import posixpath
import stat
import sys
import time
from dulwich.archive import (
tar_stream,
)
from dulwich.client import (
get_transport_and_path,
)
from dulwich.config import (
StackedConfig,
)
from dulwich.diff_tree import (
CHANGE_ADD,
CHANGE_DELETE,
CHANGE_MODIFY,
CHANGE_RENAME,
CHANGE_COPY,
RENAME_CHANGE_TYPES,
)
from dulwich.errors import (
SendPackError,
UpdateRefsError,
)
from dulwich.ignore import IgnoreFilterManager
from dulwich.index import (
blob_from_path_and_stat,
get_unstaged_changes,
)
from dulwich.object_store import (
tree_lookup_path,
)
from dulwich.objects import (
Commit,
Tag,
format_timezone,
parse_timezone,
pretty_format_tree_entry,
)
from dulwich.objectspec import (
parse_commit,
parse_object,
parse_ref,
parse_reftuples,
parse_tree,
)
from dulwich.pack import (
write_pack_index,
write_pack_objects,
)
from dulwich.patch import write_tree_diff
from dulwich.protocol import (
Protocol,
ZERO_SHA,
)
from dulwich.refs import (
ANNOTATED_TAG_SUFFIX,
LOCAL_BRANCH_PREFIX,
strip_peeled_refs,
)
from dulwich.repo import (BaseRepo, Repo)
from dulwich.server import (
FileSystemBackend,
TCPGitServer,
ReceivePackHandler,
UploadPackHandler,
update_server_info as server_update_server_info,
)
# Module level tuple definition for status output
GitStatus = namedtuple('GitStatus', 'staged unstaged untracked')
class NoneStream(RawIOBase):
"""Fallback if stdout or stderr are unavailable, does nothing."""
def read(self, size=-1):
return None
def readall(self):
return None
def readinto(self, b):
return None
def write(self, b):
return None
if sys.version_info[0] == 2:
default_bytes_out_stream = sys.stdout or NoneStream()
default_bytes_err_stream = sys.stderr or NoneStream()
else:
default_bytes_out_stream = (
getattr(sys.stdout, 'buffer', None) or NoneStream())
default_bytes_err_stream = (
getattr(sys.stderr, 'buffer', None) or NoneStream())
DEFAULT_ENCODING = 'utf-8'
class RemoteExists(Exception):
"""Raised when the remote already exists."""
def open_repo(path_or_repo):
"""Open an argument that can be a repository or a path for a repository."""
if isinstance(path_or_repo, BaseRepo):
return path_or_repo
return Repo(path_or_repo)
@contextmanager
def _noop_context_manager(obj):
"""Context manager that has the same api as closing but does nothing."""
yield obj
def open_repo_closing(path_or_repo):
"""Open an argument that can be a repository or a path for a repository.
returns a context manager that will close the repo on exit if the argument
is a path, else does nothing if the argument is a repo.
"""
if isinstance(path_or_repo, BaseRepo):
return _noop_context_manager(path_or_repo)
return closing(Repo(path_or_repo))
def path_to_tree_path(repopath, path):
"""Convert a path to a path usable in an index, e.g. bytes and relative to
the repository root.
- :param repopath: Repository path, absolute or relative to the cwd
- :param path: A path, absolute or relative to the cwd
- :return: A path formatted for use in e.g. an index
+ Args:
+ repopath: Repository path, absolute or relative to the cwd
+ path: A path, absolute or relative to the cwd
+ Returns: A path formatted for use in e.g. an index
"""
if not isinstance(path, bytes):
path = path.encode(sys.getfilesystemencoding())
if not isinstance(repopath, bytes):
repopath = repopath.encode(sys.getfilesystemencoding())
treepath = os.path.relpath(path, repopath)
if treepath.startswith(b'..'):
raise ValueError('Path not in repo')
if os.path.sep != '/':
treepath = treepath.replace(os.path.sep.encode('ascii'), b'/')
return treepath
def archive(repo, committish=None, outstream=default_bytes_out_stream,
errstream=default_bytes_err_stream):
"""Create an archive.
- :param repo: Path of repository for which to generate an archive.
- :param committish: Commit SHA1 or ref to use
- :param outstream: Output stream (defaults to stdout)
- :param errstream: Error stream (defaults to stderr)
+ Args:
+ repo: Path of repository for which to generate an archive.
+ committish: Commit SHA1 or ref to use
+ outstream: Output stream (defaults to stdout)
+ errstream: Error stream (defaults to stderr)
"""
if committish is None:
committish = "HEAD"
with open_repo_closing(repo) as repo_obj:
c = parse_commit(repo_obj, committish)
for chunk in tar_stream(
repo_obj.object_store, repo_obj.object_store[c.tree],
c.commit_time):
outstream.write(chunk)
def update_server_info(repo="."):
"""Update server info files for a repository.
- :param repo: path to the repository
+ Args:
+ repo: path to the repository
"""
with open_repo_closing(repo) as r:
server_update_server_info(r)
def symbolic_ref(repo, ref_name, force=False):
"""Set git symbolic ref into HEAD.
- :param repo: path to the repository
- :param ref_name: short name of the new ref
- :param force: force settings without checking if it exists in refs/heads
+ Args:
+ repo: path to the repository
+ ref_name: short name of the new ref
+ force: force settings without checking if it exists in refs/heads
"""
with open_repo_closing(repo) as repo_obj:
ref_path = _make_branch_ref(ref_name)
if not force and ref_path not in repo_obj.refs.keys():
raise ValueError('fatal: ref `%s` is not a ref' % ref_name)
repo_obj.refs.set_symbolic_ref(b'HEAD', ref_path)
def commit(repo=".", message=None, author=None, committer=None, encoding=None):
"""Create a new commit.
- :param repo: Path to repository
- :param message: Optional commit message
- :param author: Optional author name and email
- :param committer: Optional committer name and email
- :return: SHA1 of the new commit
+ Args:
+ repo: Path to repository
+ message: Optional commit message
+ author: Optional author name and email
+ committer: Optional committer name and email
+ Returns: SHA1 of the new commit
"""
# FIXME: Support --all argument
# FIXME: Support --signoff argument
if getattr(message, 'encode', None):
message = message.encode(encoding or DEFAULT_ENCODING)
if getattr(author, 'encode', None):
author = author.encode(encoding or DEFAULT_ENCODING)
if getattr(committer, 'encode', None):
committer = committer.encode(encoding or DEFAULT_ENCODING)
with open_repo_closing(repo) as r:
return r.do_commit(
message=message, author=author, committer=committer,
encoding=encoding)
def commit_tree(repo, tree, message=None, author=None, committer=None):
"""Create a new commit object.
- :param repo: Path to repository
- :param tree: An existing tree object
- :param author: Optional author name and email
- :param committer: Optional committer name and email
+ Args:
+ repo: Path to repository
+ tree: An existing tree object
+ author: Optional author name and email
+ committer: Optional committer name and email
"""
with open_repo_closing(repo) as r:
return r.do_commit(
message=message, tree=tree, committer=committer, author=author)
def init(path=".", bare=False):
"""Create a new git repository.
- :param path: Path to repository.
- :param bare: Whether to create a bare repository.
- :return: A Repo instance
+ Args:
+ path: Path to repository.
+ bare: Whether to create a bare repository.
+ Returns: A Repo instance
"""
if not os.path.exists(path):
os.mkdir(path)
if bare:
return Repo.init_bare(path)
else:
return Repo.init(path)
def clone(source, target=None, bare=False, checkout=None,
errstream=default_bytes_err_stream, outstream=None,
origin=b"origin", depth=None, **kwargs):
"""Clone a local or remote git repository.
- :param source: Path or URL for source repository
- :param target: Path to target repository (optional)
- :param bare: Whether or not to create a bare repository
- :param checkout: Whether or not to check-out HEAD after cloning
- :param errstream: Optional stream to write progress to
- :param outstream: Optional stream to write progress to (deprecated)
- :param origin: Name of remote from the repository used to clone
- :param depth: Depth to fetch at
- :return: The new repository
+ Args:
+ source: Path or URL for source repository
+ target: Path to target repository (optional)
+ bare: Whether or not to create a bare repository
+ checkout: Whether or not to check-out HEAD after cloning
+ errstream: Optional stream to write progress to
+ outstream: Optional stream to write progress to (deprecated)
+ origin: Name of remote from the repository used to clone
+ depth: Depth to fetch at
+ Returns: The new repository
"""
# TODO(jelmer): This code overlaps quite a bit with Repo.clone
if outstream is not None:
import warnings
warnings.warn(
"outstream= has been deprecated in favour of errstream=.",
DeprecationWarning, stacklevel=3)
errstream = outstream
if checkout is None:
checkout = (not bare)
if checkout and bare:
raise ValueError("checkout and bare are incompatible")
if target is None:
target = source.split("/")[-1]
if not os.path.exists(target):
os.mkdir(target)
if bare:
r = Repo.init_bare(target)
else:
r = Repo.init(target)
reflog_message = b'clone: from ' + source.encode('utf-8')
try:
fetch_result = fetch(
r, source, origin, errstream=errstream, message=reflog_message,
depth=depth, **kwargs)
target_config = r.get_config()
if not isinstance(source, bytes):
source = source.encode(DEFAULT_ENCODING)
target_config.set((b'remote', origin), b'url', source)
target_config.set(
(b'remote', origin), b'fetch',
b'+refs/heads/*:refs/remotes/' + origin + b'/*')
target_config.write_to_path()
# TODO(jelmer): Support symref capability,
# https://github.com/jelmer/dulwich/issues/485
try:
head = r[fetch_result[b'HEAD']]
except KeyError:
head = None
else:
r[b'HEAD'] = head.id
if checkout and not bare and head is not None:
errstream.write(b'Checking out ' + head.id + b'\n')
r.reset_index(head.tree)
except BaseException:
r.close()
raise
return r
def add(repo=".", paths=None):
"""Add files to the staging area.
- :param repo: Repository for the files
- :param paths: Paths to add. No value passed stages all modified files.
- :return: Tuple with set of added files and ignored files
+ Args:
+ repo: Repository for the files
+ paths: Paths to add. No value passed stages all modified files.
+ Returns: Tuple with set of added files and ignored files
"""
ignored = set()
with open_repo_closing(repo) as r:
ignore_manager = IgnoreFilterManager.from_repo(r)
if not paths:
paths = list(
get_untracked_paths(os.getcwd(), r.path, r.open_index()))
relpaths = []
if not isinstance(paths, list):
paths = [paths]
for p in paths:
relpath = os.path.relpath(p, r.path)
if relpath.startswith('..' + os.path.sep):
raise ValueError('path %r is not in repo' % relpath)
# FIXME: Support patterns, directories.
if ignore_manager.is_ignored(relpath):
ignored.add(relpath)
continue
relpaths.append(relpath)
r.stage(relpaths)
return (relpaths, ignored)
def _is_subdir(subdir, parentdir):
"""Check whether subdir is parentdir or a subdir of parentdir
If parentdir or subdir is a relative path, it will be disamgibuated
relative to the pwd.
"""
parentdir_abs = os.path.realpath(parentdir) + os.path.sep
subdir_abs = os.path.realpath(subdir) + os.path.sep
return subdir_abs.startswith(parentdir_abs)
# TODO: option to remove ignored files also, in line with `git clean -fdx`
def clean(repo=".", target_dir=None):
"""Remove any untracked files from the target directory recursively
Equivalent to running `git clean -fd` in target_dir.
- :param repo: Repository where the files may be tracked
- :param target_dir: Directory to clean - current directory if None
+ Args:
+ repo: Repository where the files may be tracked
+ target_dir: Directory to clean - current directory if None
"""
if target_dir is None:
target_dir = os.getcwd()
with open_repo_closing(repo) as r:
if not _is_subdir(target_dir, r.path):
raise ValueError("target_dir must be in the repo's working dir")
index = r.open_index()
ignore_manager = IgnoreFilterManager.from_repo(r)
paths_in_wd = _walk_working_dir_paths(target_dir, r.path)
# Reverse file visit order, so that files and subdirectories are
# removed before containing directory
for ap, is_dir in reversed(list(paths_in_wd)):
if is_dir:
# All subdirectories and files have been removed if untracked,
# so dir contains no tracked files iff it is empty.
is_empty = len(os.listdir(ap)) == 0
if is_empty:
os.rmdir(ap)
else:
ip = path_to_tree_path(r.path, ap)
is_tracked = ip in index
rp = os.path.relpath(ap, r.path)
is_ignored = ignore_manager.is_ignored(rp)
if not is_tracked and not is_ignored:
os.remove(ap)
def remove(repo=".", paths=None, cached=False):
"""Remove files from the staging area.
- :param repo: Repository for the files
- :param paths: Paths to remove
+ Args:
+ repo: Repository for the files
+ paths: Paths to remove
"""
with open_repo_closing(repo) as r:
index = r.open_index()
for p in paths:
full_path = os.path.abspath(p).encode(sys.getfilesystemencoding())
tree_path = path_to_tree_path(r.path, p)
try:
index_sha = index[tree_path].sha
except KeyError:
raise Exception('%s did not match any files' % p)
if not cached:
try:
st = os.lstat(full_path)
except OSError:
pass
else:
try:
blob = blob_from_path_and_stat(full_path, st)
except IOError:
pass
else:
try:
committed_sha = tree_lookup_path(
r.__getitem__, r[r.head()].tree, tree_path)[1]
except KeyError:
committed_sha = None
if blob.id != index_sha and index_sha != committed_sha:
raise Exception(
'file has staged content differing '
'from both the file and head: %s' % p)
if index_sha != committed_sha:
raise Exception(
'file has staged changes: %s' % p)
os.remove(full_path)
del index[tree_path]
index.write()
rm = remove
def commit_decode(commit, contents, default_encoding=DEFAULT_ENCODING):
if commit.encoding is not None:
return contents.decode(commit.encoding, "replace")
return contents.decode(default_encoding, "replace")
def print_commit(commit, decode, outstream=sys.stdout):
"""Write a human-readable commit log entry.
- :param commit: A `Commit` object
- :param outstream: A stream file to write to
+ Args:
+ commit: A `Commit` object
+ outstream: A stream file to write to
"""
outstream.write("-" * 50 + "\n")
outstream.write("commit: " + commit.id.decode('ascii') + "\n")
if len(commit.parents) > 1:
outstream.write(
"merge: " +
"...".join([c.decode('ascii') for c in commit.parents[1:]]) + "\n")
outstream.write("Author: " + decode(commit.author) + "\n")
if commit.author != commit.committer:
outstream.write("Committer: " + decode(commit.committer) + "\n")
time_tuple = time.gmtime(commit.author_time + commit.author_timezone)
time_str = time.strftime("%a %b %d %Y %H:%M:%S", time_tuple)
timezone_str = format_timezone(commit.author_timezone).decode('ascii')
outstream.write("Date: " + time_str + " " + timezone_str + "\n")
outstream.write("\n")
outstream.write(decode(commit.message) + "\n")
outstream.write("\n")
def print_tag(tag, decode, outstream=sys.stdout):
"""Write a human-readable tag.
- :param tag: A `Tag` object
- :param decode: Function for decoding bytes to unicode string
- :param outstream: A stream to write to
+ Args:
+ tag: A `Tag` object
+ decode: Function for decoding bytes to unicode string
+ outstream: A stream to write to
"""
outstream.write("Tagger: " + decode(tag.tagger) + "\n")
time_tuple = time.gmtime(tag.tag_time + tag.tag_timezone)
time_str = time.strftime("%a %b %d %Y %H:%M:%S", time_tuple)
timezone_str = format_timezone(tag.tag_timezone).decode('ascii')
outstream.write("Date: " + time_str + " " + timezone_str + "\n")
outstream.write("\n")
outstream.write(decode(tag.message) + "\n")
outstream.write("\n")
def show_blob(repo, blob, decode, outstream=sys.stdout):
"""Write a blob to a stream.
- :param repo: A `Repo` object
- :param blob: A `Blob` object
- :param decode: Function for decoding bytes to unicode string
- :param outstream: A stream file to write to
+ Args:
+ repo: A `Repo` object
+ blob: A `Blob` object
+ decode: Function for decoding bytes to unicode string
+ outstream: A stream file to write to
"""
outstream.write(decode(blob.data))
def show_commit(repo, commit, decode, outstream=sys.stdout):
"""Show a commit to a stream.
- :param repo: A `Repo` object
- :param commit: A `Commit` object
- :param decode: Function for decoding bytes to unicode string
- :param outstream: Stream to write to
+ Args:
+ repo: A `Repo` object
+ commit: A `Commit` object
+ decode: Function for decoding bytes to unicode string
+ outstream: Stream to write to
"""
print_commit(commit, decode=decode, outstream=outstream)
if commit.parents:
parent_commit = repo[commit.parents[0]]
base_tree = parent_commit.tree
else:
base_tree = None
diffstream = BytesIO()
write_tree_diff(
diffstream,
repo.object_store, base_tree, commit.tree)
diffstream.seek(0)
outstream.write(
diffstream.getvalue().decode(
commit.encoding or DEFAULT_ENCODING, 'replace'))
def show_tree(repo, tree, decode, outstream=sys.stdout):
"""Print a tree to a stream.
- :param repo: A `Repo` object
- :param tree: A `Tree` object
- :param decode: Function for decoding bytes to unicode string
- :param outstream: Stream to write to
+ Args:
+ repo: A `Repo` object
+ tree: A `Tree` object
+ decode: Function for decoding bytes to unicode string
+ outstream: Stream to write to
"""
for n in tree:
outstream.write(decode(n) + "\n")
def show_tag(repo, tag, decode, outstream=sys.stdout):
"""Print a tag to a stream.
- :param repo: A `Repo` object
- :param tag: A `Tag` object
- :param decode: Function for decoding bytes to unicode string
- :param outstream: Stream to write to
+ Args:
+ repo: A `Repo` object
+ tag: A `Tag` object
+ decode: Function for decoding bytes to unicode string
+ outstream: Stream to write to
"""
print_tag(tag, decode, outstream)
show_object(repo, repo[tag.object[1]], decode, outstream)
def show_object(repo, obj, decode, outstream):
return {
b"tree": show_tree,
b"blob": show_blob,
b"commit": show_commit,
b"tag": show_tag,
}[obj.type_name](repo, obj, decode, outstream)
def print_name_status(changes):
"""Print a simple status summary, listing changed files.
"""
for change in changes:
if not change:
continue
if isinstance(change, list):
change = change[0]
if change.type == CHANGE_ADD:
path1 = change.new.path
path2 = ''
kind = 'A'
elif change.type == CHANGE_DELETE:
path1 = change.old.path
path2 = ''
kind = 'D'
elif change.type == CHANGE_MODIFY:
path1 = change.new.path
path2 = ''
kind = 'M'
elif change.type in RENAME_CHANGE_TYPES:
path1 = change.old.path
path2 = change.new.path
if change.type == CHANGE_RENAME:
kind = 'R'
elif change.type == CHANGE_COPY:
kind = 'C'
yield '%-8s%-20s%-20s' % (kind, path1, path2)
def log(repo=".", paths=None, outstream=sys.stdout, max_entries=None,
reverse=False, name_status=False):
"""Write commit logs.
- :param repo: Path to repository
- :param paths: Optional set of specific paths to print entries for
- :param outstream: Stream to write log output to
- :param reverse: Reverse order in which entries are printed
- :param name_status: Print name status
- :param max_entries: Optional maximum number of entries to display
+ Args:
+ repo: Path to repository
+ paths: Optional set of specific paths to print entries for
+ outstream: Stream to write log output to
+ reverse: Reverse order in which entries are printed
+ name_status: Print name status
+ max_entries: Optional maximum number of entries to display
"""
with open_repo_closing(repo) as r:
walker = r.get_walker(
max_entries=max_entries, paths=paths, reverse=reverse)
for entry in walker:
def decode(x):
return commit_decode(entry.commit, x)
print_commit(entry.commit, decode, outstream)
if name_status:
outstream.writelines(
[l+'\n' for l in print_name_status(entry.changes())])
# TODO(jelmer): better default for encoding?
def show(repo=".", objects=None, outstream=sys.stdout,
default_encoding=DEFAULT_ENCODING):
"""Print the changes in a commit.
- :param repo: Path to repository
- :param objects: Objects to show (defaults to [HEAD])
- :param outstream: Stream to write to
- :param default_encoding: Default encoding to use if none is set in the
+ Args:
+ repo: Path to repository
+ objects: Objects to show (defaults to [HEAD])
+ outstream: Stream to write to
+ default_encoding: Default encoding to use if none is set in the
commit
"""
if objects is None:
objects = ["HEAD"]
if not isinstance(objects, list):
objects = [objects]
with open_repo_closing(repo) as r:
for objectish in objects:
o = parse_object(r, objectish)
if isinstance(o, Commit):
def decode(x):
return commit_decode(o, x, default_encoding)
else:
def decode(x):
return x.decode(default_encoding)
show_object(r, o, decode, outstream)
def diff_tree(repo, old_tree, new_tree, outstream=sys.stdout):
"""Compares the content and mode of blobs found via two tree objects.
- :param repo: Path to repository
- :param old_tree: Id of old tree
- :param new_tree: Id of new tree
- :param outstream: Stream to write to
+ Args:
+ repo: Path to repository
+ old_tree: Id of old tree
+ new_tree: Id of new tree
+ outstream: Stream to write to
"""
with open_repo_closing(repo) as r:
write_tree_diff(outstream, r.object_store, old_tree, new_tree)
def rev_list(repo, commits, outstream=sys.stdout):
"""Lists commit objects in reverse chronological order.
- :param repo: Path to repository
- :param commits: Commits over which to iterate
- :param outstream: Stream to write to
+ Args:
+ repo: Path to repository
+ commits: Commits over which to iterate
+ outstream: Stream to write to
"""
with open_repo_closing(repo) as r:
for entry in r.get_walker(include=[r[c].id for c in commits]):
outstream.write(entry.commit.id + b"\n")
def tag(*args, **kwargs):
import warnings
warnings.warn("tag has been deprecated in favour of tag_create.",
DeprecationWarning)
return tag_create(*args, **kwargs)
def tag_create(
repo, tag, author=None, message=None, annotated=False,
objectish="HEAD", tag_time=None, tag_timezone=None,
sign=False):
"""Creates a tag in git via dulwich calls:
- :param repo: Path to repository
- :param tag: tag string
- :param author: tag author (optional, if annotated is set)
- :param message: tag message (optional)
- :param annotated: whether to create an annotated tag
- :param objectish: object the tag should point at, defaults to HEAD
- :param tag_time: Optional time for annotated tag
- :param tag_timezone: Optional timezone for annotated tag
- :param sign: GPG Sign the tag
+ Args:
+ repo: Path to repository
+ tag: tag string
+ author: tag author (optional, if annotated is set)
+ message: tag message (optional)
+ annotated: whether to create an annotated tag
+ objectish: object the tag should point at, defaults to HEAD
+ tag_time: Optional time for annotated tag
+ tag_timezone: Optional timezone for annotated tag
+ sign: GPG Sign the tag
"""
with open_repo_closing(repo) as r:
object = parse_object(r, objectish)
if annotated:
# Create the tag object
tag_obj = Tag()
if author is None:
# TODO(jelmer): Don't use repo private method.
author = r._get_user_identity(r.get_config_stack())
tag_obj.tagger = author
tag_obj.message = message
tag_obj.name = tag
tag_obj.object = (type(object), object.id)
if tag_time is None:
tag_time = int(time.time())
tag_obj.tag_time = tag_time
if tag_timezone is None:
# TODO(jelmer) Use current user timezone rather than UTC
tag_timezone = 0
elif isinstance(tag_timezone, str):
tag_timezone = parse_timezone(tag_timezone)
tag_obj.tag_timezone = tag_timezone
if sign:
import gpg
with gpg.Context(armor=True) as c:
tag_obj.signature, unused_result = c.sign(
tag_obj.as_raw_string())
r.object_store.add_object(tag_obj)
tag_id = tag_obj.id
else:
tag_id = object.id
r.refs[_make_tag_ref(tag)] = tag_id
def list_tags(*args, **kwargs):
import warnings
warnings.warn("list_tags has been deprecated in favour of tag_list.",
DeprecationWarning)
return tag_list(*args, **kwargs)
def tag_list(repo, outstream=sys.stdout):
"""List all tags.
- :param repo: Path to repository
- :param outstream: Stream to write tags to
+ Args:
+ repo: Path to repository
+ outstream: Stream to write tags to
"""
with open_repo_closing(repo) as r:
tags = sorted(r.refs.as_dict(b"refs/tags"))
return tags
def tag_delete(repo, name):
"""Remove a tag.
- :param repo: Path to repository
- :param name: Name of tag to remove
+ Args:
+ repo: Path to repository
+ name: Name of tag to remove
"""
with open_repo_closing(repo) as r:
if isinstance(name, bytes):
names = [name]
elif isinstance(name, list):
names = name
else:
raise TypeError("Unexpected tag name type %r" % name)
for name in names:
del r.refs[_make_tag_ref(name)]
def reset(repo, mode, treeish="HEAD"):
"""Reset current HEAD to the specified state.
- :param repo: Path to repository
- :param mode: Mode ("hard", "soft", "mixed")
- :param treeish: Treeish to reset to
+ Args:
+ repo: Path to repository
+ mode: Mode ("hard", "soft", "mixed")
+ treeish: Treeish to reset to
"""
if mode != "hard":
raise ValueError("hard is the only mode currently supported")
with open_repo_closing(repo) as r:
tree = parse_tree(r, treeish)
r.reset_index(tree.id)
def push(repo, remote_location, refspecs,
outstream=default_bytes_out_stream,
errstream=default_bytes_err_stream, **kwargs):
"""Remote push with dulwich via dulwich.client
- :param repo: Path to repository
- :param remote_location: Location of the remote
- :param refspecs: Refs to push to remote
- :param outstream: A stream file to write output
- :param errstream: A stream file to write errors
+ Args:
+ repo: Path to repository
+ remote_location: Location of the remote
+ refspecs: Refs to push to remote
+ outstream: A stream file to write output
+ errstream: A stream file to write errors
"""
# Open the repo
with open_repo_closing(repo) as r:
# Get the client and path
client, path = get_transport_and_path(
remote_location, config=r.get_config_stack(), **kwargs)
selected_refs = []
def update_refs(refs):
selected_refs.extend(parse_reftuples(r.refs, refs, refspecs))
new_refs = {}
# TODO: Handle selected_refs == {None: None}
for (lh, rh, force) in selected_refs:
if lh is None:
new_refs[rh] = ZERO_SHA
else:
new_refs[rh] = r.refs[lh]
return new_refs
err_encoding = getattr(errstream, 'encoding', None) or DEFAULT_ENCODING
remote_location_bytes = client.get_url(path).encode(err_encoding)
try:
client.send_pack(
path, update_refs,
generate_pack_data=r.object_store.generate_pack_data,
progress=errstream.write)
errstream.write(
b"Push to " + remote_location_bytes + b" successful.\n")
except (UpdateRefsError, SendPackError) as e:
errstream.write(b"Push to " + remote_location_bytes +
b" failed -> " + e.message.encode(err_encoding) +
b"\n")
def pull(repo, remote_location=None, refspecs=None,
outstream=default_bytes_out_stream,
errstream=default_bytes_err_stream, **kwargs):
"""Pull from remote via dulwich.client
- :param repo: Path to repository
- :param remote_location: Location of the remote
- :param refspec: refspecs to fetch
- :param outstream: A stream file to write to output
- :param errstream: A stream file to write to errors
+ Args:
+ repo: Path to repository
+ remote_location: Location of the remote
+ refspec: refspecs to fetch
+ outstream: A stream file to write to output
+ errstream: A stream file to write to errors
"""
# Open the repo
with open_repo_closing(repo) as r:
if remote_location is None:
# TODO(jelmer): Lookup 'remote' for current branch in config
raise NotImplementedError(
"looking up remote from branch config not supported yet")
if refspecs is None:
refspecs = [b"HEAD"]
selected_refs = []
def determine_wants(remote_refs):
selected_refs.extend(
parse_reftuples(remote_refs, r.refs, refspecs))
return [remote_refs[lh] for (lh, rh, force) in selected_refs]
client, path = get_transport_and_path(
remote_location, config=r.get_config_stack(), **kwargs)
fetch_result = client.fetch(
path, r, progress=errstream.write, determine_wants=determine_wants)
for (lh, rh, force) in selected_refs:
r.refs[rh] = fetch_result.refs[lh]
if selected_refs:
r[b'HEAD'] = fetch_result.refs[selected_refs[0][1]]
# Perform 'git checkout .' - syncs staged changes
tree = r[b"HEAD"].tree
r.reset_index(tree=tree)
def status(repo=".", ignored=False):
"""Returns staged, unstaged, and untracked changes relative to the HEAD.
- :param repo: Path to repository or repository object
- :param ignored: Whether to include ignored files in `untracked`
- :return: GitStatus tuple,
+ Args:
+ repo: Path to repository or repository object
+ ignored: Whether to include ignored files in `untracked`
+ Returns: GitStatus tuple,
staged - dict with lists of staged paths (diff index/HEAD)
unstaged - list of unstaged paths (diff index/working-tree)
untracked - list of untracked, un-ignored & non-.git paths
"""
with open_repo_closing(repo) as r:
# 1. Get status of staged
tracked_changes = get_tree_changes(r)
# 2. Get status of unstaged
index = r.open_index()
normalizer = r.get_blob_normalizer()
filter_callback = normalizer.checkin_normalize
unstaged_changes = list(
get_unstaged_changes(index, r.path, filter_callback)
)
ignore_manager = IgnoreFilterManager.from_repo(r)
untracked_paths = get_untracked_paths(r.path, r.path, index)
if ignored:
untracked_changes = list(untracked_paths)
else:
untracked_changes = [
p for p in untracked_paths
if not ignore_manager.is_ignored(p)]
return GitStatus(tracked_changes, unstaged_changes, untracked_changes)
def _walk_working_dir_paths(frompath, basepath):
"""Get path, is_dir for files in working dir from frompath
- :param frompath: Path to begin walk
- :param basepath: Path to compare to
+ Args:
+ frompath: Path to begin walk
+ basepath: Path to compare to
"""
for dirpath, dirnames, filenames in os.walk(frompath):
# Skip .git and below.
if '.git' in dirnames:
dirnames.remove('.git')
if dirpath != basepath:
continue
if '.git' in filenames:
filenames.remove('.git')
if dirpath != basepath:
continue
if dirpath != frompath:
yield dirpath, True
for filename in filenames:
filepath = os.path.join(dirpath, filename)
yield filepath, False
def get_untracked_paths(frompath, basepath, index):
"""Get untracked paths.
+ Args:
;param frompath: Path to walk
- :param basepath: Path to compare to
- :param index: Index to check against
+ basepath: Path to compare to
+ index: Index to check against
"""
for ap, is_dir in _walk_working_dir_paths(frompath, basepath):
if not is_dir:
ip = path_to_tree_path(basepath, ap)
if ip not in index:
yield os.path.relpath(ap, frompath)
def get_tree_changes(repo):
"""Return add/delete/modify changes to tree by comparing index to HEAD.
- :param repo: repo path or object
- :return: dict with lists for each type of change
+ Args:
+ repo: repo path or object
+ Returns: dict with lists for each type of change
"""
with open_repo_closing(repo) as r:
index = r.open_index()
# Compares the Index to the HEAD & determines changes
# Iterate through the changes and report add/delete/modify
# TODO: call out to dulwich.diff_tree somehow.
tracked_changes = {
'add': [],
'delete': [],
'modify': [],
}
try:
tree_id = r[b'HEAD'].tree
except KeyError:
tree_id = None
for change in index.changes_from_tree(r.object_store, tree_id):
if not change[0][0]:
tracked_changes['add'].append(change[0][1])
elif not change[0][1]:
tracked_changes['delete'].append(change[0][0])
elif change[0][0] == change[0][1]:
tracked_changes['modify'].append(change[0][0])
else:
raise AssertionError('git mv ops not yet supported')
return tracked_changes
def daemon(path=".", address=None, port=None):
"""Run a daemon serving Git requests over TCP/IP.
- :param path: Path to the directory to serve.
- :param address: Optional address to listen on (defaults to ::)
- :param port: Optional port to listen on (defaults to TCP_GIT_PORT)
+ Args:
+ path: Path to the directory to serve.
+ address: Optional address to listen on (defaults to ::)
+ port: Optional port to listen on (defaults to TCP_GIT_PORT)
"""
# TODO(jelmer): Support git-daemon-export-ok and --export-all.
backend = FileSystemBackend(path)
server = TCPGitServer(backend, address, port)
server.serve_forever()
def web_daemon(path=".", address=None, port=None):
"""Run a daemon serving Git requests over HTTP.
- :param path: Path to the directory to serve
- :param address: Optional address to listen on (defaults to ::)
- :param port: Optional port to listen on (defaults to 80)
+ Args:
+ path: Path to the directory to serve
+ address: Optional address to listen on (defaults to ::)
+ port: Optional port to listen on (defaults to 80)
"""
from dulwich.web import (
make_wsgi_chain,
make_server,
WSGIRequestHandlerLogger,
WSGIServerLogger)
backend = FileSystemBackend(path)
app = make_wsgi_chain(backend)
server = make_server(address, port, app,
handler_class=WSGIRequestHandlerLogger,
server_class=WSGIServerLogger)
server.serve_forever()
def upload_pack(path=".", inf=None, outf=None):
"""Upload a pack file after negotiating its contents using smart protocol.
- :param path: Path to the repository
- :param inf: Input stream to communicate with client
- :param outf: Output stream to communicate with client
+ Args:
+ path: Path to the repository
+ inf: Input stream to communicate with client
+ outf: Output stream to communicate with client
"""
if outf is None:
outf = getattr(sys.stdout, 'buffer', sys.stdout)
if inf is None:
inf = getattr(sys.stdin, 'buffer', sys.stdin)
path = os.path.expanduser(path)
backend = FileSystemBackend(path)
def send_fn(data):
outf.write(data)
outf.flush()
proto = Protocol(inf.read, send_fn)
handler = UploadPackHandler(backend, [path], proto)
# FIXME: Catch exceptions and write a single-line summary to outf.
handler.handle()
return 0
def receive_pack(path=".", inf=None, outf=None):
"""Receive a pack file after negotiating its contents using smart protocol.
- :param path: Path to the repository
- :param inf: Input stream to communicate with client
- :param outf: Output stream to communicate with client
+ Args:
+ path: Path to the repository
+ inf: Input stream to communicate with client
+ outf: Output stream to communicate with client
"""
if outf is None:
outf = getattr(sys.stdout, 'buffer', sys.stdout)
if inf is None:
inf = getattr(sys.stdin, 'buffer', sys.stdin)
path = os.path.expanduser(path)
backend = FileSystemBackend(path)
def send_fn(data):
outf.write(data)
outf.flush()
proto = Protocol(inf.read, send_fn)
handler = ReceivePackHandler(backend, [path], proto)
# FIXME: Catch exceptions and write a single-line summary to outf.
handler.handle()
return 0
def _make_branch_ref(name):
if getattr(name, 'encode', None):
name = name.encode(DEFAULT_ENCODING)
return LOCAL_BRANCH_PREFIX + name
def _make_tag_ref(name):
if getattr(name, 'encode', None):
name = name.encode(DEFAULT_ENCODING)
return b"refs/tags/" + name
def branch_delete(repo, name):
"""Delete a branch.
- :param repo: Path to the repository
- :param name: Name of the branch
+ Args:
+ repo: Path to the repository
+ name: Name of the branch
"""
with open_repo_closing(repo) as r:
if isinstance(name, list):
names = name
else:
names = [name]
for name in names:
del r.refs[_make_branch_ref(name)]
def branch_create(repo, name, objectish=None, force=False):
"""Create a branch.
- :param repo: Path to the repository
- :param name: Name of the new branch
- :param objectish: Target object to point new branch at (defaults to HEAD)
- :param force: Force creation of branch, even if it already exists
+ Args:
+ repo: Path to the repository
+ name: Name of the new branch
+ objectish: Target object to point new branch at (defaults to HEAD)
+ force: Force creation of branch, even if it already exists
"""
with open_repo_closing(repo) as r:
if objectish is None:
objectish = "HEAD"
object = parse_object(r, objectish)
refname = _make_branch_ref(name)
ref_message = b"branch: Created from " + objectish.encode('utf-8')
if force:
r.refs.set_if_equals(refname, None, object.id, message=ref_message)
else:
if not r.refs.add_if_new(refname, object.id, message=ref_message):
raise KeyError("Branch with name %s already exists." % name)
def branch_list(repo):
"""List all branches.
- :param repo: Path to the repository
+ Args:
+ repo: Path to the repository
"""
with open_repo_closing(repo) as r:
return r.refs.keys(base=LOCAL_BRANCH_PREFIX)
def active_branch(repo):
"""Return the active branch in the repository, if any.
Args:
repo: Repository to open
Returns:
branch name
Raises:
KeyError: if the repository does not have a working tree
IndexError: if HEAD is floating
"""
with open_repo_closing(repo) as r:
active_ref = r.refs.follow(b'HEAD')[0][1]
if not active_ref.startswith(LOCAL_BRANCH_PREFIX):
raise ValueError(active_ref)
return active_ref[len(LOCAL_BRANCH_PREFIX):]
def fetch(repo, remote_location, remote_name=b'origin', outstream=sys.stdout,
errstream=default_bytes_err_stream, message=None, depth=None,
prune=False, prune_tags=False, **kwargs):
"""Fetch objects from a remote server.
Args:
repo: Path to the repository
remote_location: String identifying a remote server
remote_name: Name for remote server
outstream: Output stream (defaults to stdout)
errstream: Error stream (defaults to stderr)
message: Reflog message (defaults to b"fetch: from ")
depth: Depth to fetch at
prune: Prune remote removed refs
prune_tags: Prune reomte removed tags
Returns:
Dictionary with refs on the remote
"""
if message is None:
message = b'fetch: from ' + remote_location.encode("utf-8")
with open_repo_closing(repo) as r:
client, path = get_transport_and_path(
remote_location, config=r.get_config_stack(), **kwargs)
fetch_result = client.fetch(path, r, progress=errstream.write,
depth=depth)
stripped_refs = strip_peeled_refs(fetch_result.refs)
branches = {
n[len(LOCAL_BRANCH_PREFIX):]: v for (n, v) in stripped_refs.items()
if n.startswith(LOCAL_BRANCH_PREFIX)}
r.refs.import_refs(
b'refs/remotes/' + remote_name, branches, message=message,
prune=prune)
tags = {
n[len(b'refs/tags/'):]: v for (n, v) in stripped_refs.items()
if n.startswith(b'refs/tags/') and
not n.endswith(ANNOTATED_TAG_SUFFIX)}
r.refs.import_refs(
b'refs/tags', tags, message=message,
prune=prune_tags)
return fetch_result.refs
def ls_remote(remote, config=None, **kwargs):
"""List the refs in a remote.
Args:
remote: Remote repository location
config: Configuration to use
Returns:
Dictionary with remote refs
"""
if config is None:
config = StackedConfig.default()
client, host_path = get_transport_and_path(remote, config=config, **kwargs)
return client.get_refs(host_path)
def repack(repo):
"""Repack loose files in a repository.
Currently this only packs loose objects.
- :param repo: Path to the repository
+ Args:
+ repo: Path to the repository
"""
with open_repo_closing(repo) as r:
r.object_store.pack_loose_objects()
def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None):
"""Pack objects into a file.
- :param repo: Path to the repository
- :param object_ids: List of object ids to write
- :param packf: File-like object to write to
- :param idxf: File-like object to write to (can be None)
+ Args:
+ repo: Path to the repository
+ object_ids: List of object ids to write
+ packf: File-like object to write to
+ idxf: File-like object to write to (can be None)
"""
with open_repo_closing(repo) as r:
entries, data_sum = write_pack_objects(
packf,
r.object_store.iter_shas((oid, None) for oid in object_ids),
delta_window_size=delta_window_size)
if idxf is not None:
entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])
write_pack_index(idxf, entries, data_sum)
def ls_tree(repo, treeish=b"HEAD", outstream=sys.stdout, recursive=False,
name_only=False):
"""List contents of a tree.
- :param repo: Path to the repository
- :param tree_ish: Tree id to list
- :param outstream: Output stream (defaults to stdout)
- :param recursive: Whether to recursively list files
- :param name_only: Only print item name
+ Args:
+ repo: Path to the repository
+ tree_ish: Tree id to list
+ outstream: Output stream (defaults to stdout)
+ recursive: Whether to recursively list files
+ name_only: Only print item name
"""
def list_tree(store, treeid, base):
for (name, mode, sha) in store[treeid].iteritems():
if base:
name = posixpath.join(base, name)
if name_only:
outstream.write(name + b"\n")
else:
outstream.write(pretty_format_tree_entry(name, mode, sha))
if stat.S_ISDIR(mode) and recursive:
list_tree(store, sha, name)
with open_repo_closing(repo) as r:
tree = parse_tree(r, treeish)
list_tree(r.object_store, tree.id, "")
def remote_add(repo, name, url):
"""Add a remote.
- :param repo: Path to the repository
- :param name: Remote name
- :param url: Remote URL
+ Args:
+ repo: Path to the repository
+ name: Remote name
+ url: Remote URL
"""
if not isinstance(name, bytes):
name = name.encode(DEFAULT_ENCODING)
if not isinstance(url, bytes):
url = url.encode(DEFAULT_ENCODING)
with open_repo_closing(repo) as r:
c = r.get_config()
section = (b'remote', name)
if c.has_section(section):
raise RemoteExists(section)
c.set(section, b"url", url)
c.write_to_path()
def check_ignore(repo, paths, no_index=False):
"""Debug gitignore files.
- :param repo: Path to the repository
- :param paths: List of paths to check for
- :param no_index: Don't check index
- :return: List of ignored files
+ Args:
+ repo: Path to the repository
+ paths: List of paths to check for
+ no_index: Don't check index
+ Returns: List of ignored files
"""
with open_repo_closing(repo) as r:
index = r.open_index()
ignore_manager = IgnoreFilterManager.from_repo(r)
for path in paths:
if not no_index and path_to_tree_path(r.path, path) in index:
continue
if os.path.isabs(path):
path = os.path.relpath(path, r.path)
if ignore_manager.is_ignored(path):
yield path
def update_head(repo, target, detached=False, new_branch=None):
"""Update HEAD to point at a new branch/commit.
Note that this does not actually update the working tree.
- :param repo: Path to the repository
- :param detach: Create a detached head
- :param target: Branch or committish to switch to
- :param new_branch: New branch to create
+ Args:
+ repo: Path to the repository
+ detach: Create a detached head
+ target: Branch or committish to switch to
+ new_branch: New branch to create
"""
with open_repo_closing(repo) as r:
if new_branch is not None:
to_set = _make_branch_ref(new_branch)
else:
to_set = b"HEAD"
if detached:
# TODO(jelmer): Provide some way so that the actual ref gets
# updated rather than what it points to, so the delete isn't
# necessary.
del r.refs[to_set]
r.refs[to_set] = parse_commit(r, target).id
else:
r.refs.set_symbolic_ref(to_set, parse_ref(r, target))
if new_branch is not None:
r.refs.set_symbolic_ref(b"HEAD", to_set)
def check_mailmap(repo, contact):
"""Check canonical name and email of contact.
- :param repo: Path to the repository
- :param contact: Contact name and/or email
- :return: Canonical contact data
+ Args:
+ repo: Path to the repository
+ contact: Contact name and/or email
+ Returns: Canonical contact data
"""
with open_repo_closing(repo) as r:
from dulwich.mailmap import Mailmap
import errno
try:
mailmap = Mailmap.from_path(os.path.join(r.path, '.mailmap'))
except IOError as e:
if e.errno != errno.ENOENT:
raise
mailmap = Mailmap()
return mailmap.lookup(contact)
def fsck(repo):
"""Check a repository.
- :param repo: A path to the repository
- :return: Iterator over errors/warnings
+ Args:
+ repo: A path to the repository
+ Returns: Iterator over errors/warnings
"""
with open_repo_closing(repo) as r:
# TODO(jelmer): check pack files
# TODO(jelmer): check graph
# TODO(jelmer): check refs
for sha in r.object_store:
o = r.object_store[sha]
try:
o.check()
except Exception as e:
yield (sha, e)
def stash_list(repo):
"""List all stashes in a repository."""
with open_repo_closing(repo) as r:
from dulwich.stash import Stash
stash = Stash.from_repo(r)
return enumerate(list(stash.stashes()))
def stash_push(repo):
"""Push a new stash onto the stack."""
with open_repo_closing(repo) as r:
from dulwich.stash import Stash
stash = Stash.from_repo(r)
stash.push()
def stash_pop(repo):
"""Pop a new stash from the stack."""
with open_repo_closing(repo) as r:
from dulwich.stash import Stash
stash = Stash.from_repo(r)
stash.pop()
def ls_files(repo):
"""List all files in an index."""
with open_repo_closing(repo) as r:
return sorted(r.open_index())
def describe(repo):
"""Describe the repository version.
- :param projdir: git repository root
- :returns: a string description of the current git revision
+ Args:
+ projdir: git repository root
+ Returns: a string description of the current git revision
Examples: "gabcdefh", "v0.1" or "v0.1-5-gabcdefh".
"""
# Get the repository
with open_repo_closing(repo) as r:
# Get a list of all tags
refs = r.get_refs()
tags = {}
for key, value in refs.items():
key = key.decode()
obj = r.get_object(value)
if u'tags' not in key:
continue
_, tag = key.rsplit(u'/', 1)
try:
commit = obj.object
except AttributeError:
continue
else:
commit = r.get_object(commit[1])
tags[tag] = [
datetime.datetime(*time.gmtime(commit.commit_time)[:6]),
commit.id.decode('ascii'),
]
sorted_tags = sorted(tags.items(),
key=lambda tag: tag[1][0],
reverse=True)
# If there are no tags, return the current commit
if len(sorted_tags) == 0:
return 'g{}'.format(r[r.head()].id.decode('ascii')[:7])
# We're now 0 commits from the top
commit_count = 0
# Get the latest commit
latest_commit = r[r.head()]
# Walk through all commits
walker = r.get_walker()
for entry in walker:
# Check if tag
commit_id = entry.commit.id.decode('ascii')
for tag in sorted_tags:
tag_name = tag[0]
tag_commit = tag[1][1]
if commit_id == tag_commit:
if commit_count == 0:
return tag_name
else:
return '{}-{}-g{}'.format(
tag_name,
commit_count,
latest_commit.id.decode('ascii')[:7])
commit_count += 1
# Return plain commit if no parent tag can be found
return 'g{}'.format(latest_commit.id.decode('ascii')[:7])
def get_object_by_path(repo, path, committish=None):
"""Get an object by path.
- :param repo: A path to the repository
- :param path: Path to look up
- :param committish: Commit to look up path in
- :return: A `ShaFile` object
+ Args:
+ repo: A path to the repository
+ path: Path to look up
+ committish: Commit to look up path in
+ Returns: A `ShaFile` object
"""
if committish is None:
committish = "HEAD"
# Get the repository
with open_repo_closing(repo) as r:
commit = parse_commit(r, committish)
base_tree = commit.tree
if not isinstance(path, bytes):
path = path.encode(commit.encoding or DEFAULT_ENCODING)
(mode, sha) = tree_lookup_path(
r.object_store.__getitem__,
base_tree, path)
return r[sha]
def write_tree(repo):
"""Write a tree object from the index.
- :param repo: Repository for which to write tree
- :return: tree id for the tree that was written
+ Args:
+ repo: Repository for which to write tree
+ Returns: tree id for the tree that was written
"""
with open_repo_closing(repo) as r:
return r.open_index().commit(r.object_store)
diff --git a/dulwich/protocol.py b/dulwich/protocol.py
index df61b34e..756fe663 100644
--- a/dulwich/protocol.py
+++ b/dulwich/protocol.py
@@ -1,550 +1,559 @@
# protocol.py -- Shared parts of the git protocols
# Copyright (C) 2008 John Carr
# Copyright (C) 2008-2012 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Generic functions for talking the git smart server protocol."""
from io import BytesIO
from os import (
SEEK_END,
)
import socket
import dulwich
from dulwich.errors import (
HangupException,
GitProtocolError,
)
TCP_GIT_PORT = 9418
ZERO_SHA = b"0" * 40
SINGLE_ACK = 0
MULTI_ACK = 1
MULTI_ACK_DETAILED = 2
# pack data
SIDE_BAND_CHANNEL_DATA = 1
# progress messages
SIDE_BAND_CHANNEL_PROGRESS = 2
# fatal error message just before stream aborts
SIDE_BAND_CHANNEL_FATAL = 3
CAPABILITY_DEEPEN_SINCE = b'deepen-since'
CAPABILITY_DEEPEN_NOT = b'deepen-not'
CAPABILITY_DEEPEN_RELATIVE = b'deepen-relative'
CAPABILITY_DELETE_REFS = b'delete-refs'
CAPABILITY_INCLUDE_TAG = b'include-tag'
CAPABILITY_MULTI_ACK = b'multi_ack'
CAPABILITY_MULTI_ACK_DETAILED = b'multi_ack_detailed'
CAPABILITY_NO_DONE = b'no-done'
CAPABILITY_NO_PROGRESS = b'no-progress'
CAPABILITY_OFS_DELTA = b'ofs-delta'
CAPABILITY_QUIET = b'quiet'
CAPABILITY_REPORT_STATUS = b'report-status'
CAPABILITY_SHALLOW = b'shallow'
CAPABILITY_SIDE_BAND = b'side-band'
CAPABILITY_SIDE_BAND_64K = b'side-band-64k'
CAPABILITY_THIN_PACK = b'thin-pack'
CAPABILITY_AGENT = b'agent'
CAPABILITY_SYMREF = b'symref'
# Magic ref that is used to attach capabilities to when
# there are no refs. Should always be ste to ZERO_SHA.
CAPABILITIES_REF = b'capabilities^{}'
COMMON_CAPABILITIES = [
CAPABILITY_OFS_DELTA,
CAPABILITY_SIDE_BAND,
CAPABILITY_SIDE_BAND_64K,
CAPABILITY_AGENT,
CAPABILITY_NO_PROGRESS]
KNOWN_UPLOAD_CAPABILITIES = set(COMMON_CAPABILITIES + [
CAPABILITY_THIN_PACK,
CAPABILITY_MULTI_ACK,
CAPABILITY_MULTI_ACK_DETAILED,
CAPABILITY_INCLUDE_TAG,
CAPABILITY_DEEPEN_SINCE,
CAPABILITY_SYMREF,
CAPABILITY_SHALLOW,
CAPABILITY_DEEPEN_NOT,
CAPABILITY_DEEPEN_RELATIVE,
])
KNOWN_RECEIVE_CAPABILITIES = set(COMMON_CAPABILITIES + [
CAPABILITY_REPORT_STATUS])
def agent_string():
return ('dulwich/%d.%d.%d' % dulwich.__version__).encode('ascii')
def capability_agent():
return CAPABILITY_AGENT + b'=' + agent_string()
def capability_symref(from_ref, to_ref):
return CAPABILITY_SYMREF + b'=' + from_ref + b':' + to_ref
def extract_capability_names(capabilities):
return set(parse_capability(c)[0] for c in capabilities)
def parse_capability(capability):
parts = capability.split(b'=', 1)
if len(parts) == 1:
return (parts[0], None)
return tuple(parts)
def symref_capabilities(symrefs):
return [capability_symref(*k) for k in symrefs]
COMMAND_DEEPEN = b'deepen'
COMMAND_SHALLOW = b'shallow'
COMMAND_UNSHALLOW = b'unshallow'
COMMAND_DONE = b'done'
COMMAND_WANT = b'want'
COMMAND_HAVE = b'have'
class ProtocolFile(object):
"""A dummy file for network ops that expect file-like objects."""
def __init__(self, read, write):
self.read = read
self.write = write
def tell(self):
pass
def close(self):
pass
def pkt_line(data):
"""Wrap data in a pkt-line.
- :param data: The data to wrap, as a str or None.
- :return: The data prefixed with its length in pkt-line format; if data was
+ Args:
+ data: The data to wrap, as a str or None.
+ Returns: The data prefixed with its length in pkt-line format; if data was
None, returns the flush-pkt ('0000').
"""
if data is None:
return b'0000'
return ('%04x' % (len(data) + 4)).encode('ascii') + data
class Protocol(object):
"""Class for interacting with a remote git process over the wire.
Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line
consists of the length of the line as a 4-byte hex string, followed by the
payload data. The length includes the 4-byte header. The special line
'0000' indicates the end of a section of input and is called a 'flush-pkt'.
For details on the pkt-line format, see the cgit distribution:
Documentation/technical/protocol-common.txt
"""
def __init__(self, read, write, close=None, report_activity=None):
self.read = read
self.write = write
self._close = close
self.report_activity = report_activity
self._readahead = None
def close(self):
if self._close:
self._close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def read_pkt_line(self):
"""Reads a pkt-line from the remote git process.
This method may read from the readahead buffer; see unread_pkt_line.
- :return: The next string from the stream, without the length prefix, or
+ Returns: The next string from the stream, without the length prefix, or
None for a flush-pkt ('0000').
"""
if self._readahead is None:
read = self.read
else:
read = self._readahead.read
self._readahead = None
try:
sizestr = read(4)
if not sizestr:
raise HangupException()
size = int(sizestr, 16)
if size == 0:
if self.report_activity:
self.report_activity(4, 'read')
return None
if self.report_activity:
self.report_activity(size, 'read')
pkt_contents = read(size-4)
except socket.error as e:
raise GitProtocolError(e)
else:
if len(pkt_contents) + 4 != size:
raise GitProtocolError(
'Length of pkt read %04x does not match length prefix %04x'
% (len(pkt_contents) + 4, size))
return pkt_contents
def eof(self):
"""Test whether the protocol stream has reached EOF.
Note that this refers to the actual stream EOF and not just a
flush-pkt.
- :return: True if the stream is at EOF, False otherwise.
+ Returns: True if the stream is at EOF, False otherwise.
"""
try:
next_line = self.read_pkt_line()
except HangupException:
return True
self.unread_pkt_line(next_line)
return False
def unread_pkt_line(self, data):
"""Unread a single line of data into the readahead buffer.
This method can be used to unread a single pkt-line into a fixed
readahead buffer.
- :param data: The data to unread, without the length prefix.
- :raise ValueError: If more than one pkt-line is unread.
+ Args:
+ data: The data to unread, without the length prefix.
+ Raises:
+ ValueError: If more than one pkt-line is unread.
"""
if self._readahead is not None:
raise ValueError('Attempted to unread multiple pkt-lines.')
self._readahead = BytesIO(pkt_line(data))
def read_pkt_seq(self):
"""Read a sequence of pkt-lines from the remote git process.
- :return: Yields each line of data up to but not including the next
+ Returns: Yields each line of data up to but not including the next
flush-pkt.
"""
pkt = self.read_pkt_line()
while pkt:
yield pkt
pkt = self.read_pkt_line()
def write_pkt_line(self, line):
"""Sends a pkt-line to the remote git process.
- :param line: A string containing the data to send, without the length
+ Args:
+ line: A string containing the data to send, without the length
prefix.
"""
try:
line = pkt_line(line)
self.write(line)
if self.report_activity:
self.report_activity(len(line), 'write')
except socket.error as e:
raise GitProtocolError(e)
def write_file(self):
"""Return a writable file-like object for this protocol."""
class ProtocolFile(object):
def __init__(self, proto):
self._proto = proto
self._offset = 0
def write(self, data):
self._proto.write(data)
self._offset += len(data)
def tell(self):
return self._offset
def close(self):
pass
return ProtocolFile(self)
def write_sideband(self, channel, blob):
"""Write multiplexed data to the sideband.
- :param channel: An int specifying the channel to write to.
- :param blob: A blob of data (as a string) to send on this channel.
+ Args:
+ channel: An int specifying the channel to write to.
+ blob: A blob of data (as a string) to send on this channel.
"""
# a pktline can be a max of 65520. a sideband line can therefore be
# 65520-5 = 65515
# WTF: Why have the len in ASCII, but the channel in binary.
while blob:
self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515])
blob = blob[65515:]
def send_cmd(self, cmd, *args):
"""Send a command and some arguments to a git server.
Only used for the TCP git protocol (git://).
- :param cmd: The remote service to access.
- :param args: List of arguments to send to remove service.
+ Args:
+ cmd: The remote service to access.
+ args: List of arguments to send to remove service.
"""
self.write_pkt_line(cmd + b" " + b"".join([(a + b"\0") for a in args]))
def read_cmd(self):
"""Read a command and some arguments from the git client
Only used for the TCP git protocol (git://).
- :return: A tuple of (command, [list of arguments]).
+ Returns: A tuple of (command, [list of arguments]).
"""
line = self.read_pkt_line()
splice_at = line.find(b" ")
cmd, args = line[:splice_at], line[splice_at+1:]
assert args[-1:] == b"\x00"
return cmd, args[:-1].split(b"\0")
_RBUFSIZE = 8192 # Default read buffer size.
class ReceivableProtocol(Protocol):
"""Variant of Protocol that allows reading up to a size without blocking.
This class has a recv() method that behaves like socket.recv() in addition
to a read() method.
If you want to read n bytes from the wire and block until exactly n bytes
(or EOF) are read, use read(n). If you want to read at most n bytes from
the wire but don't care if you get less, use recv(n). Note that recv(n)
will still block until at least one byte is read.
"""
def __init__(self, recv, write, close=None, report_activity=None,
rbufsize=_RBUFSIZE):
super(ReceivableProtocol, self).__init__(
self.read, write, close=close, report_activity=report_activity)
self._recv = recv
self._rbuf = BytesIO()
self._rbufsize = rbufsize
def read(self, size):
# From _fileobj.read in socket.py in the Python 2.6.5 standard library,
# with the following modifications:
# - omit the size <= 0 branch
# - seek back to start rather than 0 in case some buffer has been
# consumed.
# - use SEEK_END instead of the magic number.
# Copyright (c) 2001-2010 Python Software Foundation; All Rights
# Reserved
# Licensed under the Python Software Foundation License.
# TODO: see if buffer is more efficient than cBytesIO.
assert size > 0
# Our use of BytesIO rather than lists of string objects returned by
# recv() minimizes memory usage and fragmentation that occurs when
# rbufsize is large compared to the typical return value of recv().
buf = self._rbuf
start = buf.tell()
buf.seek(0, SEEK_END)
# buffer may have been partially consumed by recv()
buf_len = buf.tell() - start
if buf_len >= size:
# Already have size bytes in our buffer? Extract and return.
buf.seek(start)
rv = buf.read(size)
self._rbuf = BytesIO()
self._rbuf.write(buf.read())
self._rbuf.seek(0)
return rv
self._rbuf = BytesIO() # reset _rbuf. we consume it via buf.
while True:
left = size - buf_len
# recv() will malloc the amount of memory given as its
# parameter even though it often returns much less data
# than that. The returned data string is short lived
# as we copy it into a BytesIO and free it. This avoids
# fragmentation issues on many platforms.
data = self._recv(left)
if not data:
break
n = len(data)
if n == size and not buf_len:
# Shortcut. Avoid buffer data copies when:
# - We have no data in our buffer.
# AND
# - Our call to recv returned exactly the
# number of bytes we were asked to read.
return data
if n == left:
buf.write(data)
del data # explicit free
break
assert n <= left, "_recv(%d) returned %d bytes" % (left, n)
buf.write(data)
buf_len += n
del data # explicit free
# assert buf_len == buf.tell()
buf.seek(start)
return buf.read()
def recv(self, size):
assert size > 0
buf = self._rbuf
start = buf.tell()
buf.seek(0, SEEK_END)
buf_len = buf.tell()
buf.seek(start)
left = buf_len - start
if not left:
# only read from the wire if our read buffer is exhausted
data = self._recv(self._rbufsize)
if len(data) == size:
# shortcut: skip the buffer if we read exactly size bytes
return data
buf = BytesIO()
buf.write(data)
buf.seek(0)
del data # explicit free
self._rbuf = buf
return buf.read(size)
def extract_capabilities(text):
"""Extract a capabilities list from a string, if present.
- :param text: String to extract from
- :return: Tuple with text with capabilities removed and list of capabilities
+ Args:
+ text: String to extract from
+ Returns: Tuple with text with capabilities removed and list of capabilities
"""
if b"\0" not in text:
return text, []
text, capabilities = text.rstrip().split(b"\0")
return (text, capabilities.strip().split(b" "))
def extract_want_line_capabilities(text):
"""Extract a capabilities list from a want line, if present.
Note that want lines have capabilities separated from the rest of the line
by a space instead of a null byte. Thus want lines have the form:
want obj-id cap1 cap2 ...
- :param text: Want line to extract from
- :return: Tuple with text with capabilities removed and list of capabilities
+ Args:
+ text: Want line to extract from
+ Returns: Tuple with text with capabilities removed and list of capabilities
"""
split_text = text.rstrip().split(b" ")
if len(split_text) < 3:
return text, []
return (b" ".join(split_text[:2]), split_text[2:])
def ack_type(capabilities):
"""Extract the ack type from a capabilities list."""
if b'multi_ack_detailed' in capabilities:
return MULTI_ACK_DETAILED
elif b'multi_ack' in capabilities:
return MULTI_ACK
return SINGLE_ACK
class BufferedPktLineWriter(object):
"""Writer that wraps its data in pkt-lines and has an independent buffer.
Consecutive calls to write() wrap the data in a pkt-line and then buffers
it until enough lines have been written such that their total length
(including length prefix) reach the buffer size.
"""
def __init__(self, write, bufsize=65515):
"""Initialize the BufferedPktLineWriter.
- :param write: A write callback for the underlying writer.
- :param bufsize: The internal buffer size, including length prefixes.
+ Args:
+ write: A write callback for the underlying writer.
+ bufsize: The internal buffer size, including length prefixes.
"""
self._write = write
self._bufsize = bufsize
self._wbuf = BytesIO()
self._buflen = 0
def write(self, data):
"""Write data, wrapping it in a pkt-line."""
line = pkt_line(data)
line_len = len(line)
over = self._buflen + line_len - self._bufsize
if over >= 0:
start = line_len - over
self._wbuf.write(line[:start])
self.flush()
else:
start = 0
saved = line[start:]
self._wbuf.write(saved)
self._buflen += len(saved)
def flush(self):
"""Flush all data from the buffer."""
data = self._wbuf.getvalue()
if data:
self._write(data)
self._len = 0
self._wbuf = BytesIO()
class PktLineParser(object):
"""Packet line parser that hands completed packets off to a callback.
"""
def __init__(self, handle_pkt):
self.handle_pkt = handle_pkt
self._readahead = BytesIO()
def parse(self, data):
"""Parse a fragment of data and call back for any completed packets.
"""
self._readahead.write(data)
buf = self._readahead.getvalue()
if len(buf) < 4:
return
while len(buf) >= 4:
size = int(buf[:4], 16)
if size == 0:
self.handle_pkt(None)
buf = buf[4:]
elif size <= len(buf):
self.handle_pkt(buf[4:size])
buf = buf[size:]
else:
break
self._readahead = BytesIO()
self._readahead.write(buf)
def get_tail(self):
"""Read back any unused data."""
return self._readahead.getvalue()
diff --git a/dulwich/reflog.py b/dulwich/reflog.py
index aec32e60..37a2ff8c 100644
--- a/dulwich/reflog.py
+++ b/dulwich/reflog.py
@@ -1,76 +1,79 @@
# reflog.py -- Parsing and writing reflog files
# Copyright (C) 2015 Jelmer Vernooij and others.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Utilities for reading and generating reflogs.
"""
import collections
from dulwich.objects import (
format_timezone,
parse_timezone,
ZERO_SHA,
)
Entry = collections.namedtuple(
'Entry', ['old_sha', 'new_sha', 'committer', 'timestamp', 'timezone',
'message'])
def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
message):
"""Generate a single reflog line.
- :param old_sha: Old Commit SHA
- :param new_sha: New Commit SHA
- :param committer: Committer name and e-mail
- :param timestamp: Timestamp
- :param timezone: Timezone
- :param message: Message
+ Args:
+ old_sha: Old Commit SHA
+ new_sha: New Commit SHA
+ committer: Committer name and e-mail
+ timestamp: Timestamp
+ timezone: Timezone
+ message: Message
"""
if old_sha is None:
old_sha = ZERO_SHA
return (old_sha + b' ' + new_sha + b' ' + committer + b' ' +
str(int(timestamp)).encode('ascii') + b' ' +
format_timezone(timezone) + b'\t' + message)
def parse_reflog_line(line):
"""Parse a reflog line.
- :param line: Line to parse
- :return: Tuple of (old_sha, new_sha, committer, timestamp, timezone,
+ Args:
+ line: Line to parse
+ Returns: Tuple of (old_sha, new_sha, committer, timestamp, timezone,
message)
"""
(begin, message) = line.split(b'\t', 1)
(old_sha, new_sha, rest) = begin.split(b' ', 2)
(committer, timestamp_str, timezone_str) = rest.rsplit(b' ', 2)
return Entry(old_sha, new_sha, committer, int(timestamp_str),
parse_timezone(timezone_str)[0], message)
def read_reflog(f):
"""Read reflog.
- :param f: File-like object
- :returns: Iterator over Entry objects
+ Args:
+ f: File-like object
+ Returns: Iterator over Entry objects
"""
for l in f:
yield parse_reflog_line(l)
diff --git a/dulwich/refs.py b/dulwich/refs.py
index 75f283fb..d597289d 100644
--- a/dulwich/refs.py
+++ b/dulwich/refs.py
@@ -1,946 +1,972 @@
# refs.py -- For dealing with git refs
# Copyright (C) 2008-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Ref handling.
"""
import errno
import os
import sys
from dulwich.errors import (
PackedRefsException,
RefFormatError,
)
from dulwich.objects import (
git_line,
valid_hexsha,
ZERO_SHA,
)
from dulwich.file import (
GitFile,
ensure_dir_exists,
)
SYMREF = b'ref: '
LOCAL_BRANCH_PREFIX = b'refs/heads/'
LOCAL_TAG_PREFIX = b'refs/tags/'
BAD_REF_CHARS = set(b'\177 ~^:?*[')
ANNOTATED_TAG_SUFFIX = b'^{}'
def parse_symref_value(contents):
"""Parse a symref value.
- :param contents: Contents to parse
- :return: Destination
+ Args:
+ contents: Contents to parse
+ Returns: Destination
"""
if contents.startswith(SYMREF):
return contents[len(SYMREF):].rstrip(b'\r\n')
raise ValueError(contents)
def check_ref_format(refname):
"""Check if a refname is correctly formatted.
Implements all the same rules as git-check-ref-format[1].
[1]
http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
- :param refname: The refname to check
- :return: True if refname is valid, False otherwise
+ Args:
+ refname: The refname to check
+ Returns: True if refname is valid, False otherwise
"""
# These could be combined into one big expression, but are listed
# separately to parallel [1].
if b'/.' in refname or refname.startswith(b'.'):
return False
if b'/' not in refname:
return False
if b'..' in refname:
return False
for i, c in enumerate(refname):
if ord(refname[i:i+1]) < 0o40 or c in BAD_REF_CHARS:
return False
if refname[-1] in b'/.':
return False
if refname.endswith(b'.lock'):
return False
if b'@{' in refname:
return False
if b'\\' in refname:
return False
return True
class RefsContainer(object):
"""A container for refs."""
def __init__(self, logger=None):
self._logger = logger
def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None,
timezone=None, message=None):
if self._logger is None:
return
if message is None:
return
self._logger(ref, old_sha, new_sha, committer, timestamp,
timezone, message)
def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
timezone=None, message=None):
"""Make a ref point at another ref.
- :param name: Name of the ref to set
- :param other: Name of the ref to point at
- :param message: Optional message
+ Args:
+ name: Name of the ref to set
+ other: Name of the ref to point at
+ message: Optional message
"""
raise NotImplementedError(self.set_symbolic_ref)
def get_packed_refs(self):
"""Get contents of the packed-refs file.
- :return: Dictionary mapping ref names to SHA1s
+ Returns: Dictionary mapping ref names to SHA1s
- :note: Will return an empty dictionary when no packed-refs file is
+ Note: Will return an empty dictionary when no packed-refs file is
present.
"""
raise NotImplementedError(self.get_packed_refs)
def get_peeled(self, name):
"""Return the cached peeled value of a ref, if available.
- :param name: Name of the ref to peel
- :return: The peeled value of the ref. If the ref is known not point to
+ Args:
+ name: Name of the ref to peel
+ Returns: The peeled value of the ref. If the ref is known not point to
a tag, this will be the SHA the ref refers to. If the ref may point
to a tag, but no cached information is available, None is returned.
"""
return None
def import_refs(self, base, other, committer=None, timestamp=None,
timezone=None, message=None, prune=False):
if prune:
to_delete = set(self.subkeys(base))
else:
to_delete = set()
for name, value in other.items():
self.set_if_equals(b'/'.join((base, name)), None, value,
message=message)
if to_delete:
try:
to_delete.remove(name)
except KeyError:
pass
for ref in to_delete:
self.remove_if_equals(b'/'.join((base, ref)), None)
def allkeys(self):
"""All refs present in this container."""
raise NotImplementedError(self.allkeys)
def __iter__(self):
return iter(self.allkeys())
def keys(self, base=None):
"""Refs present in this container.
- :param base: An optional base to return refs under.
- :return: An unsorted set of valid refs in this container, including
+ Args:
+ base: An optional base to return refs under.
+ Returns: An unsorted set of valid refs in this container, including
packed refs.
"""
if base is not None:
return self.subkeys(base)
else:
return self.allkeys()
def subkeys(self, base):
"""Refs present in this container under a base.
- :param base: The base to return refs under.
- :return: A set of valid refs in this container under the base; the base
+ Args:
+ base: The base to return refs under.
+ Returns: A set of valid refs in this container under the base; the base
prefix is stripped from the ref names returned.
"""
keys = set()
base_len = len(base) + 1
for refname in self.allkeys():
if refname.startswith(base):
keys.add(refname[base_len:])
return keys
def as_dict(self, base=None):
"""Return the contents of this container as a dictionary.
"""
ret = {}
keys = self.keys(base)
if base is None:
base = b''
else:
base = base.rstrip(b'/')
for key in keys:
try:
ret[key] = self[(base + b'/' + key).strip(b'/')]
except KeyError:
continue # Unable to resolve
return ret
def _check_refname(self, name):
"""Ensure a refname is valid and lives in refs or is HEAD.
HEAD is not a valid refname according to git-check-ref-format, but this
class needs to be able to touch HEAD. Also, check_ref_format expects
refnames without the leading 'refs/', but this class requires that
so it cannot touch anything outside the refs dir (or HEAD).
- :param name: The name of the reference.
- :raises KeyError: if a refname is not HEAD or is otherwise not valid.
+ Args:
+ name: The name of the reference.
+ Raises:
+ KeyError: if a refname is not HEAD or is otherwise not valid.
"""
if name in (b'HEAD', b'refs/stash'):
return
if not name.startswith(b'refs/') or not check_ref_format(name[5:]):
raise RefFormatError(name)
def read_ref(self, refname):
"""Read a reference without following any references.
- :param refname: The name of the reference
- :return: The contents of the ref file, or None if it does
+ Args:
+ refname: The name of the reference
+ Returns: The contents of the ref file, or None if it does
not exist.
"""
contents = self.read_loose_ref(refname)
if not contents:
contents = self.get_packed_refs().get(refname, None)
return contents
def read_loose_ref(self, name):
"""Read a loose reference and return its contents.
- :param name: the refname to read
- :return: The contents of the ref file, or None if it does
+ Args:
+ name: the refname to read
+ Returns: The contents of the ref file, or None if it does
not exist.
"""
raise NotImplementedError(self.read_loose_ref)
def follow(self, name):
"""Follow a reference name.
- :return: a tuple of (refnames, sha), wheres refnames are the names of
+ Returns: a tuple of (refnames, sha), wheres refnames are the names of
references in the chain
"""
contents = SYMREF + name
depth = 0
refnames = []
while contents.startswith(SYMREF):
refname = contents[len(SYMREF):]
refnames.append(refname)
contents = self.read_ref(refname)
if not contents:
break
depth += 1
if depth > 5:
raise KeyError(name)
return refnames, contents
def _follow(self, name):
import warnings
warnings.warn(
"RefsContainer._follow is deprecated. Use RefsContainer.follow "
"instead.", DeprecationWarning)
refnames, contents = self.follow(name)
if not refnames:
return (None, contents)
return (refnames[-1], contents)
def __contains__(self, refname):
if self.read_ref(refname):
return True
return False
def __getitem__(self, name):
"""Get the SHA1 for a reference name.
This method follows all symbolic references.
"""
_, sha = self.follow(name)
if sha is None:
raise KeyError(name)
return sha
def set_if_equals(self, name, old_ref, new_ref, committer=None,
timestamp=None, timezone=None, message=None):
"""Set a refname to new_ref only if it currently equals old_ref.
This method follows all symbolic references if applicable for the
subclass, and can be used to perform an atomic compare-and-swap
operation.
- :param name: The refname to set.
- :param old_ref: The old sha the refname must refer to, or None to set
+ Args:
+ name: The refname to set.
+ old_ref: The old sha the refname must refer to, or None to set
unconditionally.
- :param new_ref: The new sha the refname will refer to.
- :param message: Message for reflog
- :return: True if the set was successful, False otherwise.
+ new_ref: The new sha the refname will refer to.
+ message: Message for reflog
+ Returns: True if the set was successful, False otherwise.
"""
raise NotImplementedError(self.set_if_equals)
def add_if_new(self, name, ref):
"""Add a new reference only if it does not already exist.
- :param name: Ref name
- :param ref: Ref value
- :param message: Message for reflog
+ Args:
+ name: Ref name
+ ref: Ref value
+ message: Message for reflog
"""
raise NotImplementedError(self.add_if_new)
def __setitem__(self, name, ref):
"""Set a reference name to point to the given SHA1.
This method follows all symbolic references if applicable for the
subclass.
- :note: This method unconditionally overwrites the contents of a
+ Note: This method unconditionally overwrites the contents of a
reference. To update atomically only if the reference has not
changed, use set_if_equals().
- :param name: The refname to set.
- :param ref: The new sha the refname will refer to.
+
+ Args:
+ name: The refname to set.
+ ref: The new sha the refname will refer to.
"""
self.set_if_equals(name, None, ref)
def remove_if_equals(self, name, old_ref, committer=None,
timestamp=None, timezone=None, message=None):
"""Remove a refname only if it currently equals old_ref.
This method does not follow symbolic references, even if applicable for
the subclass. It can be used to perform an atomic compare-and-delete
operation.
- :param name: The refname to delete.
- :param old_ref: The old sha the refname must refer to, or None to
+ Args:
+ name: The refname to delete.
+ old_ref: The old sha the refname must refer to, or None to
delete unconditionally.
- :param message: Message for reflog
- :return: True if the delete was successful, False otherwise.
+ message: Message for reflog
+ Returns: True if the delete was successful, False otherwise.
"""
raise NotImplementedError(self.remove_if_equals)
def __delitem__(self, name):
"""Remove a refname.
This method does not follow symbolic references, even if applicable for
the subclass.
- :note: This method unconditionally deletes the contents of a reference.
+ Note: This method unconditionally deletes the contents of a reference.
To delete atomically only if the reference has not changed, use
remove_if_equals().
- :param name: The refname to delete.
+ Args:
+ name: The refname to delete.
"""
self.remove_if_equals(name, None)
def get_symrefs(self):
"""Get a dict with all symrefs in this container.
- :return: Dictionary mapping source ref to target ref
+ Returns: Dictionary mapping source ref to target ref
"""
ret = {}
for src in self.allkeys():
try:
dst = parse_symref_value(self.read_ref(src))
except ValueError:
pass
else:
ret[src] = dst
return ret
class DictRefsContainer(RefsContainer):
"""RefsContainer backed by a simple dict.
This container does not support symbolic or packed references and is not
threadsafe.
"""
def __init__(self, refs, logger=None):
super(DictRefsContainer, self).__init__(logger=logger)
self._refs = refs
self._peeled = {}
def allkeys(self):
return self._refs.keys()
def read_loose_ref(self, name):
return self._refs.get(name, None)
def get_packed_refs(self):
return {}
def set_symbolic_ref(self, name, other, committer=None,
timestamp=None, timezone=None, message=None):
old = self.follow(name)[-1]
self._refs[name] = SYMREF + other
self._log(name, old, old, committer=committer, timestamp=timestamp,
timezone=timezone, message=message)
def set_if_equals(self, name, old_ref, new_ref, committer=None,
timestamp=None, timezone=None, message=None):
if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
return False
realnames, _ = self.follow(name)
for realname in realnames:
self._check_refname(realname)
old = self._refs.get(realname)
self._refs[realname] = new_ref
self._log(realname, old, new_ref, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
return True
def add_if_new(self, name, ref, committer=None, timestamp=None,
timezone=None, message=None):
if name in self._refs:
return False
self._refs[name] = ref
self._log(name, None, ref, committer=committer, timestamp=timestamp,
timezone=timezone, message=message)
return True
def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
timezone=None, message=None):
if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
return False
try:
old = self._refs.pop(name)
except KeyError:
pass
else:
self._log(name, old, None, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
return True
def get_peeled(self, name):
return self._peeled.get(name)
def _update(self, refs):
"""Update multiple refs; intended only for testing."""
# TODO(dborowitz): replace this with a public function that uses
# set_if_equal.
self._refs.update(refs)
def _update_peeled(self, peeled):
"""Update cached peeled refs; intended only for testing."""
self._peeled.update(peeled)
class InfoRefsContainer(RefsContainer):
"""Refs container that reads refs from a info/refs file."""
def __init__(self, f):
self._refs = {}
self._peeled = {}
for l in f.readlines():
sha, name = l.rstrip(b'\n').split(b'\t')
if name.endswith(ANNOTATED_TAG_SUFFIX):
name = name[:-3]
if not check_ref_format(name):
raise ValueError("invalid ref name %r" % name)
self._peeled[name] = sha
else:
if not check_ref_format(name):
raise ValueError("invalid ref name %r" % name)
self._refs[name] = sha
def allkeys(self):
return self._refs.keys()
def read_loose_ref(self, name):
return self._refs.get(name, None)
def get_packed_refs(self):
return {}
def get_peeled(self, name):
try:
return self._peeled[name]
except KeyError:
return self._refs[name]
class DiskRefsContainer(RefsContainer):
"""Refs container that reads refs from disk."""
def __init__(self, path, worktree_path=None, logger=None):
super(DiskRefsContainer, self).__init__(logger=logger)
if getattr(path, 'encode', None) is not None:
path = path.encode(sys.getfilesystemencoding())
self.path = path
if worktree_path is None:
worktree_path = path
if getattr(worktree_path, 'encode', None) is not None:
worktree_path = worktree_path.encode(sys.getfilesystemencoding())
self.worktree_path = worktree_path
self._packed_refs = None
self._peeled_refs = None
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.path)
def subkeys(self, base):
subkeys = set()
path = self.refpath(base)
for root, unused_dirs, files in os.walk(path):
dir = root[len(path):]
if os.path.sep != '/':
dir = dir.replace(os.path.sep.encode(
sys.getfilesystemencoding()), b"/")
dir = dir.strip(b'/')
for filename in files:
refname = b"/".join(([dir] if dir else []) + [filename])
# check_ref_format requires at least one /, so we prepend the
# base before calling it.
if check_ref_format(base + b'/' + refname):
subkeys.add(refname)
for key in self.get_packed_refs():
if key.startswith(base):
subkeys.add(key[len(base):].strip(b'/'))
return subkeys
def allkeys(self):
allkeys = set()
if os.path.exists(self.refpath(b'HEAD')):
allkeys.add(b'HEAD')
path = self.refpath(b'')
refspath = self.refpath(b'refs')
for root, unused_dirs, files in os.walk(refspath):
dir = root[len(path):]
if os.path.sep != '/':
dir = dir.replace(
os.path.sep.encode(sys.getfilesystemencoding()), b"/")
for filename in files:
refname = b"/".join([dir, filename])
if check_ref_format(refname):
allkeys.add(refname)
allkeys.update(self.get_packed_refs())
return allkeys
def refpath(self, name):
"""Return the disk path of a ref.
"""
if os.path.sep != "/":
name = name.replace(
b"/",
os.path.sep.encode(sys.getfilesystemencoding()))
# TODO: as the 'HEAD' reference is working tree specific, it
# should actually not be a part of RefsContainer
if name == b'HEAD':
return os.path.join(self.worktree_path, name)
else:
return os.path.join(self.path, name)
def get_packed_refs(self):
"""Get contents of the packed-refs file.
- :return: Dictionary mapping ref names to SHA1s
+ Returns: Dictionary mapping ref names to SHA1s
- :note: Will return an empty dictionary when no packed-refs file is
+ Note: Will return an empty dictionary when no packed-refs file is
present.
"""
# TODO: invalidate the cache on repacking
if self._packed_refs is None:
# set both to empty because we want _peeled_refs to be
# None if and only if _packed_refs is also None.
self._packed_refs = {}
self._peeled_refs = {}
path = os.path.join(self.path, b'packed-refs')
try:
f = GitFile(path, 'rb')
except IOError as e:
if e.errno == errno.ENOENT:
return {}
raise
with f:
first_line = next(iter(f)).rstrip()
if (first_line.startswith(b'# pack-refs') and b' peeled' in
first_line):
for sha, name, peeled in read_packed_refs_with_peeled(f):
self._packed_refs[name] = sha
if peeled:
self._peeled_refs[name] = peeled
else:
f.seek(0)
for sha, name in read_packed_refs(f):
self._packed_refs[name] = sha
return self._packed_refs
def get_peeled(self, name):
"""Return the cached peeled value of a ref, if available.
- :param name: Name of the ref to peel
- :return: The peeled value of the ref. If the ref is known not point to
+ Args:
+ name: Name of the ref to peel
+ Returns: The peeled value of the ref. If the ref is known not point to
a tag, this will be the SHA the ref refers to. If the ref may point
to a tag, but no cached information is available, None is returned.
"""
self.get_packed_refs()
if self._peeled_refs is None or name not in self._packed_refs:
# No cache: no peeled refs were read, or this ref is loose
return None
if name in self._peeled_refs:
return self._peeled_refs[name]
else:
# Known not peelable
return self[name]
def read_loose_ref(self, name):
"""Read a reference file and return its contents.
If the reference file a symbolic reference, only read the first line of
the file. Otherwise, only read the first 40 bytes.
- :param name: the refname to read, relative to refpath
- :return: The contents of the ref file, or None if the file does not
+ Args:
+ name: the refname to read, relative to refpath
+ Returns: The contents of the ref file, or None if the file does not
exist.
- :raises IOError: if any other error occurs
+ Raises:
+ IOError: if any other error occurs
"""
filename = self.refpath(name)
try:
with GitFile(filename, 'rb') as f:
header = f.read(len(SYMREF))
if header == SYMREF:
# Read only the first line
return header + next(iter(f)).rstrip(b'\r\n')
else:
# Read only the first 40 bytes
return header + f.read(40 - len(SYMREF))
except IOError as e:
if e.errno in (errno.ENOENT, errno.EISDIR, errno.ENOTDIR):
return None
raise
def _remove_packed_ref(self, name):
if self._packed_refs is None:
return
filename = os.path.join(self.path, b'packed-refs')
# reread cached refs from disk, while holding the lock
f = GitFile(filename, 'wb')
try:
self._packed_refs = None
self.get_packed_refs()
if name not in self._packed_refs:
return
del self._packed_refs[name]
if name in self._peeled_refs:
del self._peeled_refs[name]
write_packed_refs(f, self._packed_refs, self._peeled_refs)
f.close()
finally:
f.abort()
def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
timezone=None, message=None):
"""Make a ref point at another ref.
- :param name: Name of the ref to set
- :param other: Name of the ref to point at
- :param message: Optional message to describe the change
+ Args:
+ name: Name of the ref to set
+ other: Name of the ref to point at
+ message: Optional message to describe the change
"""
self._check_refname(name)
self._check_refname(other)
filename = self.refpath(name)
f = GitFile(filename, 'wb')
try:
f.write(SYMREF + other + b'\n')
sha = self.follow(name)[-1]
self._log(name, sha, sha, committer=committer,
timestamp=timestamp, timezone=timezone,
message=message)
except BaseException:
f.abort()
raise
else:
f.close()
def set_if_equals(self, name, old_ref, new_ref, committer=None,
timestamp=None, timezone=None, message=None):
"""Set a refname to new_ref only if it currently equals old_ref.
This method follows all symbolic references, and can be used to perform
an atomic compare-and-swap operation.
- :param name: The refname to set.
- :param old_ref: The old sha the refname must refer to, or None to set
+ Args:
+ name: The refname to set.
+ old_ref: The old sha the refname must refer to, or None to set
unconditionally.
- :param new_ref: The new sha the refname will refer to.
- :param message: Set message for reflog
- :return: True if the set was successful, False otherwise.
+ new_ref: The new sha the refname will refer to.
+ message: Set message for reflog
+ Returns: True if the set was successful, False otherwise.
"""
self._check_refname(name)
try:
realnames, _ = self.follow(name)
realname = realnames[-1]
except (KeyError, IndexError):
realname = name
filename = self.refpath(realname)
# make sure none of the ancestor folders is in packed refs
probe_ref = os.path.dirname(realname)
packed_refs = self.get_packed_refs()
while probe_ref:
if packed_refs.get(probe_ref, None) is not None:
raise OSError(errno.ENOTDIR,
'Not a directory: {}'.format(filename))
probe_ref = os.path.dirname(probe_ref)
ensure_dir_exists(os.path.dirname(filename))
with GitFile(filename, 'wb') as f:
if old_ref is not None:
try:
# read again while holding the lock
orig_ref = self.read_loose_ref(realname)
if orig_ref is None:
orig_ref = self.get_packed_refs().get(
realname, ZERO_SHA)
if orig_ref != old_ref:
f.abort()
return False
except (OSError, IOError):
f.abort()
raise
try:
f.write(new_ref + b'\n')
except (OSError, IOError):
f.abort()
raise
self._log(realname, old_ref, new_ref, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
return True
def add_if_new(self, name, ref, committer=None, timestamp=None,
timezone=None, message=None):
"""Add a new reference only if it does not already exist.
This method follows symrefs, and only ensures that the last ref in the
chain does not exist.
- :param name: The refname to set.
- :param ref: The new sha the refname will refer to.
- :param message: Optional message for reflog
- :return: True if the add was successful, False otherwise.
+ Args:
+ name: The refname to set.
+ ref: The new sha the refname will refer to.
+ message: Optional message for reflog
+ Returns: True if the add was successful, False otherwise.
"""
try:
realnames, contents = self.follow(name)
if contents is not None:
return False
realname = realnames[-1]
except (KeyError, IndexError):
realname = name
self._check_refname(realname)
filename = self.refpath(realname)
ensure_dir_exists(os.path.dirname(filename))
with GitFile(filename, 'wb') as f:
if os.path.exists(filename) or name in self.get_packed_refs():
f.abort()
return False
try:
f.write(ref + b'\n')
except (OSError, IOError):
f.abort()
raise
else:
self._log(name, None, ref, committer=committer,
timestamp=timestamp, timezone=timezone,
message=message)
return True
def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
timezone=None, message=None):
"""Remove a refname only if it currently equals old_ref.
This method does not follow symbolic references. It can be used to
perform an atomic compare-and-delete operation.
- :param name: The refname to delete.
- :param old_ref: The old sha the refname must refer to, or None to
+ Args:
+ name: The refname to delete.
+ old_ref: The old sha the refname must refer to, or None to
delete unconditionally.
- :param message: Optional message
- :return: True if the delete was successful, False otherwise.
+ message: Optional message
+ Returns: True if the delete was successful, False otherwise.
"""
self._check_refname(name)
filename = self.refpath(name)
ensure_dir_exists(os.path.dirname(filename))
f = GitFile(filename, 'wb')
try:
if old_ref is not None:
orig_ref = self.read_loose_ref(name)
if orig_ref is None:
orig_ref = self.get_packed_refs().get(name, ZERO_SHA)
if orig_ref != old_ref:
return False
# remove the reference file itself
try:
os.remove(filename)
except OSError as e:
if e.errno != errno.ENOENT: # may only be packed
raise
self._remove_packed_ref(name)
self._log(name, old_ref, None, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
finally:
# never write, we just wanted the lock
f.abort()
# outside of the lock, clean-up any parent directory that might now
# be empty. this ensures that re-creating a reference of the same
# name of what was previously a directory works as expected
parent = name
while True:
try:
parent, _ = parent.rsplit(b'/', 1)
except ValueError:
break
parent_filename = self.refpath(parent)
try:
os.rmdir(parent_filename)
except OSError:
# this can be caused by the parent directory being
# removed by another process, being not empty, etc.
# in any case, this is non fatal because we already
# removed the reference, just ignore it
break
return True
def _split_ref_line(line):
"""Split a single ref line into a tuple of SHA1 and name."""
fields = line.rstrip(b'\n\r').split(b' ')
if len(fields) != 2:
raise PackedRefsException("invalid ref line %r" % line)
sha, name = fields
if not valid_hexsha(sha):
raise PackedRefsException("Invalid hex sha %r" % sha)
if not check_ref_format(name):
raise PackedRefsException("invalid ref name %r" % name)
return (sha, name)
def read_packed_refs(f):
"""Read a packed refs file.
- :param f: file-like object to read from
- :return: Iterator over tuples with SHA1s and ref names.
+ Args:
+ f: file-like object to read from
+ Returns: Iterator over tuples with SHA1s and ref names.
"""
for l in f:
if l.startswith(b'#'):
# Comment
continue
if l.startswith(b'^'):
raise PackedRefsException(
"found peeled ref in packed-refs without peeled")
yield _split_ref_line(l)
def read_packed_refs_with_peeled(f):
"""Read a packed refs file including peeled refs.
Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
with ref names, SHA1s, and peeled SHA1s (or None).
- :param f: file-like object to read from, seek'ed to the second line
+ Args:
+ f: file-like object to read from, seek'ed to the second line
"""
last = None
for line in f:
if line[0] == b'#':
continue
line = line.rstrip(b'\r\n')
if line.startswith(b'^'):
if not last:
raise PackedRefsException("unexpected peeled ref line")
if not valid_hexsha(line[1:]):
raise PackedRefsException("Invalid hex sha %r" % line[1:])
sha, name = _split_ref_line(last)
last = None
yield (sha, name, line[1:])
else:
if last:
sha, name = _split_ref_line(last)
yield (sha, name, None)
last = line
if last:
sha, name = _split_ref_line(last)
yield (sha, name, None)
def write_packed_refs(f, packed_refs, peeled_refs=None):
"""Write a packed refs file.
- :param f: empty file-like object to write to
- :param packed_refs: dict of refname to sha of packed refs to write
- :param peeled_refs: dict of refname to peeled value of sha
+ Args:
+ f: empty file-like object to write to
+ packed_refs: dict of refname to sha of packed refs to write
+ peeled_refs: dict of refname to peeled value of sha
"""
if peeled_refs is None:
peeled_refs = {}
else:
f.write(b'# pack-refs with: peeled\n')
for refname in sorted(packed_refs.keys()):
f.write(git_line(packed_refs[refname], refname))
if refname in peeled_refs:
f.write(b'^' + peeled_refs[refname] + b'\n')
def read_info_refs(f):
ret = {}
for l in f.readlines():
(sha, name) = l.rstrip(b"\r\n").split(b"\t", 1)
ret[name] = sha
return ret
def write_info_refs(refs, store):
"""Generate info refs."""
for name, sha in sorted(refs.items()):
# get_refs() includes HEAD as a special case, but we don't want to
# advertise it
if name == b'HEAD':
continue
try:
o = store[sha]
except KeyError:
continue
peeled = store.peel_sha(sha)
yield o.id + b'\t' + name + b'\n'
if o.id != peeled.id:
yield peeled.id + b'\t' + name + ANNOTATED_TAG_SUFFIX + b'\n'
def is_local_branch(x):
return x.startswith(LOCAL_BRANCH_PREFIX)
def strip_peeled_refs(refs):
"""Remove all peeled refs"""
return {ref: sha for (ref, sha) in refs.items()
if not ref.endswith(ANNOTATED_TAG_SUFFIX)}
diff --git a/dulwich/repo.py b/dulwich/repo.py
index ccb30f1a..1e1a4c42 100644
--- a/dulwich/repo.py
+++ b/dulwich/repo.py
@@ -1,1438 +1,1480 @@
# repo.py -- For dealing with git repositories.
# Copyright (C) 2007 James Westby
# Copyright (C) 2008-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Repository access.
This module contains the base class for git repositories
(BaseRepo) and an implementation which uses a repository on
local disk (Repo).
"""
from io import BytesIO
import errno
import os
import sys
import stat
import time
from dulwich.errors import (
NoIndexPresent,
NotBlobError,
NotCommitError,
NotGitRepository,
NotTreeError,
NotTagError,
CommitError,
RefFormatError,
HookError,
)
from dulwich.file import (
GitFile,
)
from dulwich.object_store import (
DiskObjectStore,
MemoryObjectStore,
ObjectStoreGraphWalker,
)
from dulwich.objects import (
check_hexsha,
Blob,
Commit,
ShaFile,
Tag,
Tree,
)
from dulwich.pack import (
pack_objects_to_data,
)
from dulwich.hooks import (
PreCommitShellHook,
PostCommitShellHook,
CommitMsgShellHook,
)
from dulwich.line_ending import BlobNormalizer
from dulwich.refs import ( # noqa: F401
ANNOTATED_TAG_SUFFIX,
check_ref_format,
RefsContainer,
DictRefsContainer,
InfoRefsContainer,
DiskRefsContainer,
read_packed_refs,
read_packed_refs_with_peeled,
write_packed_refs,
SYMREF,
)
import warnings
CONTROLDIR = '.git'
OBJECTDIR = 'objects'
REFSDIR = 'refs'
REFSDIR_TAGS = 'tags'
REFSDIR_HEADS = 'heads'
INDEX_FILENAME = "index"
COMMONDIR = 'commondir'
GITDIR = 'gitdir'
WORKTREES = 'worktrees'
BASE_DIRECTORIES = [
["branches"],
[REFSDIR],
[REFSDIR, REFSDIR_TAGS],
[REFSDIR, REFSDIR_HEADS],
["hooks"],
["info"]
]
DEFAULT_REF = b'refs/heads/master'
class InvalidUserIdentity(Exception):
"""User identity is not of the format 'user '"""
def __init__(self, identity):
self.identity = identity
def _get_default_identity():
import getpass
import socket
username = getpass.getuser()
try:
import pwd
except ImportError:
fullname = None
else:
try:
gecos = pwd.getpwnam(username).pw_gecos
except KeyError:
fullname = None
else:
fullname = gecos.split(',')[0]
if not fullname:
fullname = username
email = os.environ.get('EMAIL')
if email is None:
email = "{}@{}".format(username, socket.gethostname())
return (fullname, email)
def get_user_identity(config, kind=None):
"""Determine the identity to use for new commits.
"""
if kind:
user = os.environ.get("GIT_" + kind + "_NAME")
if user is not None:
user = user.encode('utf-8')
email = os.environ.get("GIT_" + kind + "_EMAIL")
if email is not None:
email = email.encode('utf-8')
else:
user = None
email = None
if user is None:
try:
user = config.get(("user", ), "name")
except KeyError:
user = None
if email is None:
try:
email = config.get(("user", ), "email")
except KeyError:
email = None
default_user, default_email = _get_default_identity()
if user is None:
user = default_user
if not isinstance(user, bytes):
user = user.encode('utf-8')
if email is None:
email = default_email
if not isinstance(email, bytes):
email = email.encode('utf-8')
if email.startswith(b'<') and email.endswith(b'>'):
email = email[1:-1]
return (user + b" <" + email + b">")
def check_user_identity(identity):
"""Verify that a user identity is formatted correctly.
- :param identity: User identity bytestring
- :raise InvalidUserIdentity: Raised when identity is invalid
+ Args:
+ identity: User identity bytestring
+ Raises:
+ InvalidUserIdentity: Raised when identity is invalid
"""
try:
fst, snd = identity.split(b' <', 1)
except ValueError:
raise InvalidUserIdentity(identity)
if b'>' not in snd:
raise InvalidUserIdentity(identity)
def parse_graftpoints(graftpoints):
"""Convert a list of graftpoints into a dict
- :param graftpoints: Iterator of graftpoint lines
+ Args:
+ graftpoints: Iterator of graftpoint lines
Each line is formatted as:
[]*
Resulting dictionary is:
: [*]
https://git.wiki.kernel.org/index.php/GraftPoint
"""
grafts = {}
for l in graftpoints:
raw_graft = l.split(None, 1)
commit = raw_graft[0]
if len(raw_graft) == 2:
parents = raw_graft[1].split()
else:
parents = []
for sha in [commit] + parents:
check_hexsha(sha, 'Invalid graftpoint')
grafts[commit] = parents
return grafts
def serialize_graftpoints(graftpoints):
"""Convert a dictionary of grafts into string
The graft dictionary is:
: [*]
Each line is formatted as:
[]*
https://git.wiki.kernel.org/index.php/GraftPoint
"""
graft_lines = []
for commit, parents in graftpoints.items():
if parents:
graft_lines.append(commit + b' ' + b' '.join(parents))
else:
graft_lines.append(commit)
return b'\n'.join(graft_lines)
def _set_filesystem_hidden(path):
"""Mark path as to be hidden if supported by platform and filesystem.
On win32 uses SetFileAttributesW api:
"""
if sys.platform == 'win32':
import ctypes
from ctypes.wintypes import BOOL, DWORD, LPCWSTR
FILE_ATTRIBUTE_HIDDEN = 2
SetFileAttributesW = ctypes.WINFUNCTYPE(BOOL, LPCWSTR, DWORD)(
("SetFileAttributesW", ctypes.windll.kernel32))
if isinstance(path, bytes):
path = path.decode(sys.getfilesystemencoding())
if not SetFileAttributesW(path, FILE_ATTRIBUTE_HIDDEN):
pass # Could raise or log `ctypes.WinError()` here
# Could implement other platform specific filesytem hiding here
class BaseRepo(object):
"""Base class for a git repository.
:ivar object_store: Dictionary-like object for accessing
the objects
:ivar refs: Dictionary-like object with the refs in this
repository
"""
def __init__(self, object_store, refs):
"""Open a repository.
This shouldn't be called directly, but rather through one of the
base classes, such as MemoryRepo or Repo.
- :param object_store: Object store to use
- :param refs: Refs container to use
+ Args:
+ object_store: Object store to use
+ refs: Refs container to use
"""
self.object_store = object_store
self.refs = refs
self._graftpoints = {}
self.hooks = {}
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
- :return: True if permissions can be trusted, False otherwise.
+ Returns: True if permissions can be trusted, False otherwise.
"""
raise NotImplementedError(self._determine_file_mode)
def _init_files(self, bare):
"""Initialize a default set of named files."""
from dulwich.config import ConfigFile
self._put_named_file('description', b"Unnamed repository")
f = BytesIO()
cf = ConfigFile()
cf.set("core", "repositoryformatversion", "0")
if self._determine_file_mode():
cf.set("core", "filemode", True)
else:
cf.set("core", "filemode", False)
cf.set("core", "bare", bare)
cf.set("core", "logallrefupdates", True)
cf.write_to_file(f)
self._put_named_file('config', f.getvalue())
self._put_named_file(os.path.join('info', 'exclude'), b'')
def get_named_file(self, path):
"""Get a file from the control dir with a specific name.
Although the filename should be interpreted as a filename relative to
the control dir in a disk-based Repo, the object returned need not be
pointing to a file in that location.
- :param path: The path to the file, relative to the control dir.
- :return: An open file object, or None if the file does not exist.
+ Args:
+ path: The path to the file, relative to the control dir.
+ Returns: An open file object, or None if the file does not exist.
"""
raise NotImplementedError(self.get_named_file)
def _put_named_file(self, path, contents):
"""Write a file to the control dir with the given name and contents.
- :param path: The path to the file, relative to the control dir.
- :param contents: A string to write to the file.
+ Args:
+ path: The path to the file, relative to the control dir.
+ contents: A string to write to the file.
"""
raise NotImplementedError(self._put_named_file)
def _del_named_file(self, path):
"""Delete a file in the contrl directory with the given name."""
raise NotImplementedError(self._del_named_file)
def open_index(self):
"""Open the index for this repository.
- :raise NoIndexPresent: If no index is present
- :return: The matching `Index`
+ Raises:
+ NoIndexPresent: If no index is present
+ Returns: The matching `Index`
"""
raise NotImplementedError(self.open_index)
def fetch(self, target, determine_wants=None, progress=None, depth=None):
"""Fetch objects into another repository.
- :param target: The target repository
- :param determine_wants: Optional function to determine what refs to
+ Args:
+ target: The target repository
+ determine_wants: Optional function to determine what refs to
fetch.
- :param progress: Optional progress function
- :param depth: Optional shallow fetch depth
- :return: The local refs
+ progress: Optional progress function
+ depth: Optional shallow fetch depth
+ Returns: The local refs
"""
if determine_wants is None:
determine_wants = target.object_store.determine_wants_all
count, pack_data = self.fetch_pack_data(
determine_wants, target.get_graph_walker(), progress=progress,
depth=depth)
target.object_store.add_pack_data(count, pack_data, progress)
return self.get_refs()
def fetch_pack_data(self, determine_wants, graph_walker, progress,
get_tagged=None, depth=None):
"""Fetch the pack data required for a set of revisions.
- :param determine_wants: Function that takes a dictionary with heads
+ Args:
+ determine_wants: Function that takes a dictionary with heads
and returns the list of heads to fetch.
- :param graph_walker: Object that can iterate over the list of revisions
+ graph_walker: Object that can iterate over the list of revisions
to fetch and has an "ack" method that will be called to acknowledge
that a revision is present.
- :param progress: Simple progress function that will be called with
+ progress: Simple progress function that will be called with
updated progress strings.
- :param get_tagged: Function that returns a dict of pointed-to sha ->
+ get_tagged: Function that returns a dict of pointed-to sha ->
tag sha for including tags.
- :param depth: Shallow fetch depth
- :return: count and iterator over pack data
+ depth: Shallow fetch depth
+ Returns: count and iterator over pack data
"""
# TODO(jelmer): Fetch pack data directly, don't create objects first.
objects = self.fetch_objects(determine_wants, graph_walker, progress,
get_tagged, depth=depth)
return pack_objects_to_data(objects)
def fetch_objects(self, determine_wants, graph_walker, progress,
get_tagged=None, depth=None):
"""Fetch the missing objects required for a set of revisions.
- :param determine_wants: Function that takes a dictionary with heads
+ Args:
+ determine_wants: Function that takes a dictionary with heads
and returns the list of heads to fetch.
- :param graph_walker: Object that can iterate over the list of revisions
+ graph_walker: Object that can iterate over the list of revisions
to fetch and has an "ack" method that will be called to acknowledge
that a revision is present.
- :param progress: Simple progress function that will be called with
+ progress: Simple progress function that will be called with
updated progress strings.
- :param get_tagged: Function that returns a dict of pointed-to sha ->
+ get_tagged: Function that returns a dict of pointed-to sha ->
tag sha for including tags.
- :param depth: Shallow fetch depth
- :return: iterator over objects, with __len__ implemented
+ depth: Shallow fetch depth
+ Returns: iterator over objects, with __len__ implemented
"""
if depth not in (None, 0):
raise NotImplementedError("depth not supported yet")
refs = {}
for ref, sha in self.get_refs().items():
try:
obj = self.object_store[sha]
except KeyError:
warnings.warn(
'ref %s points at non-present sha %s' % (
ref.decode('utf-8', 'replace'), sha.decode('ascii')),
UserWarning)
continue
else:
if isinstance(obj, Tag):
refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1]
refs[ref] = sha
wants = determine_wants(refs)
if not isinstance(wants, list):
raise TypeError("determine_wants() did not return a list")
shallows = getattr(graph_walker, 'shallow', frozenset())
unshallows = getattr(graph_walker, 'unshallow', frozenset())
if wants == []:
# TODO(dborowitz): find a way to short-circuit that doesn't change
# this interface.
if shallows or unshallows:
# Do not send a pack in shallow short-circuit path
return None
return []
# If the graph walker is set up with an implementation that can
# ACK/NAK to the wire, it will write data to the client through
# this call as a side-effect.
haves = self.object_store.find_common_revisions(graph_walker)
# Deal with shallow requests separately because the haves do
# not reflect what objects are missing
if shallows or unshallows:
# TODO: filter the haves commits from iter_shas. the specific
# commits aren't missing.
haves = []
def get_parents(commit):
if commit.id in shallows:
return []
return self.get_parents(commit.id, commit)
return self.object_store.iter_shas(
self.object_store.find_missing_objects(
haves, wants, progress,
get_tagged,
get_parents=get_parents))
def get_graph_walker(self, heads=None):
"""Retrieve a graph walker.
A graph walker is used by a remote repository (or proxy)
to find out which objects are present in this repository.
- :param heads: Repository heads to use (optional)
- :return: A graph walker object
+ Args:
+ heads: Repository heads to use (optional)
+ Returns: A graph walker object
"""
if heads is None:
heads = [
sha for sha in self.refs.as_dict(b'refs/heads').values()
if sha in self.object_store]
return ObjectStoreGraphWalker(
heads, self.get_parents, shallow=self.get_shallow())
def get_refs(self):
"""Get dictionary with all refs.
- :return: A ``dict`` mapping ref names to SHA1s
+ Returns: A ``dict`` mapping ref names to SHA1s
"""
return self.refs.as_dict()
def head(self):
"""Return the SHA1 pointed at by HEAD."""
return self.refs[b'HEAD']
def _get_object(self, sha, cls):
assert len(sha) in (20, 40)
ret = self.get_object(sha)
if not isinstance(ret, cls):
if cls is Commit:
raise NotCommitError(ret)
elif cls is Blob:
raise NotBlobError(ret)
elif cls is Tree:
raise NotTreeError(ret)
elif cls is Tag:
raise NotTagError(ret)
else:
raise Exception("Type invalid: %r != %r" % (
ret.type_name, cls.type_name))
return ret
def get_object(self, sha):
"""Retrieve the object with the specified SHA.
- :param sha: SHA to retrieve
- :return: A ShaFile object
- :raise KeyError: when the object can not be found
+ Args:
+ sha: SHA to retrieve
+ Returns: A ShaFile object
+ Raises:
+ KeyError: when the object can not be found
"""
return self.object_store[sha]
def get_parents(self, sha, commit=None):
"""Retrieve the parents of a specific commit.
If the specific commit is a graftpoint, the graft parents
will be returned instead.
- :param sha: SHA of the commit for which to retrieve the parents
- :param commit: Optional commit matching the sha
- :return: List of parents
+ Args:
+ sha: SHA of the commit for which to retrieve the parents
+ commit: Optional commit matching the sha
+ Returns: List of parents
"""
try:
return self._graftpoints[sha]
except KeyError:
if commit is None:
commit = self[sha]
return commit.parents
def get_config(self):
"""Retrieve the config object.
- :return: `ConfigFile` object for the ``.git/config`` file.
+ Returns: `ConfigFile` object for the ``.git/config`` file.
"""
raise NotImplementedError(self.get_config)
def get_description(self):
"""Retrieve the description for this repository.
- :return: String with the description of the repository
+ Returns: String with the description of the repository
as set by the user.
"""
raise NotImplementedError(self.get_description)
def set_description(self, description):
"""Set the description for this repository.
- :param description: Text to set as description for this repository.
+ Args:
+ description: Text to set as description for this repository.
"""
raise NotImplementedError(self.set_description)
def get_config_stack(self):
"""Return a config stack for this repository.
This stack accesses the configuration for both this repository
itself (.git/config) and the global configuration, which usually
lives in ~/.gitconfig.
- :return: `Config` instance for this repository
+ Returns: `Config` instance for this repository
"""
from dulwich.config import StackedConfig
backends = [self.get_config()] + StackedConfig.default_backends()
return StackedConfig(backends, writable=backends[0])
def get_shallow(self):
"""Get the set of shallow commits.
- :return: Set of shallow commits.
+ Returns: Set of shallow commits.
"""
f = self.get_named_file('shallow')
if f is None:
return set()
with f:
return set(l.strip() for l in f)
def update_shallow(self, new_shallow, new_unshallow):
"""Update the list of shallow objects.
- :param new_shallow: Newly shallow objects
- :param new_unshallow: Newly no longer shallow objects
+ Args:
+ new_shallow: Newly shallow objects
+ new_unshallow: Newly no longer shallow objects
"""
shallow = self.get_shallow()
if new_shallow:
shallow.update(new_shallow)
if new_unshallow:
shallow.difference_update(new_unshallow)
self._put_named_file(
'shallow',
b''.join([sha + b'\n' for sha in shallow]))
def get_peeled(self, ref):
"""Get the peeled value of a ref.
- :param ref: The refname to peel.
- :return: The fully-peeled SHA1 of a tag object, after peeling all
+ Args:
+ ref: The refname to peel.
+ Returns: The fully-peeled SHA1 of a tag object, after peeling all
intermediate tags; if the original ref does not point to a tag,
this will equal the original SHA1.
"""
cached = self.refs.get_peeled(ref)
if cached is not None:
return cached
return self.object_store.peel_sha(self.refs[ref]).id
def get_walker(self, include=None, *args, **kwargs):
"""Obtain a walker for this repository.
- :param include: Iterable of SHAs of commits to include along with their
+ Args:
+ include: Iterable of SHAs of commits to include along with their
ancestors. Defaults to [HEAD]
- :param exclude: Iterable of SHAs of commits to exclude along with their
+ exclude: Iterable of SHAs of commits to exclude along with their
ancestors, overriding includes.
- :param order: ORDER_* constant specifying the order of results.
+ order: ORDER_* constant specifying the order of results.
Anything other than ORDER_DATE may result in O(n) memory usage.
- :param reverse: If True, reverse the order of output, requiring O(n)
+ reverse: If True, reverse the order of output, requiring O(n)
memory.
- :param max_entries: The maximum number of entries to yield, or None for
+ max_entries: The maximum number of entries to yield, or None for
no limit.
- :param paths: Iterable of file or subtree paths to show entries for.
- :param rename_detector: diff.RenameDetector object for detecting
+ paths: Iterable of file or subtree paths to show entries for.
+ rename_detector: diff.RenameDetector object for detecting
renames.
- :param follow: If True, follow path across renames/copies. Forces a
+ follow: If True, follow path across renames/copies. Forces a
default rename_detector.
- :param since: Timestamp to list commits after.
- :param until: Timestamp to list commits before.
- :param queue_cls: A class to use for a queue of commits, supporting the
+ since: Timestamp to list commits after.
+ until: Timestamp to list commits before.
+ queue_cls: A class to use for a queue of commits, supporting the
iterator protocol. The constructor takes a single argument, the
Walker.
- :return: A `Walker` object
+ Returns: A `Walker` object
"""
from dulwich.walk import Walker
if include is None:
include = [self.head()]
if isinstance(include, str):
include = [include]
kwargs['get_parents'] = lambda commit: self.get_parents(
commit.id, commit)
return Walker(self.object_store, include, *args, **kwargs)
def __getitem__(self, name):
"""Retrieve a Git object by SHA1 or ref.
- :param name: A Git object SHA1 or a ref name
- :return: A `ShaFile` object, such as a Commit or Blob
- :raise KeyError: when the specified ref or object does not exist
+ Args:
+ name: A Git object SHA1 or a ref name
+ Returns: A `ShaFile` object, such as a Commit or Blob
+ Raises:
+ KeyError: when the specified ref or object does not exist
"""
if not isinstance(name, bytes):
raise TypeError("'name' must be bytestring, not %.80s" %
type(name).__name__)
if len(name) in (20, 40):
try:
return self.object_store[name]
except (KeyError, ValueError):
pass
try:
return self.object_store[self.refs[name]]
except RefFormatError:
raise KeyError(name)
def __contains__(self, name):
"""Check if a specific Git object or ref is present.
- :param name: Git object SHA1 or ref name
+ Args:
+ name: Git object SHA1 or ref name
"""
if len(name) in (20, 40):
return name in self.object_store or name in self.refs
else:
return name in self.refs
def __setitem__(self, name, value):
"""Set a ref.
- :param name: ref name
- :param value: Ref value - either a ShaFile object, or a hex sha
+ Args:
+ name: ref name
+ value: Ref value - either a ShaFile object, or a hex sha
"""
if name.startswith(b"refs/") or name == b'HEAD':
if isinstance(value, ShaFile):
self.refs[name] = value.id
elif isinstance(value, bytes):
self.refs[name] = value
else:
raise TypeError(value)
else:
raise ValueError(name)
def __delitem__(self, name):
"""Remove a ref.
- :param name: Name of the ref to remove
+ Args:
+ name: Name of the ref to remove
"""
if name.startswith(b"refs/") or name == b"HEAD":
del self.refs[name]
else:
raise ValueError(name)
def _get_user_identity(self, config, kind=None):
"""Determine the identity to use for new commits.
"""
# TODO(jelmer): Deprecate this function in favor of get_user_identity
return get_user_identity(config)
def _add_graftpoints(self, updated_graftpoints):
"""Add or modify graftpoints
- :param updated_graftpoints: Dict of commit shas to list of parent shas
+ Args:
+ updated_graftpoints: Dict of commit shas to list of parent shas
"""
# Simple validation
for commit, parents in updated_graftpoints.items():
for sha in [commit] + parents:
check_hexsha(sha, 'Invalid graftpoint')
self._graftpoints.update(updated_graftpoints)
def _remove_graftpoints(self, to_remove=[]):
"""Remove graftpoints
- :param to_remove: List of commit shas
+ Args:
+ to_remove: List of commit shas
"""
for sha in to_remove:
del self._graftpoints[sha]
def _read_heads(self, name):
f = self.get_named_file(name)
if f is None:
return []
with f:
return [l.strip() for l in f.readlines() if l.strip()]
def do_commit(self, message=None, committer=None,
author=None, commit_timestamp=None,
commit_timezone=None, author_timestamp=None,
author_timezone=None, tree=None, encoding=None,
ref=b'HEAD', merge_heads=None):
"""Create a new commit.
- :param message: Commit message
- :param committer: Committer fullname
- :param author: Author fullname (defaults to committer)
- :param commit_timestamp: Commit timestamp (defaults to now)
- :param commit_timezone: Commit timestamp timezone (defaults to GMT)
- :param author_timestamp: Author timestamp (defaults to commit
+ Args:
+ message: Commit message
+ committer: Committer fullname
+ author: Author fullname (defaults to committer)
+ commit_timestamp: Commit timestamp (defaults to now)
+ commit_timezone: Commit timestamp timezone (defaults to GMT)
+ author_timestamp: Author timestamp (defaults to commit
timestamp)
- :param author_timezone: Author timestamp timezone
+ author_timezone: Author timestamp timezone
(defaults to commit timestamp timezone)
- :param tree: SHA1 of the tree root to use (if not specified the
+ tree: SHA1 of the tree root to use (if not specified the
current index will be committed).
- :param encoding: Encoding
- :param ref: Optional ref to commit to (defaults to current branch)
- :param merge_heads: Merge heads (defaults to .git/MERGE_HEADS)
- :return: New commit SHA1
+ encoding: Encoding
+ ref: Optional ref to commit to (defaults to current branch)
+ merge_heads: Merge heads (defaults to .git/MERGE_HEADS)
+ Returns: New commit SHA1
"""
import time
c = Commit()
if tree is None:
index = self.open_index()
c.tree = index.commit(self.object_store)
else:
if len(tree) != 40:
raise ValueError("tree must be a 40-byte hex sha string")
c.tree = tree
try:
self.hooks['pre-commit'].execute()
except HookError as e:
raise CommitError(e)
except KeyError: # no hook defined, silent fallthrough
pass
config = self.get_config_stack()
if merge_heads is None:
merge_heads = self._read_heads('MERGE_HEADS')
if committer is None:
committer = get_user_identity(config, kind='COMMITTER')
check_user_identity(committer)
c.committer = committer
if commit_timestamp is None:
# FIXME: Support GIT_COMMITTER_DATE environment variable
commit_timestamp = time.time()
c.commit_time = int(commit_timestamp)
if commit_timezone is None:
# FIXME: Use current user timezone rather than UTC
commit_timezone = 0
c.commit_timezone = commit_timezone
if author is None:
author = get_user_identity(config, kind='AUTHOR')
c.author = author
check_user_identity(author)
if author_timestamp is None:
# FIXME: Support GIT_AUTHOR_DATE environment variable
author_timestamp = commit_timestamp
c.author_time = int(author_timestamp)
if author_timezone is None:
author_timezone = commit_timezone
c.author_timezone = author_timezone
if encoding is None:
try:
encoding = config.get(('i18n', ), 'commitEncoding')
except KeyError:
pass # No dice
if encoding is not None:
c.encoding = encoding
if message is None:
# FIXME: Try to read commit message from .git/MERGE_MSG
raise ValueError("No commit message specified")
try:
c.message = self.hooks['commit-msg'].execute(message)
if c.message is None:
c.message = message
except HookError as e:
raise CommitError(e)
except KeyError: # no hook defined, message not modified
c.message = message
if ref is None:
# Create a dangling commit
c.parents = merge_heads
self.object_store.add_object(c)
else:
try:
old_head = self.refs[ref]
c.parents = [old_head] + merge_heads
self.object_store.add_object(c)
ok = self.refs.set_if_equals(
ref, old_head, c.id, message=b"commit: " + message,
committer=committer, timestamp=commit_timestamp,
timezone=commit_timezone)
except KeyError:
c.parents = merge_heads
self.object_store.add_object(c)
ok = self.refs.add_if_new(
ref, c.id, message=b"commit: " + message,
committer=committer, timestamp=commit_timestamp,
timezone=commit_timezone)
if not ok:
# Fail if the atomic compare-and-swap failed, leaving the
# commit and all its objects as garbage.
raise CommitError("%s changed during commit" % (ref,))
self._del_named_file('MERGE_HEADS')
try:
self.hooks['post-commit'].execute()
except HookError as e: # silent failure
warnings.warn("post-commit hook failed: %s" % e, UserWarning)
except KeyError: # no hook defined, silent fallthrough
pass
return c.id
def read_gitfile(f):
"""Read a ``.git`` file.
The first line of the file should start with "gitdir: "
- :param f: File-like object to read from
- :return: A path
+ Args:
+ f: File-like object to read from
+ Returns: A path
"""
cs = f.read()
if not cs.startswith("gitdir: "):
raise ValueError("Expected file to start with 'gitdir: '")
return cs[len("gitdir: "):].rstrip("\n")
class Repo(BaseRepo):
"""A git repository backed by local disk.
To open an existing repository, call the contructor with
the path of the repository.
To create a new repository, use the Repo.init class method.
"""
def __init__(self, root):
hidden_path = os.path.join(root, CONTROLDIR)
if os.path.isdir(os.path.join(hidden_path, OBJECTDIR)):
self.bare = False
self._controldir = hidden_path
elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and
os.path.isdir(os.path.join(root, REFSDIR))):
self.bare = True
self._controldir = root
elif os.path.isfile(hidden_path):
self.bare = False
with open(hidden_path, 'r') as f:
path = read_gitfile(f)
self.bare = False
self._controldir = os.path.join(root, path)
else:
raise NotGitRepository(
"No git repository was found at %(path)s" % dict(path=root)
)
commondir = self.get_named_file(COMMONDIR)
if commondir is not None:
with commondir:
self._commondir = os.path.join(
self.controldir(),
commondir.read().rstrip(b"\r\n").decode(
sys.getfilesystemencoding()))
else:
self._commondir = self._controldir
self.path = root
object_store = DiskObjectStore(
os.path.join(self.commondir(), OBJECTDIR))
refs = DiskRefsContainer(self.commondir(), self._controldir,
logger=self._write_reflog)
BaseRepo.__init__(self, object_store, refs)
self._graftpoints = {}
graft_file = self.get_named_file(os.path.join("info", "grafts"),
basedir=self.commondir())
if graft_file:
with graft_file:
self._graftpoints.update(parse_graftpoints(graft_file))
graft_file = self.get_named_file("shallow",
basedir=self.commondir())
if graft_file:
with graft_file:
self._graftpoints.update(parse_graftpoints(graft_file))
self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
self.hooks['post-commit'] = PostCommitShellHook(self.controldir())
def _write_reflog(self, ref, old_sha, new_sha, committer, timestamp,
timezone, message):
from .reflog import format_reflog_line
path = os.path.join(
self.controldir(), 'logs',
ref.decode(sys.getfilesystemencoding()))
try:
os.makedirs(os.path.dirname(path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
if committer is None:
config = self.get_config_stack()
committer = self._get_user_identity(config)
check_user_identity(committer)
if timestamp is None:
timestamp = int(time.time())
if timezone is None:
timezone = 0 # FIXME
with open(path, 'ab') as f:
f.write(format_reflog_line(old_sha, new_sha, committer,
timestamp, timezone, message) + b'\n')
@classmethod
def discover(cls, start='.'):
"""Iterate parent directories to discover a repository
Return a Repo object for the first parent directory that looks like a
Git repository.
- :param start: The directory to start discovery from (defaults to '.')
+ Args:
+ start: The directory to start discovery from (defaults to '.')
"""
remaining = True
path = os.path.abspath(start)
while remaining:
try:
return cls(path)
except NotGitRepository:
path, remaining = os.path.split(path)
raise NotGitRepository(
"No git repository was found at %(path)s" % dict(path=start)
)
def controldir(self):
"""Return the path of the control directory."""
return self._controldir
def commondir(self):
"""Return the path of the common directory.
For a main working tree, it is identical to controldir().
For a linked working tree, it is the control directory of the
main working tree."""
return self._commondir
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
- :return: True if permissions can be trusted, False otherwise.
+ Returns: True if permissions can be trusted, False otherwise.
"""
fname = os.path.join(self.path, '.probe-permissions')
with open(fname, 'w') as f:
f.write('')
st1 = os.lstat(fname)
try:
os.chmod(fname, st1.st_mode ^ stat.S_IXUSR)
except EnvironmentError as e:
if e.errno == errno.EPERM:
return False
raise
st2 = os.lstat(fname)
os.unlink(fname)
mode_differs = st1.st_mode != st2.st_mode
st2_has_exec = (st2.st_mode & stat.S_IXUSR) != 0
return mode_differs and st2_has_exec
def _put_named_file(self, path, contents):
"""Write a file to the control dir with the given name and contents.
- :param path: The path to the file, relative to the control dir.
- :param contents: A string to write to the file.
+ Args:
+ path: The path to the file, relative to the control dir.
+ contents: A string to write to the file.
"""
path = path.lstrip(os.path.sep)
with GitFile(os.path.join(self.controldir(), path), 'wb') as f:
f.write(contents)
def _del_named_file(self, path):
try:
os.unlink(os.path.join(self.controldir(), path))
except (IOError, OSError) as e:
if e.errno == errno.ENOENT:
return
raise
def get_named_file(self, path, basedir=None):
"""Get a file from the control dir with a specific name.
Although the filename should be interpreted as a filename relative to
the control dir in a disk-based Repo, the object returned need not be
pointing to a file in that location.
- :param path: The path to the file, relative to the control dir.
- :param basedir: Optional argument that specifies an alternative to the
+ Args:
+ path: The path to the file, relative to the control dir.
+ basedir: Optional argument that specifies an alternative to the
control dir.
- :return: An open file object, or None if the file does not exist.
+ Returns: An open file object, or None if the file does not exist.
"""
# TODO(dborowitz): sanitize filenames, since this is used directly by
# the dumb web serving code.
if basedir is None:
basedir = self.controldir()
path = path.lstrip(os.path.sep)
try:
return open(os.path.join(basedir, path), 'rb')
except (IOError, OSError) as e:
if e.errno == errno.ENOENT:
return None
raise
def index_path(self):
"""Return path to the index file."""
return os.path.join(self.controldir(), INDEX_FILENAME)
def open_index(self):
"""Open the index for this repository.
- :raise NoIndexPresent: If no index is present
- :return: The matching `Index`
+ Raises:
+ NoIndexPresent: If no index is present
+ Returns: The matching `Index`
"""
from dulwich.index import Index
if not self.has_index():
raise NoIndexPresent()
return Index(self.index_path())
def has_index(self):
"""Check if an index is present."""
# Bare repos must never have index files; non-bare repos may have a
# missing index file, which is treated as empty.
return not self.bare
def stage(self, fs_paths):
"""Stage a set of paths.
- :param fs_paths: List of paths, relative to the repository path
+ Args:
+ fs_paths: List of paths, relative to the repository path
"""
root_path_bytes = self.path.encode(sys.getfilesystemencoding())
if not isinstance(fs_paths, list):
fs_paths = [fs_paths]
from dulwich.index import (
blob_from_path_and_stat,
index_entry_from_stat,
_fs_to_tree_path,
)
index = self.open_index()
blob_normalizer = self.get_blob_normalizer()
for fs_path in fs_paths:
if not isinstance(fs_path, bytes):
fs_path = fs_path.encode(sys.getfilesystemencoding())
if os.path.isabs(fs_path):
raise ValueError(
"path %r should be relative to "
"repository root, not absolute" % fs_path)
tree_path = _fs_to_tree_path(fs_path)
full_path = os.path.join(root_path_bytes, fs_path)
try:
st = os.lstat(full_path)
except OSError:
# File no longer exists
try:
del index[tree_path]
except KeyError:
pass # already removed
else:
if not stat.S_ISDIR(st.st_mode):
blob = blob_from_path_and_stat(full_path, st)
blob = blob_normalizer.checkin_normalize(blob, fs_path)
self.object_store.add_object(blob)
index[tree_path] = index_entry_from_stat(st, blob.id, 0)
else:
try:
del index[tree_path]
except KeyError:
pass
index.write()
def clone(self, target_path, mkdir=True, bare=False,
origin=b"origin", checkout=None):
"""Clone this repository.
- :param target_path: Target path
- :param mkdir: Create the target directory
- :param bare: Whether to create a bare repository
- :param origin: Base name for refs in target repository
+ Args:
+ target_path: Target path
+ mkdir: Create the target directory
+ bare: Whether to create a bare repository
+ origin: Base name for refs in target repository
cloned from this repository
- :return: Created repository as `Repo`
+ Returns: Created repository as `Repo`
"""
if not bare:
target = self.init(target_path, mkdir=mkdir)
else:
if checkout:
raise ValueError("checkout and bare are incompatible")
target = self.init_bare(target_path, mkdir=mkdir)
self.fetch(target)
encoded_path = self.path
if not isinstance(encoded_path, bytes):
encoded_path = encoded_path.encode(sys.getfilesystemencoding())
ref_message = b"clone: from " + encoded_path
target.refs.import_refs(
b'refs/remotes/' + origin, self.refs.as_dict(b'refs/heads'),
message=ref_message)
target.refs.import_refs(
b'refs/tags', self.refs.as_dict(b'refs/tags'),
message=ref_message)
try:
target.refs.add_if_new(
DEFAULT_REF, self.refs[DEFAULT_REF],
message=ref_message)
except KeyError:
pass
target_config = target.get_config()
target_config.set(('remote', 'origin'), 'url', encoded_path)
target_config.set(('remote', 'origin'), 'fetch',
'+refs/heads/*:refs/remotes/origin/*')
target_config.write_to_path()
# Update target head
head_chain, head_sha = self.refs.follow(b'HEAD')
if head_chain and head_sha is not None:
target.refs.set_symbolic_ref(b'HEAD', head_chain[-1],
message=ref_message)
target[b'HEAD'] = head_sha
if checkout is None:
checkout = (not bare)
if checkout:
# Checkout HEAD to target dir
target.reset_index()
return target
def reset_index(self, tree=None):
"""Reset the index back to a specific tree.
- :param tree: Tree SHA to reset to, None for current HEAD tree.
+ Args:
+ tree: Tree SHA to reset to, None for current HEAD tree.
"""
from dulwich.index import (
build_index_from_tree,
validate_path_element_default,
validate_path_element_ntfs,
)
if tree is None:
tree = self[b'HEAD'].tree
config = self.get_config()
honor_filemode = config.get_boolean(
b'core', b'filemode', os.name != "nt")
if config.get_boolean(b'core', b'core.protectNTFS', os.name == "nt"):
validate_path_element = validate_path_element_ntfs
else:
validate_path_element = validate_path_element_default
return build_index_from_tree(
self.path, self.index_path(), self.object_store, tree,
honor_filemode=honor_filemode,
validate_path_element=validate_path_element)
def get_config(self):
"""Retrieve the config object.
- :return: `ConfigFile` object for the ``.git/config`` file.
+ Returns: `ConfigFile` object for the ``.git/config`` file.
"""
from dulwich.config import ConfigFile
path = os.path.join(self._controldir, 'config')
try:
return ConfigFile.from_path(path)
except (IOError, OSError) as e:
if e.errno != errno.ENOENT:
raise
ret = ConfigFile()
ret.path = path
return ret
def get_description(self):
"""Retrieve the description of this repository.
- :return: A string describing the repository or None.
+ Returns: A string describing the repository or None.
"""
path = os.path.join(self._controldir, 'description')
try:
with GitFile(path, 'rb') as f:
return f.read()
except (IOError, OSError) as e:
if e.errno != errno.ENOENT:
raise
return None
def __repr__(self):
return "" % self.path
def set_description(self, description):
"""Set the description for this repository.
- :param description: Text to set as description for this repository.
+ Args:
+ description: Text to set as description for this repository.
"""
self._put_named_file('description', description)
@classmethod
def _init_maybe_bare(cls, path, bare):
for d in BASE_DIRECTORIES:
os.mkdir(os.path.join(path, *d))
DiskObjectStore.init(os.path.join(path, OBJECTDIR))
ret = cls(path)
ret.refs.set_symbolic_ref(b'HEAD', DEFAULT_REF)
ret._init_files(bare)
return ret
@classmethod
def init(cls, path, mkdir=False):
"""Create a new repository.
- :param path: Path in which to create the repository
- :param mkdir: Whether to create the directory
- :return: `Repo` instance
+ Args:
+ path: Path in which to create the repository
+ mkdir: Whether to create the directory
+ Returns: `Repo` instance
"""
if mkdir:
os.mkdir(path)
controldir = os.path.join(path, CONTROLDIR)
os.mkdir(controldir)
_set_filesystem_hidden(controldir)
cls._init_maybe_bare(controldir, False)
return cls(path)
@classmethod
def _init_new_working_directory(cls, path, main_repo, identifier=None,
mkdir=False):
"""Create a new working directory linked to a repository.
- :param path: Path in which to create the working tree.
- :param main_repo: Main repository to reference
- :param identifier: Worktree identifier
- :param mkdir: Whether to create the directory
- :return: `Repo` instance
+ Args:
+ path: Path in which to create the working tree.
+ main_repo: Main repository to reference
+ identifier: Worktree identifier
+ mkdir: Whether to create the directory
+ Returns: `Repo` instance
"""
if mkdir:
os.mkdir(path)
if identifier is None:
identifier = os.path.basename(path)
main_worktreesdir = os.path.join(main_repo.controldir(), WORKTREES)
worktree_controldir = os.path.join(main_worktreesdir, identifier)
gitdirfile = os.path.join(path, CONTROLDIR)
with open(gitdirfile, 'wb') as f:
f.write(b'gitdir: ' +
worktree_controldir.encode(sys.getfilesystemencoding()) +
b'\n')
try:
os.mkdir(main_worktreesdir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
try:
os.mkdir(worktree_controldir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(os.path.join(worktree_controldir, GITDIR), 'wb') as f:
f.write(gitdirfile.encode(sys.getfilesystemencoding()) + b'\n')
with open(os.path.join(worktree_controldir, COMMONDIR), 'wb') as f:
f.write(b'../..\n')
with open(os.path.join(worktree_controldir, 'HEAD'), 'wb') as f:
f.write(main_repo.head() + b'\n')
r = cls(path)
r.reset_index()
return r
@classmethod
def init_bare(cls, path, mkdir=False):
"""Create a new bare repository.
``path`` should already exist and be an empty directory.
- :param path: Path to create bare repository in
- :return: a `Repo` instance
+ Args:
+ path: Path to create bare repository in
+ Returns: a `Repo` instance
"""
if mkdir:
os.mkdir(path)
return cls._init_maybe_bare(path, True)
create = init_bare
def close(self):
"""Close any files opened by this repository."""
self.object_store.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def get_blob_normalizer(self):
""" Return a BlobNormalizer object
"""
# TODO Parse the git attributes files
git_attributes = {}
return BlobNormalizer(
self.get_config_stack(), git_attributes
)
class MemoryRepo(BaseRepo):
"""Repo that stores refs, objects, and named files in memory.
MemoryRepos are always bare: they have no working tree and no index, since
those have a stronger dependency on the filesystem.
"""
def __init__(self):
from dulwich.config import ConfigFile
self._reflog = []
refs_container = DictRefsContainer({}, logger=self._append_reflog)
BaseRepo.__init__(self, MemoryObjectStore(), refs_container)
self._named_files = {}
self.bare = True
self._config = ConfigFile()
self._description = None
def _append_reflog(self, *args):
self._reflog.append(args)
def set_description(self, description):
self._description = description
def get_description(self):
return self._description
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
- :return: True if permissions can be trusted, False otherwise.
+ Returns: True if permissions can be trusted, False otherwise.
"""
return sys.platform != 'win32'
def _put_named_file(self, path, contents):
"""Write a file to the control dir with the given name and contents.
- :param path: The path to the file, relative to the control dir.
- :param contents: A string to write to the file.
+ Args:
+ path: The path to the file, relative to the control dir.
+ contents: A string to write to the file.
"""
self._named_files[path] = contents
def _del_named_file(self, path):
try:
del self._named_files[path]
except KeyError:
pass
def get_named_file(self, path, basedir=None):
"""Get a file from the control dir with a specific name.
Although the filename should be interpreted as a filename relative to
the control dir in a disk-baked Repo, the object returned need not be
pointing to a file in that location.
- :param path: The path to the file, relative to the control dir.
- :return: An open file object, or None if the file does not exist.
+ Args:
+ path: The path to the file, relative to the control dir.
+ Returns: An open file object, or None if the file does not exist.
"""
contents = self._named_files.get(path, None)
if contents is None:
return None
return BytesIO(contents)
def open_index(self):
"""Fail to open index for this repo, since it is bare.
- :raise NoIndexPresent: Raised when no index is present
+ Raises:
+ NoIndexPresent: Raised when no index is present
"""
raise NoIndexPresent()
def get_config(self):
"""Retrieve the config object.
- :return: `ConfigFile` object.
+ Returns: `ConfigFile` object.
"""
return self._config
@classmethod
def init_bare(cls, objects, refs):
"""Create a new bare repository in memory.
- :param objects: Objects for the new repository,
+ Args:
+ objects: Objects for the new repository,
as iterable
- :param refs: Refs as dictionary, mapping names
+ refs: Refs as dictionary, mapping names
to object SHA1s
"""
ret = cls()
for obj in objects:
ret.object_store.add_object(obj)
for refname, sha in refs.items():
ret.refs.add_if_new(refname, sha)
ret._init_files(bare=True)
return ret
diff --git a/dulwich/server.py b/dulwich/server.py
index db6ebb7f..f5af537e 100644
--- a/dulwich/server.py
+++ b/dulwich/server.py
@@ -1,1186 +1,1200 @@
# server.py -- Implementation of the server side git protocols
# Copyright (C) 2008 John Carr
# Coprygith (C) 2011-2012 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Git smart network protocol server implementation.
For more detailed implementation on the network protocol, see the
Documentation/technical directory in the cgit distribution, and in particular:
* Documentation/technical/protocol-capabilities.txt
* Documentation/technical/pack-protocol.txt
Currently supported capabilities:
* include-tag
* thin-pack
* multi_ack_detailed
* multi_ack
* side-band-64k
* ofs-delta
* no-progress
* report-status
* delete-refs
* shallow
* symref
"""
import collections
import os
import socket
import sys
import time
import zlib
try:
import SocketServer
except ImportError:
import socketserver as SocketServer
from dulwich.archive import tar_stream
from dulwich.errors import (
ApplyDeltaError,
ChecksumMismatch,
GitProtocolError,
NotGitRepository,
UnexpectedCommandError,
ObjectFormatException,
)
from dulwich import log_utils
from dulwich.objects import (
Commit,
valid_hexsha,
)
from dulwich.pack import (
write_pack_objects,
)
from dulwich.protocol import ( # noqa: F401
BufferedPktLineWriter,
capability_agent,
CAPABILITIES_REF,
CAPABILITY_DELETE_REFS,
CAPABILITY_INCLUDE_TAG,
CAPABILITY_MULTI_ACK_DETAILED,
CAPABILITY_MULTI_ACK,
CAPABILITY_NO_DONE,
CAPABILITY_NO_PROGRESS,
CAPABILITY_OFS_DELTA,
CAPABILITY_QUIET,
CAPABILITY_REPORT_STATUS,
CAPABILITY_SHALLOW,
CAPABILITY_SIDE_BAND_64K,
CAPABILITY_THIN_PACK,
COMMAND_DEEPEN,
COMMAND_DONE,
COMMAND_HAVE,
COMMAND_SHALLOW,
COMMAND_UNSHALLOW,
COMMAND_WANT,
MULTI_ACK,
MULTI_ACK_DETAILED,
Protocol,
ProtocolFile,
ReceivableProtocol,
SIDE_BAND_CHANNEL_DATA,
SIDE_BAND_CHANNEL_PROGRESS,
SIDE_BAND_CHANNEL_FATAL,
SINGLE_ACK,
TCP_GIT_PORT,
ZERO_SHA,
ack_type,
extract_capabilities,
extract_want_line_capabilities,
symref_capabilities,
)
from dulwich.refs import (
ANNOTATED_TAG_SUFFIX,
write_info_refs,
)
from dulwich.repo import (
Repo,
)
logger = log_utils.getLogger(__name__)
class Backend(object):
"""A backend for the Git smart server implementation."""
def open_repository(self, path):
"""Open the repository at a path.
- :param path: Path to the repository
- :raise NotGitRepository: no git repository was found at path
- :return: Instance of BackendRepo
+ Args:
+ path: Path to the repository
+ Raises:
+ NotGitRepository: no git repository was found at path
+ Returns: Instance of BackendRepo
"""
raise NotImplementedError(self.open_repository)
class BackendRepo(object):
"""Repository abstraction used by the Git server.
The methods required here are a subset of those provided by
dulwich.repo.Repo.
"""
object_store = None
refs = None
def get_refs(self):
"""
Get all the refs in the repository
- :return: dict of name -> sha
+ Returns: dict of name -> sha
"""
raise NotImplementedError
def get_peeled(self, name):
"""Return the cached peeled value of a ref, if available.
- :param name: Name of the ref to peel
- :return: The peeled value of the ref. If the ref is known not point to
+ Args:
+ name: Name of the ref to peel
+ Returns: The peeled value of the ref. If the ref is known not point to
a tag, this will be the SHA the ref refers to. If no cached
information about a tag is available, this method may return None,
but it should attempt to peel the tag if possible.
"""
return None
def fetch_objects(self, determine_wants, graph_walker, progress,
get_tagged=None):
"""
Yield the objects required for a list of commits.
- :param progress: is a callback to send progress messages to the client
- :param get_tagged: Function that returns a dict of pointed-to sha ->
+ Args:
+ progress: is a callback to send progress messages to the client
+ get_tagged: Function that returns a dict of pointed-to sha ->
tag sha for including tags.
"""
raise NotImplementedError
class DictBackend(Backend):
"""Trivial backend that looks up Git repositories in a dictionary."""
def __init__(self, repos):
self.repos = repos
def open_repository(self, path):
logger.debug('Opening repository at %s', path)
try:
return self.repos[path]
except KeyError:
raise NotGitRepository(
"No git repository was found at %(path)s" % dict(path=path)
)
class FileSystemBackend(Backend):
"""Simple backend looking up Git repositories in the local file system."""
def __init__(self, root=os.sep):
super(FileSystemBackend, self).__init__()
self.root = (os.path.abspath(root) + os.sep).replace(
os.sep * 2, os.sep)
def open_repository(self, path):
logger.debug('opening repository at %s', path)
abspath = os.path.abspath(os.path.join(self.root, path)) + os.sep
normcase_abspath = os.path.normcase(abspath)
normcase_root = os.path.normcase(self.root)
if not normcase_abspath.startswith(normcase_root):
raise NotGitRepository(
"Path %r not inside root %r" %
(path, self.root))
return Repo(abspath)
class Handler(object):
"""Smart protocol command handler base class."""
def __init__(self, backend, proto, http_req=None):
self.backend = backend
self.proto = proto
self.http_req = http_req
def handle(self):
raise NotImplementedError(self.handle)
class PackHandler(Handler):
"""Protocol handler for packs."""
def __init__(self, backend, proto, http_req=None):
super(PackHandler, self).__init__(backend, proto, http_req)
self._client_capabilities = None
# Flags needed for the no-done capability
self._done_received = False
@classmethod
def capability_line(cls, capabilities):
logger.info('Sending capabilities: %s', capabilities)
return b"".join([b" " + c for c in capabilities])
@classmethod
def capabilities(cls):
raise NotImplementedError(cls.capabilities)
@classmethod
def innocuous_capabilities(cls):
return [CAPABILITY_INCLUDE_TAG, CAPABILITY_THIN_PACK,
CAPABILITY_NO_PROGRESS, CAPABILITY_OFS_DELTA,
capability_agent()]
@classmethod
def required_capabilities(cls):
"""Return a list of capabilities that we require the client to have."""
return []
def set_client_capabilities(self, caps):
allowable_caps = set(self.innocuous_capabilities())
allowable_caps.update(self.capabilities())
for cap in caps:
if cap not in allowable_caps:
raise GitProtocolError('Client asked for capability %s that '
'was not advertised.' % cap)
for cap in self.required_capabilities():
if cap not in caps:
raise GitProtocolError('Client does not support required '
'capability %s.' % cap)
self._client_capabilities = set(caps)
logger.info('Client capabilities: %s', caps)
def has_capability(self, cap):
if self._client_capabilities is None:
raise GitProtocolError('Server attempted to access capability %s '
'before asking client' % cap)
return cap in self._client_capabilities
def notify_done(self):
self._done_received = True
class UploadPackHandler(PackHandler):
"""Protocol handler for uploading a pack to the client."""
def __init__(self, backend, args, proto, http_req=None,
advertise_refs=False):
super(UploadPackHandler, self).__init__(
backend, proto, http_req=http_req)
self.repo = backend.open_repository(args[0])
self._graph_walker = None
self.advertise_refs = advertise_refs
# A state variable for denoting that the have list is still
# being processed, and the client is not accepting any other
# data (such as side-band, see the progress method here).
self._processing_have_lines = False
@classmethod
def capabilities(cls):
return [CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_MULTI_ACK,
CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
CAPABILITY_OFS_DELTA, CAPABILITY_NO_PROGRESS,
CAPABILITY_INCLUDE_TAG, CAPABILITY_SHALLOW, CAPABILITY_NO_DONE]
@classmethod
def required_capabilities(cls):
return (CAPABILITY_SIDE_BAND_64K, CAPABILITY_THIN_PACK,
CAPABILITY_OFS_DELTA)
def progress(self, message):
if (self.has_capability(CAPABILITY_NO_PROGRESS) or
self._processing_have_lines):
return
self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
def get_tagged(self, refs=None, repo=None):
"""Get a dict of peeled values of tags to their original tag shas.
- :param refs: dict of refname -> sha of possible tags; defaults to all
+ Args:
+ refs: dict of refname -> sha of possible tags; defaults to all
of the backend's refs.
- :param repo: optional Repo instance for getting peeled refs; defaults
+ repo: optional Repo instance for getting peeled refs; defaults
to the backend's repo, if available
- :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
+ Returns: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
tag whose peeled value is peeled_sha.
"""
if not self.has_capability(CAPABILITY_INCLUDE_TAG):
return {}
if refs is None:
refs = self.repo.get_refs()
if repo is None:
repo = getattr(self.repo, "repo", None)
if repo is None:
# Bail if we don't have a Repo available; this is ok since
# clients must be able to handle if the server doesn't include
# all relevant tags.
# TODO: fix behavior when missing
return {}
# TODO(jelmer): Integrate this with the refs logic in
# Repo.fetch_objects
tagged = {}
for name, sha in refs.items():
peeled_sha = repo.get_peeled(name)
if peeled_sha != sha:
tagged[peeled_sha] = sha
return tagged
def handle(self):
def write(x):
return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
graph_walker = _ProtocolGraphWalker(
self, self.repo.object_store, self.repo.get_peeled,
self.repo.refs.get_symrefs)
objects_iter = self.repo.fetch_objects(
graph_walker.determine_wants, graph_walker, self.progress,
get_tagged=self.get_tagged)
# Note the fact that client is only processing responses related
# to the have lines it sent, and any other data (including side-
# band) will be be considered a fatal error.
self._processing_have_lines = True
# Did the process short-circuit (e.g. in a stateless RPC call)? Note
# that the client still expects a 0-object pack in most cases.
# Also, if it also happens that the object_iter is instantiated
# with a graph walker with an implementation that talks over the
# wire (which is this instance of this class) this will actually
# iterate through everything and write things out to the wire.
if len(objects_iter) == 0:
return
# The provided haves are processed, and it is safe to send side-
# band data now.
self._processing_have_lines = False
if not graph_walker.handle_done(
not self.has_capability(CAPABILITY_NO_DONE),
self._done_received):
return
self.progress(
("counting objects: %d, done.\n" % len(objects_iter)).encode(
'ascii'))
write_pack_objects(ProtocolFile(None, write), objects_iter)
# we are done
self.proto.write_pkt_line(None)
def _split_proto_line(line, allowed):
"""Split a line read from the wire.
- :param line: The line read from the wire.
- :param allowed: An iterable of command names that should be allowed.
+ Args:
+ line: The line read from the wire.
+ allowed: An iterable of command names that should be allowed.
Command names not listed below as possible return values will be
ignored. If None, any commands from the possible return values are
allowed.
- :return: a tuple having one of the following forms:
+ Returns: a tuple having one of the following forms:
('want', obj_id)
('have', obj_id)
('done', None)
(None, None) (for a flush-pkt)
- :raise UnexpectedCommandError: if the line cannot be parsed into one of the
+ Raises:
+ UnexpectedCommandError: if the line cannot be parsed into one of the
allowed return values.
"""
if not line:
fields = [None]
else:
fields = line.rstrip(b'\n').split(b' ', 1)
command = fields[0]
if allowed is not None and command not in allowed:
raise UnexpectedCommandError(command)
if len(fields) == 1 and command in (COMMAND_DONE, None):
return (command, None)
elif len(fields) == 2:
if command in (COMMAND_WANT, COMMAND_HAVE, COMMAND_SHALLOW,
COMMAND_UNSHALLOW):
if not valid_hexsha(fields[1]):
raise GitProtocolError("Invalid sha")
return tuple(fields)
elif command == COMMAND_DEEPEN:
return command, int(fields[1])
raise GitProtocolError('Received invalid line from client: %r' % line)
def _find_shallow(store, heads, depth):
"""Find shallow commits according to a given depth.
- :param store: An ObjectStore for looking up objects.
- :param heads: Iterable of head SHAs to start walking from.
- :param depth: The depth of ancestors to include. A depth of one includes
+ Args:
+ store: An ObjectStore for looking up objects.
+ heads: Iterable of head SHAs to start walking from.
+ depth: The depth of ancestors to include. A depth of one includes
only the heads themselves.
- :return: A tuple of (shallow, not_shallow), sets of SHAs that should be
+ Returns: A tuple of (shallow, not_shallow), sets of SHAs that should be
considered shallow and unshallow according to the arguments. Note that
these sets may overlap if a commit is reachable along multiple paths.
"""
parents = {}
def get_parents(sha):
result = parents.get(sha, None)
if not result:
result = store[sha].parents
parents[sha] = result
return result
todo = [] # stack of (sha, depth)
for head_sha in heads:
obj = store.peel_sha(head_sha)
if isinstance(obj, Commit):
todo.append((obj.id, 1))
not_shallow = set()
shallow = set()
while todo:
sha, cur_depth = todo.pop()
if cur_depth < depth:
not_shallow.add(sha)
new_depth = cur_depth + 1
todo.extend((p, new_depth) for p in get_parents(sha))
else:
shallow.add(sha)
return shallow, not_shallow
def _want_satisfied(store, haves, want, earliest):
o = store[want]
pending = collections.deque([o])
known = set([want])
while pending:
commit = pending.popleft()
if commit.id in haves:
return True
if commit.type_name != b"commit":
# non-commit wants are assumed to be satisfied
continue
for parent in commit.parents:
if parent in known:
continue
known.add(parent)
parent_obj = store[parent]
# TODO: handle parents with later commit times than children
if parent_obj.commit_time >= earliest:
pending.append(parent_obj)
return False
def _all_wants_satisfied(store, haves, wants):
"""Check whether all the current wants are satisfied by a set of haves.
- :param store: Object store to retrieve objects from
- :param haves: A set of commits we know the client has.
- :param wants: A set of commits the client wants
- :note: Wants are specified with set_wants rather than passed in since
+ Args:
+ store: Object store to retrieve objects from
+ haves: A set of commits we know the client has.
+ wants: A set of commits the client wants
+ Note: Wants are specified with set_wants rather than passed in since
in the current interface they are determined outside this class.
"""
haves = set(haves)
if haves:
earliest = min([store[h].commit_time for h in haves])
else:
earliest = 0
for want in wants:
if not _want_satisfied(store, haves, want, earliest):
return False
return True
class _ProtocolGraphWalker(object):
"""A graph walker that knows the git protocol.
As a graph walker, this class implements ack(), next(), and reset(). It
also contains some base methods for interacting with the wire and walking
the commit tree.
The work of determining which acks to send is passed on to the
implementation instance stored in _impl. The reason for this is that we do
not know at object creation time what ack level the protocol requires. A
call to set_ack_type() is required to set up the implementation, before
any calls to next() or ack() are made.
"""
def __init__(self, handler, object_store, get_peeled, get_symrefs):
self.handler = handler
self.store = object_store
self.get_peeled = get_peeled
self.get_symrefs = get_symrefs
self.proto = handler.proto
self.http_req = handler.http_req
self.advertise_refs = handler.advertise_refs
self._wants = []
self.shallow = set()
self.client_shallow = set()
self.unshallow = set()
self._cached = False
self._cache = []
self._cache_index = 0
self._impl = None
def determine_wants(self, heads):
"""Determine the wants for a set of heads.
The given heads are advertised to the client, who then specifies which
refs he wants using 'want' lines. This portion of the protocol is the
same regardless of ack type, and in fact is used to set the ack type of
the ProtocolGraphWalker.
If the client has the 'shallow' capability, this method also reads and
responds to the 'shallow' and 'deepen' lines from the client. These are
not part of the wants per se, but they set up necessary state for
walking the graph. Additionally, later code depends on this method
consuming everything up to the first 'have' line.
- :param heads: a dict of refname->SHA1 to advertise
- :return: a list of SHA1s requested by the client
+ Args:
+ heads: a dict of refname->SHA1 to advertise
+ Returns: a list of SHA1s requested by the client
"""
symrefs = self.get_symrefs()
values = set(heads.values())
if self.advertise_refs or not self.http_req:
for i, (ref, sha) in enumerate(sorted(heads.items())):
try:
peeled_sha = self.get_peeled(ref)
except KeyError:
# Skip refs that are inaccessible
# TODO(jelmer): Integrate with Repo.fetch_objects refs
# logic.
continue
line = sha + b' ' + ref
if not i:
line += (b'\x00' +
self.handler.capability_line(
self.handler.capabilities() +
symref_capabilities(symrefs.items())))
self.proto.write_pkt_line(line + b'\n')
if peeled_sha != sha:
self.proto.write_pkt_line(
peeled_sha + b' ' + ref + ANNOTATED_TAG_SUFFIX + b'\n')
# i'm done..
self.proto.write_pkt_line(None)
if self.advertise_refs:
return []
# Now client will sending want want want commands
want = self.proto.read_pkt_line()
if not want:
return []
line, caps = extract_want_line_capabilities(want)
self.handler.set_client_capabilities(caps)
self.set_ack_type(ack_type(caps))
allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
command, sha = _split_proto_line(line, allowed)
want_revs = []
while command == COMMAND_WANT:
if sha not in values:
raise GitProtocolError(
'Client wants invalid object %s' % sha)
want_revs.append(sha)
command, sha = self.read_proto_line(allowed)
self.set_wants(want_revs)
if command in (COMMAND_SHALLOW, COMMAND_DEEPEN):
self.unread_proto_line(command, sha)
self._handle_shallow_request(want_revs)
if self.http_req and self.proto.eof():
# The client may close the socket at this point, expecting a
# flush-pkt from the server. We might be ready to send a packfile
# at this point, so we need to explicitly short-circuit in this
# case.
return []
return want_revs
def unread_proto_line(self, command, value):
if isinstance(value, int):
value = str(value).encode('ascii')
self.proto.unread_pkt_line(command + b' ' + value)
def ack(self, have_ref):
if len(have_ref) != 40:
raise ValueError("invalid sha %r" % have_ref)
return self._impl.ack(have_ref)
def reset(self):
self._cached = True
self._cache_index = 0
def next(self):
if not self._cached:
if not self._impl and self.http_req:
return None
return next(self._impl)
self._cache_index += 1
if self._cache_index > len(self._cache):
return None
return self._cache[self._cache_index]
__next__ = next
def read_proto_line(self, allowed):
"""Read a line from the wire.
- :param allowed: An iterable of command names that should be allowed.
- :return: A tuple of (command, value); see _split_proto_line.
- :raise UnexpectedCommandError: If an error occurred reading the line.
+ Args:
+ allowed: An iterable of command names that should be allowed.
+ Returns: A tuple of (command, value); see _split_proto_line.
+ Raises:
+ UnexpectedCommandError: If an error occurred reading the line.
"""
return _split_proto_line(self.proto.read_pkt_line(), allowed)
def _handle_shallow_request(self, wants):
while True:
command, val = self.read_proto_line(
(COMMAND_DEEPEN, COMMAND_SHALLOW))
if command == COMMAND_DEEPEN:
depth = val
break
self.client_shallow.add(val)
self.read_proto_line((None,)) # consume client's flush-pkt
shallow, not_shallow = _find_shallow(self.store, wants, depth)
# Update self.shallow instead of reassigning it since we passed a
# reference to it before this method was called.
self.shallow.update(shallow - not_shallow)
new_shallow = self.shallow - self.client_shallow
unshallow = self.unshallow = not_shallow & self.client_shallow
for sha in sorted(new_shallow):
self.proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha)
for sha in sorted(unshallow):
self.proto.write_pkt_line(COMMAND_UNSHALLOW + b' ' + sha)
self.proto.write_pkt_line(None)
def notify_done(self):
# relay the message down to the handler.
self.handler.notify_done()
def send_ack(self, sha, ack_type=b''):
if ack_type:
ack_type = b' ' + ack_type
self.proto.write_pkt_line(b'ACK ' + sha + ack_type + b'\n')
def send_nak(self):
self.proto.write_pkt_line(b'NAK\n')
def handle_done(self, done_required, done_received):
# Delegate this to the implementation.
return self._impl.handle_done(done_required, done_received)
def set_wants(self, wants):
self._wants = wants
def all_wants_satisfied(self, haves):
"""Check whether all the current wants are satisfied by a set of haves.
- :param haves: A set of commits we know the client has.
- :note: Wants are specified with set_wants rather than passed in since
+ Args:
+ haves: A set of commits we know the client has.
+ Note: Wants are specified with set_wants rather than passed in since
in the current interface they are determined outside this class.
"""
return _all_wants_satisfied(self.store, haves, self._wants)
def set_ack_type(self, ack_type):
impl_classes = {
MULTI_ACK: MultiAckGraphWalkerImpl,
MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
SINGLE_ACK: SingleAckGraphWalkerImpl,
}
self._impl = impl_classes[ack_type](self)
_GRAPH_WALKER_COMMANDS = (COMMAND_HAVE, COMMAND_DONE, None)
class SingleAckGraphWalkerImpl(object):
"""Graph walker implementation that speaks the single-ack protocol."""
def __init__(self, walker):
self.walker = walker
self._common = []
def ack(self, have_ref):
if not self._common:
self.walker.send_ack(have_ref)
self._common.append(have_ref)
def next(self):
command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
if command in (None, COMMAND_DONE):
# defer the handling of done
self.walker.notify_done()
return None
elif command == COMMAND_HAVE:
return sha
__next__ = next
def handle_done(self, done_required, done_received):
if not self._common:
self.walker.send_nak()
if done_required and not done_received:
# we are not done, especially when done is required; skip
# the pack for this request and especially do not handle
# the done.
return False
if not done_received and not self._common:
# Okay we are not actually done then since the walker picked
# up no haves. This is usually triggered when client attempts
# to pull from a source that has no common base_commit.
# See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
# test_multi_ack_stateless_nodone
return False
return True
class MultiAckGraphWalkerImpl(object):
"""Graph walker implementation that speaks the multi-ack protocol."""
def __init__(self, walker):
self.walker = walker
self._found_base = False
self._common = []
def ack(self, have_ref):
self._common.append(have_ref)
if not self._found_base:
self.walker.send_ack(have_ref, b'continue')
if self.walker.all_wants_satisfied(self._common):
self._found_base = True
# else we blind ack within next
def next(self):
while True:
command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
if command is None:
self.walker.send_nak()
# in multi-ack mode, a flush-pkt indicates the client wants to
# flush but more have lines are still coming
continue
elif command == COMMAND_DONE:
self.walker.notify_done()
return None
elif command == COMMAND_HAVE:
if self._found_base:
# blind ack
self.walker.send_ack(sha, b'continue')
return sha
__next__ = next
def handle_done(self, done_required, done_received):
if done_required and not done_received:
# we are not done, especially when done is required; skip
# the pack for this request and especially do not handle
# the done.
return False
if not done_received and not self._common:
# Okay we are not actually done then since the walker picked
# up no haves. This is usually triggered when client attempts
# to pull from a source that has no common base_commit.
# See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
# test_multi_ack_stateless_nodone
return False
# don't nak unless no common commits were found, even if not
# everything is satisfied
if self._common:
self.walker.send_ack(self._common[-1])
else:
self.walker.send_nak()
return True
class MultiAckDetailedGraphWalkerImpl(object):
"""Graph walker implementation speaking the multi-ack-detailed protocol."""
def __init__(self, walker):
self.walker = walker
self._common = []
def ack(self, have_ref):
# Should only be called iff have_ref is common
self._common.append(have_ref)
self.walker.send_ack(have_ref, b'common')
def next(self):
while True:
command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
if command is None:
if self.walker.all_wants_satisfied(self._common):
self.walker.send_ack(self._common[-1], b'ready')
self.walker.send_nak()
if self.walker.http_req:
# The HTTP version of this request a flush-pkt always
# signifies an end of request, so we also return
# nothing here as if we are done (but not really, as
# it depends on whether no-done capability was
# specified and that's handled in handle_done which
# may or may not call post_nodone_check depending on
# that).
return None
elif command == COMMAND_DONE:
# Let the walker know that we got a done.
self.walker.notify_done()
break
elif command == COMMAND_HAVE:
# return the sha and let the caller ACK it with the
# above ack method.
return sha
# don't nak unless no common commits were found, even if not
# everything is satisfied
__next__ = next
def handle_done(self, done_required, done_received):
if done_required and not done_received:
# we are not done, especially when done is required; skip
# the pack for this request and especially do not handle
# the done.
return False
if not done_received and not self._common:
# Okay we are not actually done then since the walker picked
# up no haves. This is usually triggered when client attempts
# to pull from a source that has no common base_commit.
# See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
# test_multi_ack_stateless_nodone
return False
# don't nak unless no common commits were found, even if not
# everything is satisfied
if self._common:
self.walker.send_ack(self._common[-1])
else:
self.walker.send_nak()
return True
class ReceivePackHandler(PackHandler):
"""Protocol handler for downloading a pack from the client."""
def __init__(self, backend, args, proto, http_req=None,
advertise_refs=False):
super(ReceivePackHandler, self).__init__(
backend, proto, http_req=http_req)
self.repo = backend.open_repository(args[0])
self.advertise_refs = advertise_refs
@classmethod
def capabilities(cls):
return [CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS,
CAPABILITY_QUIET, CAPABILITY_OFS_DELTA,
CAPABILITY_SIDE_BAND_64K, CAPABILITY_NO_DONE]
def _apply_pack(self, refs):
all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
AssertionError, socket.error, zlib.error,
ObjectFormatException)
status = []
will_send_pack = False
for command in refs:
if command[1] != ZERO_SHA:
will_send_pack = True
if will_send_pack:
# TODO: more informative error messages than just the exception
# string
try:
recv = getattr(self.proto, "recv", None)
self.repo.object_store.add_thin_pack(self.proto.read, recv)
status.append((b'unpack', b'ok'))
except all_exceptions as e:
status.append((b'unpack', str(e).replace('\n', '')))
# The pack may still have been moved in, but it may contain
# broken objects. We trust a later GC to clean it up.
else:
# The git protocol want to find a status entry related to unpack
# process even if no pack data has been sent.
status.append((b'unpack', b'ok'))
for oldsha, sha, ref in refs:
ref_status = b'ok'
try:
if sha == ZERO_SHA:
if CAPABILITY_DELETE_REFS not in self.capabilities():
raise GitProtocolError(
'Attempted to delete refs without delete-refs '
'capability.')
try:
self.repo.refs.remove_if_equals(ref, oldsha)
except all_exceptions:
ref_status = b'failed to delete'
else:
try:
self.repo.refs.set_if_equals(ref, oldsha, sha)
except all_exceptions:
ref_status = b'failed to write'
except KeyError:
ref_status = b'bad ref'
status.append((ref, ref_status))
return status
def _report_status(self, status):
if self.has_capability(CAPABILITY_SIDE_BAND_64K):
writer = BufferedPktLineWriter(
lambda d: self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d))
write = writer.write
def flush():
writer.flush()
self.proto.write_pkt_line(None)
else:
write = self.proto.write_pkt_line
def flush():
pass
for name, msg in status:
if name == b'unpack':
write(b'unpack ' + msg + b'\n')
elif msg == b'ok':
write(b'ok ' + name + b'\n')
else:
write(b'ng ' + name + b' ' + msg + b'\n')
write(None)
flush()
def handle(self):
if self.advertise_refs or not self.http_req:
refs = sorted(self.repo.get_refs().items())
symrefs = sorted(self.repo.refs.get_symrefs().items())
if not refs:
refs = [(CAPABILITIES_REF, ZERO_SHA)]
self.proto.write_pkt_line(
refs[0][1] + b' ' + refs[0][0] + b'\0' +
self.capability_line(
self.capabilities() + symref_capabilities(symrefs)) + b'\n')
for i in range(1, len(refs)):
ref = refs[i]
self.proto.write_pkt_line(ref[1] + b' ' + ref[0] + b'\n')
self.proto.write_pkt_line(None)
if self.advertise_refs:
return
client_refs = []
ref = self.proto.read_pkt_line()
# if ref is none then client doesnt want to send us anything..
if ref is None:
return
ref, caps = extract_capabilities(ref)
self.set_client_capabilities(caps)
# client will now send us a list of (oldsha, newsha, ref)
while ref:
client_refs.append(ref.split())
ref = self.proto.read_pkt_line()
# backend can now deal with this refs and read a pack using self.read
status = self._apply_pack(client_refs)
# when we have read all the pack from the client, send a status report
# if the client asked for it
if self.has_capability(CAPABILITY_REPORT_STATUS):
self._report_status(status)
class UploadArchiveHandler(Handler):
def __init__(self, backend, args, proto, http_req=None):
super(UploadArchiveHandler, self).__init__(backend, proto, http_req)
self.repo = backend.open_repository(args[0])
def handle(self):
def write(x):
return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
arguments = []
for pkt in self.proto.read_pkt_seq():
(key, value) = pkt.split(b' ', 1)
if key != b'argument':
raise GitProtocolError('unknown command %s' % key)
arguments.append(value.rstrip(b'\n'))
prefix = b''
format = 'tar'
i = 0
store = self.repo.object_store
while i < len(arguments):
argument = arguments[i]
if argument == b'--prefix':
i += 1
prefix = arguments[i]
elif argument == b'--format':
i += 1
format = arguments[i].decode('ascii')
else:
commit_sha = self.repo.refs[argument]
tree = store[store[commit_sha].tree]
i += 1
self.proto.write_pkt_line(b'ACK\n')
self.proto.write_pkt_line(None)
for chunk in tar_stream(
store, tree, mtime=time.time(), prefix=prefix, format=format):
write(chunk)
self.proto.write_pkt_line(None)
# Default handler classes for git services.
DEFAULT_HANDLERS = {
b'git-upload-pack': UploadPackHandler,
b'git-receive-pack': ReceivePackHandler,
b'git-upload-archive': UploadArchiveHandler,
}
class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
def __init__(self, handlers, *args, **kwargs):
self.handlers = handlers
SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
def handle(self):
proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
command, args = proto.read_cmd()
logger.info('Handling %s request, args=%s', command, args)
cls = self.handlers.get(command, None)
if not callable(cls):
raise GitProtocolError('Invalid service %s' % command)
h = cls(self.server.backend, args, proto)
h.handle()
class TCPGitServer(SocketServer.TCPServer):
allow_reuse_address = True
serve = SocketServer.TCPServer.serve_forever
def _make_handler(self, *args, **kwargs):
return TCPGitRequestHandler(self.handlers, *args, **kwargs)
def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
self.handlers = dict(DEFAULT_HANDLERS)
if handlers is not None:
self.handlers.update(handlers)
self.backend = backend
logger.info('Listening for TCP connections on %s:%d',
listen_addr, port)
SocketServer.TCPServer.__init__(self, (listen_addr, port),
self._make_handler)
def verify_request(self, request, client_address):
logger.info('Handling request from %s', client_address)
return True
def handle_error(self, request, client_address):
logger.exception('Exception happened during processing of request '
'from %s', client_address)
def main(argv=sys.argv):
"""Entry point for starting a TCP git server."""
import optparse
parser = optparse.OptionParser()
parser.add_option("-l", "--listen_address", dest="listen_address",
default="localhost",
help="Binding IP address.")
parser.add_option("-p", "--port", dest="port", type=int,
default=TCP_GIT_PORT,
help="Binding TCP port.")
options, args = parser.parse_args(argv)
log_utils.default_logging_config()
if len(args) > 1:
gitdir = args[1]
else:
gitdir = '.'
# TODO(jelmer): Support git-daemon-export-ok and --export-all.
backend = FileSystemBackend(gitdir)
server = TCPGitServer(backend, options.listen_address, options.port)
server.serve_forever()
def serve_command(handler_cls, argv=sys.argv, backend=None, inf=sys.stdin,
outf=sys.stdout):
"""Serve a single command.
This is mostly useful for the implementation of commands used by e.g.
git+ssh.
- :param handler_cls: `Handler` class to use for the request
- :param argv: execv-style command-line arguments. Defaults to sys.argv.
- :param backend: `Backend` to use
- :param inf: File-like object to read from, defaults to standard input.
- :param outf: File-like object to write to, defaults to standard output.
- :return: Exit code for use with sys.exit. 0 on success, 1 on failure.
+ Args:
+ handler_cls: `Handler` class to use for the request
+ argv: execv-style command-line arguments. Defaults to sys.argv.
+ backend: `Backend` to use
+ inf: File-like object to read from, defaults to standard input.
+ outf: File-like object to write to, defaults to standard output.
+ Returns: Exit code for use with sys.exit. 0 on success, 1 on failure.
"""
if backend is None:
backend = FileSystemBackend()
def send_fn(data):
outf.write(data)
outf.flush()
proto = Protocol(inf.read, send_fn)
handler = handler_cls(backend, argv[1:], proto)
# FIXME: Catch exceptions and write a single-line summary to outf.
handler.handle()
return 0
def generate_info_refs(repo):
"""Generate an info refs file."""
refs = repo.get_refs()
return write_info_refs(refs, repo.object_store)
def generate_objects_info_packs(repo):
"""Generate an index for for packs."""
for pack in repo.object_store.packs:
yield (
b'P ' + pack.data.filename.encode(sys.getfilesystemencoding()) +
b'\n')
def update_server_info(repo):
"""Generate server info for dumb file access.
This generates info/refs and objects/info/packs,
similar to "git update-server-info".
"""
repo._put_named_file(
os.path.join('info', 'refs'),
b"".join(generate_info_refs(repo)))
repo._put_named_file(
os.path.join('objects', 'info', 'packs'),
b"".join(generate_objects_info_packs(repo)))
if __name__ == '__main__':
main()
diff --git a/dulwich/stash.py b/dulwich/stash.py
index 63f2a364..fdf9e7bc 100644
--- a/dulwich/stash.py
+++ b/dulwich/stash.py
@@ -1,119 +1,120 @@
# stash.py
# Copyright (C) 2018 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Stash handling."""
from __future__ import absolute_import
import errno
import os
from dulwich.file import GitFile
from dulwich.index import (
commit_tree,
iter_fresh_objects,
)
from dulwich.reflog import read_reflog
DEFAULT_STASH_REF = b"refs/stash"
class Stash(object):
"""A Git stash.
Note that this doesn't currently update the working tree.
"""
def __init__(self, repo, ref=DEFAULT_STASH_REF):
self._ref = ref
self._repo = repo
def stashes(self):
reflog_path = os.path.join(
self._repo.commondir(), 'logs', self._ref)
try:
with GitFile(reflog_path, 'rb') as f:
return reversed(list(read_reflog(f)))
except EnvironmentError as e:
if e.errno == errno.ENOENT:
return []
raise
@classmethod
def from_repo(cls, repo):
"""Create a new stash from a Repo object."""
return cls(repo)
def drop(self, index):
"""Drop entry with specified index."""
raise NotImplementedError(self.drop)
def pop(self, index):
raise NotImplementedError(self.drop)
def push(self, committer=None, author=None, message=None):
"""Create a new stash.
- :param committer: Optional committer name to use
- :param author: Optional author name to use
- :param message: Optional commit message
+ Args:
+ committer: Optional committer name to use
+ author: Optional author name to use
+ message: Optional commit message
"""
# First, create the index commit.
commit_kwargs = {}
if committer is not None:
commit_kwargs['committer'] = committer
if author is not None:
commit_kwargs['author'] = author
index = self._repo.open_index()
index_tree_id = index.commit(self._repo.object_store)
index_commit_id = self._repo.do_commit(
ref=None, tree=index_tree_id,
message=b"Index stash",
merge_heads=[self._repo.head()],
**commit_kwargs)
# Then, the working tree one.
stash_tree_id = commit_tree(
self._repo.object_store,
iter_fresh_objects(
index, self._repo.path,
object_store=self._repo.object_store))
if message is None:
message = b"A stash on " + self._repo.head()
# TODO(jelmer): Just pass parents into do_commit()?
self._repo.refs[self._ref] = self._repo.head()
cid = self._repo.do_commit(
ref=self._ref, tree=stash_tree_id,
message=message,
merge_heads=[index_commit_id],
**commit_kwargs)
return cid
def __getitem__(self, index):
return self._stashes()[index]
def __len__(self):
return len(self._stashes())
diff --git a/dulwich/tests/__init__.py b/dulwich/tests/__init__.py
index 2984efc9..eea22fa3 100644
--- a/dulwich/tests/__init__.py
+++ b/dulwich/tests/__init__.py
@@ -1,195 +1,197 @@
# __init__.py -- The tests for dulwich
# Copyright (C) 2007 James Westby
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Tests for Dulwich."""
import doctest
import os
import shutil
import subprocess
import sys
import tempfile
# If Python itself provides an exception, use that
import unittest
from unittest import ( # noqa: F401
SkipTest,
TestCase as _TestCase,
skipIf,
expectedFailure,
)
class TestCase(_TestCase):
def setUp(self):
super(TestCase, self).setUp()
self._old_home = os.environ.get("HOME")
os.environ["HOME"] = "/nonexistant"
def tearDown(self):
super(TestCase, self).tearDown()
if self._old_home:
os.environ["HOME"] = self._old_home
else:
del os.environ["HOME"]
class BlackboxTestCase(TestCase):
"""Blackbox testing."""
# TODO(jelmer): Include more possible binary paths.
bin_directories = [os.path.abspath(os.path.join(
os.path.dirname(__file__), "..", "..", "bin")), '/usr/bin',
'/usr/local/bin']
def bin_path(self, name):
"""Determine the full path of a binary.
- :param name: Name of the script
- :return: Full path
+ Args:
+ name: Name of the script
+ Returns: Full path
"""
for d in self.bin_directories:
p = os.path.join(d, name)
if os.path.isfile(p):
return p
else:
raise SkipTest("Unable to find binary %s" % name)
def run_command(self, name, args):
"""Run a Dulwich command.
- :param name: Name of the command, as it exists in bin/
- :param args: Arguments to the command
+ Args:
+ name: Name of the command, as it exists in bin/
+ args: Arguments to the command
"""
env = dict(os.environ)
env["PYTHONPATH"] = os.pathsep.join(sys.path)
# Since they don't have any extensions, Windows can't recognize
# executablility of the Python files in /bin. Even then, we'd have to
# expect the user to set up file associations for .py files.
#
# Save us from all that headache and call python with the bin script.
argv = [sys.executable, self.bin_path(name)] + args
return subprocess.Popen(
argv,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE, stderr=subprocess.PIPE,
env=env)
def self_test_suite():
names = [
'archive',
'blackbox',
'client',
'config',
'diff_tree',
'fastexport',
'file',
'grafts',
'greenthreads',
'hooks',
'ignore',
'index',
'line_ending',
'lru_cache',
'mailmap',
'objects',
'objectspec',
'object_store',
'missing_obj_finder',
'pack',
'patch',
'porcelain',
'protocol',
'reflog',
'refs',
'repository',
'server',
'stash',
'utils',
'walk',
'web',
]
module_names = ['dulwich.tests.test_' + name for name in names]
loader = unittest.TestLoader()
return loader.loadTestsFromNames(module_names)
def tutorial_test_suite():
import dulwich.client # noqa: F401
import dulwich.config # noqa: F401
import dulwich.index # noqa: F401
import dulwich.reflog # noqa: F401
import dulwich.repo # noqa: F401
import dulwich.server # noqa: F401
import dulwich.patch # noqa: F401
tutorial = [
'introduction',
'file-format',
'repo',
'object-store',
'remote',
'conclusion',
]
tutorial_files = ["../../docs/tutorial/%s.txt" % name for name in tutorial]
def setup(test):
test.__old_cwd = os.getcwd()
test.tempdir = tempfile.mkdtemp()
test.globs.update({'tempdir': test.tempdir})
os.chdir(test.tempdir)
def teardown(test):
os.chdir(test.__old_cwd)
shutil.rmtree(test.tempdir)
return doctest.DocFileSuite(
module_relative=True, package='dulwich.tests',
setUp=setup, tearDown=teardown, *tutorial_files)
def nocompat_test_suite():
result = unittest.TestSuite()
result.addTests(self_test_suite())
result.addTests(tutorial_test_suite())
from dulwich.contrib import test_suite as contrib_test_suite
result.addTests(contrib_test_suite())
return result
def compat_test_suite():
result = unittest.TestSuite()
from dulwich.tests.compat import test_suite as compat_test_suite
result.addTests(compat_test_suite())
return result
def test_suite():
result = unittest.TestSuite()
result.addTests(self_test_suite())
if sys.platform != 'win32':
result.addTests(tutorial_test_suite())
from dulwich.tests.compat import test_suite as compat_test_suite
result.addTests(compat_test_suite())
from dulwich.contrib import test_suite as contrib_test_suite
result.addTests(contrib_test_suite())
return result
diff --git a/dulwich/tests/compat/test_repository.py b/dulwich/tests/compat/test_repository.py
index 4d50f92e..3bd4107d 100644
--- a/dulwich/tests/compat/test_repository.py
+++ b/dulwich/tests/compat/test_repository.py
@@ -1,217 +1,218 @@
# test_repo.py -- Git repo compatibility tests
# Copyright (C) 2010 Google, Inc.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Compatibility tests for dulwich repositories."""
from io import BytesIO
from itertools import chain
import os
import tempfile
from dulwich.objects import (
hex_to_sha,
)
from dulwich.repo import (
check_ref_format,
Repo,
)
from dulwich.tests.compat.utils import (
require_git_version,
rmtree_ro,
run_git_or_fail,
CompatTestCase,
)
class ObjectStoreTestCase(CompatTestCase):
"""Tests for git repository compatibility."""
def setUp(self):
super(ObjectStoreTestCase, self).setUp()
self._repo = self.import_repo('server_new.export')
def _run_git(self, args):
return run_git_or_fail(args, cwd=self._repo.path)
def _parse_refs(self, output):
refs = {}
for line in BytesIO(output):
fields = line.rstrip(b'\n').split(b' ')
self.assertEqual(3, len(fields))
refname, type_name, sha = fields
check_ref_format(refname[5:])
hex_to_sha(sha)
refs[refname] = (type_name, sha)
return refs
def _parse_objects(self, output):
return set(s.rstrip(b'\n').split(b' ')[0] for s in BytesIO(output))
def test_bare(self):
self.assertTrue(self._repo.bare)
self.assertFalse(os.path.exists(os.path.join(self._repo.path, '.git')))
def test_head(self):
output = self._run_git(['rev-parse', 'HEAD'])
head_sha = output.rstrip(b'\n')
hex_to_sha(head_sha)
self.assertEqual(head_sha, self._repo.refs[b'HEAD'])
def test_refs(self):
output = self._run_git(
['for-each-ref', '--format=%(refname) %(objecttype) %(objectname)'])
expected_refs = self._parse_refs(output)
actual_refs = {}
for refname, sha in self._repo.refs.as_dict().items():
if refname == b'HEAD':
continue # handled in test_head
obj = self._repo[sha]
self.assertEqual(sha, obj.id)
actual_refs[refname] = (obj.type_name, obj.id)
self.assertEqual(expected_refs, actual_refs)
# TODO(dborowitz): peeled ref tests
def _get_loose_shas(self):
output = self._run_git(
['rev-list', '--all', '--objects', '--unpacked'])
return self._parse_objects(output)
def _get_all_shas(self):
output = self._run_git(['rev-list', '--all', '--objects'])
return self._parse_objects(output)
def assertShasMatch(self, expected_shas, actual_shas_iter):
actual_shas = set()
for sha in actual_shas_iter:
obj = self._repo[sha]
self.assertEqual(sha, obj.id)
actual_shas.add(sha)
self.assertEqual(expected_shas, actual_shas)
def test_loose_objects(self):
# TODO(dborowitz): This is currently not very useful since
# fast-imported repos only contained packed objects.
expected_shas = self._get_loose_shas()
self.assertShasMatch(expected_shas,
self._repo.object_store._iter_loose_objects())
def test_packed_objects(self):
expected_shas = self._get_all_shas() - self._get_loose_shas()
self.assertShasMatch(expected_shas,
chain(*self._repo.object_store.packs))
def test_all_objects(self):
expected_shas = self._get_all_shas()
self.assertShasMatch(expected_shas, iter(self._repo.object_store))
class WorkingTreeTestCase(ObjectStoreTestCase):
"""Test for compatibility with git-worktree."""
min_git_version = (2, 5, 0)
def create_new_worktree(self, repo_dir, branch):
"""Create a new worktree using git-worktree.
- :param repo_dir: The directory of the main working tree.
- :param branch: The branch or commit to checkout in the new worktree.
+ Args:
+ repo_dir: The directory of the main working tree.
+ branch: The branch or commit to checkout in the new worktree.
- :returns: The path to the new working tree.
+ Returns: The path to the new working tree.
"""
temp_dir = tempfile.mkdtemp()
run_git_or_fail(['worktree', 'add', temp_dir, branch],
cwd=repo_dir)
self.addCleanup(rmtree_ro, temp_dir)
return temp_dir
def setUp(self):
super(WorkingTreeTestCase, self).setUp()
self._worktree_path = self.create_new_worktree(
self._repo.path, 'branch')
self._worktree_repo = Repo(self._worktree_path)
self.addCleanup(self._worktree_repo.close)
self._mainworktree_repo = self._repo
self._number_of_working_tree = 2
self._repo = self._worktree_repo
def test_refs(self):
super(WorkingTreeTestCase, self).test_refs()
self.assertEqual(self._mainworktree_repo.refs.allkeys(),
self._repo.refs.allkeys())
def test_head_equality(self):
self.assertNotEqual(self._repo.refs[b'HEAD'],
self._mainworktree_repo.refs[b'HEAD'])
def test_bare(self):
self.assertFalse(self._repo.bare)
self.assertTrue(os.path.isfile(os.path.join(self._repo.path, '.git')))
def _parse_worktree_list(self, output):
worktrees = []
for line in BytesIO(output):
fields = line.rstrip(b'\n').split()
worktrees.append(tuple(f.decode() for f in fields))
return worktrees
def test_git_worktree_list(self):
# 'git worktree list' was introduced in 2.7.0
require_git_version((2, 7, 0))
output = run_git_or_fail(['worktree', 'list'], cwd=self._repo.path)
worktrees = self._parse_worktree_list(output)
self.assertEqual(len(worktrees), self._number_of_working_tree)
self.assertEqual(worktrees[0][1], '(bare)')
self.assertEqual(os.path.normcase(worktrees[0][0]),
os.path.normcase(self._mainworktree_repo.path))
output = run_git_or_fail(
['worktree', 'list'], cwd=self._mainworktree_repo.path)
worktrees = self._parse_worktree_list(output)
self.assertEqual(len(worktrees), self._number_of_working_tree)
self.assertEqual(worktrees[0][1], '(bare)')
self.assertEqual(os.path.normcase(worktrees[0][0]),
os.path.normcase(self._mainworktree_repo.path))
class InitNewWorkingDirectoryTestCase(WorkingTreeTestCase):
"""Test compatibility of Repo.init_new_working_directory."""
min_git_version = (2, 5, 0)
def setUp(self):
super(InitNewWorkingDirectoryTestCase, self).setUp()
self._other_worktree = self._repo
worktree_repo_path = tempfile.mkdtemp()
self.addCleanup(rmtree_ro, worktree_repo_path)
self._repo = Repo._init_new_working_directory(
worktree_repo_path, self._mainworktree_repo)
self.addCleanup(self._repo.close)
self._number_of_working_tree = 3
def test_head_equality(self):
self.assertEqual(self._repo.refs[b'HEAD'],
self._mainworktree_repo.refs[b'HEAD'])
def test_bare(self):
self.assertFalse(self._repo.bare)
self.assertTrue(os.path.isfile(os.path.join(self._repo.path, '.git')))
diff --git a/dulwich/tests/compat/utils.py b/dulwich/tests/compat/utils.py
index 6f9f6915..1f36a219 100644
--- a/dulwich/tests/compat/utils.py
+++ b/dulwich/tests/compat/utils.py
@@ -1,256 +1,264 @@
# utils.py -- Git compatibility utilities
# Copyright (C) 2010 Google, Inc.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Utilities for interacting with cgit."""
import errno
import functools
import os
import shutil
import socket
import stat
import subprocess
import sys
import tempfile
import time
from dulwich.repo import Repo
from dulwich.protocol import TCP_GIT_PORT
from dulwich.tests import (
SkipTest,
TestCase,
)
_DEFAULT_GIT = 'git'
_VERSION_LEN = 4
_REPOS_DATA_DIR = os.path.abspath(os.path.join(
os.path.dirname(__file__), os.pardir, 'data', 'repos'))
def git_version(git_path=_DEFAULT_GIT):
"""Attempt to determine the version of git currently installed.
- :param git_path: Path to the git executable; defaults to the version in
+ Args:
+ git_path: Path to the git executable; defaults to the version in
the system path.
- :return: A tuple of ints of the form (major, minor, point, sub-point), or
+ Returns: A tuple of ints of the form (major, minor, point, sub-point), or
None if no git installation was found.
"""
try:
output = run_git_or_fail(['--version'], git_path=git_path)
except OSError:
return None
version_prefix = b'git version '
if not output.startswith(version_prefix):
return None
parts = output[len(version_prefix):].split(b'.')
nums = []
for part in parts:
try:
nums.append(int(part))
except ValueError:
break
while len(nums) < _VERSION_LEN:
nums.append(0)
return tuple(nums[:_VERSION_LEN])
def require_git_version(required_version, git_path=_DEFAULT_GIT):
"""Require git version >= version, or skip the calling test.
- :param required_version: A tuple of ints of the form (major, minor, point,
+ Args:
+ required_version: A tuple of ints of the form (major, minor, point,
sub-point); ommitted components default to 0.
- :param git_path: Path to the git executable; defaults to the version in
+ git_path: Path to the git executable; defaults to the version in
the system path.
- :raise ValueError: if the required version tuple has too many parts.
- :raise SkipTest: if no suitable git version was found at the given path.
+ Raises:
+ ValueError: if the required version tuple has too many parts.
+ SkipTest: if no suitable git version was found at the given path.
"""
found_version = git_version(git_path=git_path)
if found_version is None:
raise SkipTest('Test requires git >= %s, but c git not found' %
(required_version, ))
if len(required_version) > _VERSION_LEN:
raise ValueError('Invalid version tuple %s, expected %i parts' %
(required_version, _VERSION_LEN))
required_version = list(required_version)
while len(found_version) < len(required_version):
required_version.append(0)
required_version = tuple(required_version)
if found_version < required_version:
required_version = '.'.join(map(str, required_version))
found_version = '.'.join(map(str, found_version))
raise SkipTest('Test requires git >= %s, found %s' %
(required_version, found_version))
def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
**popen_kwargs):
"""Run a git command.
Input is piped from the input parameter and output is sent to the standard
streams, unless capture_stdout is set.
- :param args: A list of args to the git command.
- :param git_path: Path to to the git executable.
- :param input: Input data to be sent to stdin.
- :param capture_stdout: Whether to capture and return stdout.
- :param popen_kwargs: Additional kwargs for subprocess.Popen;
+ Args:
+ args: A list of args to the git command.
+ git_path: Path to to the git executable.
+ input: Input data to be sent to stdin.
+ capture_stdout: Whether to capture and return stdout.
+ popen_kwargs: Additional kwargs for subprocess.Popen;
stdin/stdout args are ignored.
- :return: A tuple of (returncode, stdout contents). If capture_stdout is
+ Returns: A tuple of (returncode, stdout contents). If capture_stdout is
False, None will be returned as stdout contents.
- :raise OSError: if the git executable was not found.
+ Raises:
+ OSError: if the git executable was not found.
"""
env = popen_kwargs.pop('env', {})
env['LC_ALL'] = env['LANG'] = 'C'
args = [git_path] + args
popen_kwargs['stdin'] = subprocess.PIPE
if capture_stdout:
popen_kwargs['stdout'] = subprocess.PIPE
else:
popen_kwargs.pop('stdout', None)
p = subprocess.Popen(args, env=env, **popen_kwargs)
stdout, stderr = p.communicate(input=input)
return (p.returncode, stdout)
def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
"""Run a git command, capture stdout/stderr, and fail if git fails."""
if 'stderr' not in popen_kwargs:
popen_kwargs['stderr'] = subprocess.STDOUT
returncode, stdout = run_git(args, git_path=git_path, input=input,
capture_stdout=True, **popen_kwargs)
if returncode != 0:
raise AssertionError("git with args %r failed with %d: %r" % (
args, returncode, stdout))
return stdout
def import_repo_to_dir(name):
"""Import a repo from a fast-export file in a temporary directory.
These are used rather than binary repos for compat tests because they are
more compact and human-editable, and we already depend on git.
- :param name: The name of the repository export file, relative to
+ Args:
+ name: The name of the repository export file, relative to
dulwich/tests/data/repos.
- :returns: The path to the imported repository.
+ Returns: The path to the imported repository.
"""
temp_dir = tempfile.mkdtemp()
export_path = os.path.join(_REPOS_DATA_DIR, name)
temp_repo_dir = os.path.join(temp_dir, name)
export_file = open(export_path, 'rb')
run_git_or_fail(['init', '--quiet', '--bare', temp_repo_dir])
run_git_or_fail(['fast-import'], input=export_file.read(),
cwd=temp_repo_dir)
export_file.close()
return temp_repo_dir
def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
"""Check for a running TCP daemon.
Defaults to checking 10 times with a delay of 0.1 sec between tries.
- :param limit: Number of attempts before deciding no daemon is running.
- :param delay: Delay between connection attempts.
- :param timeout: Socket timeout for connection attempts.
- :param port: Port on which we expect the daemon to appear.
- :returns: A boolean, true if a daemon is running on the specified port,
+ Args:
+ limit: Number of attempts before deciding no daemon is running.
+ delay: Delay between connection attempts.
+ timeout: Socket timeout for connection attempts.
+ port: Port on which we expect the daemon to appear.
+ Returns: A boolean, true if a daemon is running on the specified port,
false if not.
"""
for _ in range(limit):
time.sleep(delay)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.settimeout(delay)
try:
s.connect(('localhost', port))
return True
except socket.timeout:
pass
except socket.error as e:
if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
raise
elif e.args[0] != errno.ECONNREFUSED:
raise
finally:
s.close()
return False
class CompatTestCase(TestCase):
"""Test case that requires git for compatibility checks.
Subclasses can change the git version required by overriding
min_git_version.
"""
min_git_version = (1, 5, 0)
def setUp(self):
super(CompatTestCase, self).setUp()
require_git_version(self.min_git_version)
def assertObjectStoreEqual(self, store1, store2):
self.assertEqual(sorted(set(store1)), sorted(set(store2)))
def assertReposEqual(self, repo1, repo2):
self.assertEqual(repo1.get_refs(), repo2.get_refs())
self.assertObjectStoreEqual(repo1.object_store, repo2.object_store)
def assertReposNotEqual(self, repo1, repo2):
refs1 = repo1.get_refs()
objs1 = set(repo1.object_store)
refs2 = repo2.get_refs()
objs2 = set(repo2.object_store)
self.assertFalse(refs1 == refs2 and objs1 == objs2)
def import_repo(self, name):
"""Import a repo from a fast-export file in a temporary directory.
- :param name: The name of the repository export file, relative to
+ Args:
+ name: The name of the repository export file, relative to
dulwich/tests/data/repos.
- :returns: An initialized Repo object that lives in a temporary
+ Returns: An initialized Repo object that lives in a temporary
directory.
"""
path = import_repo_to_dir(name)
repo = Repo(path)
def cleanup():
repo.close()
rmtree_ro(os.path.dirname(path.rstrip(os.sep)))
self.addCleanup(cleanup)
return repo
if sys.platform == 'win32':
def remove_ro(action, name, exc):
os.chmod(name, stat.S_IWRITE)
os.remove(name)
rmtree_ro = functools.partial(shutil.rmtree, onerror=remove_ro)
else:
rmtree_ro = shutil.rmtree
diff --git a/dulwich/tests/test_fastexport.py b/dulwich/tests/test_fastexport.py
index 0f43efd4..16813061 100644
--- a/dulwich/tests/test_fastexport.py
+++ b/dulwich/tests/test_fastexport.py
@@ -1,260 +1,261 @@
# test_fastexport.py -- Fast export/import functionality
# Copyright (C) 2010 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
from io import BytesIO
import stat
from dulwich.object_store import (
MemoryObjectStore,
)
from dulwich.objects import (
Blob,
Commit,
Tree,
ZERO_SHA,
)
from dulwich.repo import (
MemoryRepo,
)
from dulwich.tests import (
SkipTest,
TestCase,
)
from dulwich.tests.utils import (
build_commit_graph,
)
class GitFastExporterTests(TestCase):
"""Tests for the GitFastExporter tests."""
def setUp(self):
super(GitFastExporterTests, self).setUp()
self.store = MemoryObjectStore()
self.stream = BytesIO()
try:
from dulwich.fastexport import GitFastExporter
except ImportError:
raise SkipTest("python-fastimport not available")
self.fastexporter = GitFastExporter(self.stream, self.store)
def test_emit_blob(self):
b = Blob()
b.data = b"fooBAR"
self.fastexporter.emit_blob(b)
self.assertEqual(b'blob\nmark :1\ndata 6\nfooBAR\n',
self.stream.getvalue())
def test_emit_commit(self):
b = Blob()
b.data = b"FOO"
t = Tree()
t.add(b"foo", stat.S_IFREG | 0o644, b.id)
c = Commit()
c.committer = c.author = b"Jelmer "
c.author_time = c.commit_time = 1271345553
c.author_timezone = c.commit_timezone = 0
c.message = b"msg"
c.tree = t.id
self.store.add_objects([(b, None), (t, None), (c, None)])
self.fastexporter.emit_commit(c, b"refs/heads/master")
self.assertEqual(b"""blob
mark :1
data 3
FOO
commit refs/heads/master
mark :2
author Jelmer 1271345553 +0000
committer Jelmer 1271345553 +0000
data 3
msg
M 644 :1 foo
""", self.stream.getvalue())
class GitImportProcessorTests(TestCase):
"""Tests for the GitImportProcessor tests."""
def setUp(self):
super(GitImportProcessorTests, self).setUp()
self.repo = MemoryRepo()
try:
from dulwich.fastexport import GitImportProcessor
except ImportError:
raise SkipTest("python-fastimport not available")
self.processor = GitImportProcessor(self.repo)
def test_reset_handler(self):
from fastimport import commands
[c1] = build_commit_graph(self.repo.object_store, [[1]])
cmd = commands.ResetCommand(b"refs/heads/foo", c1.id)
self.processor.reset_handler(cmd)
self.assertEqual(c1.id, self.repo.get_refs()[b"refs/heads/foo"])
self.assertEqual(c1.id, self.processor.last_commit)
def test_reset_handler_marker(self):
from fastimport import commands
[c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
self.processor.markers[b'10'] = c1.id
cmd = commands.ResetCommand(b"refs/heads/foo", b':10')
self.processor.reset_handler(cmd)
self.assertEqual(c1.id, self.repo.get_refs()[b"refs/heads/foo"])
def test_reset_handler_default(self):
from fastimport import commands
[c1, c2] = build_commit_graph(self.repo.object_store, [[1], [2]])
cmd = commands.ResetCommand(b"refs/heads/foo", None)
self.processor.reset_handler(cmd)
self.assertEqual(ZERO_SHA, self.repo.get_refs()[b"refs/heads/foo"])
def test_commit_handler(self):
from fastimport import commands
cmd = commands.CommitCommand(
b"refs/heads/foo", b"mrkr",
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
b"FOO", None, [], [])
self.processor.commit_handler(cmd)
commit = self.repo[self.processor.last_commit]
self.assertEqual(b"Jelmer ", commit.author)
self.assertEqual(b"Jelmer ", commit.committer)
self.assertEqual(b"FOO", commit.message)
self.assertEqual([], commit.parents)
self.assertEqual(432432432.0, commit.commit_time)
self.assertEqual(432432432.0, commit.author_time)
self.assertEqual(3600, commit.commit_timezone)
self.assertEqual(3600, commit.author_timezone)
self.assertEqual(commit, self.repo[b"refs/heads/foo"])
def test_commit_handler_markers(self):
from fastimport import commands
[c1, c2, c3] = build_commit_graph(self.repo.object_store,
[[1], [2], [3]])
self.processor.markers[b'10'] = c1.id
self.processor.markers[b'42'] = c2.id
self.processor.markers[b'98'] = c3.id
cmd = commands.CommitCommand(
b"refs/heads/foo", b"mrkr",
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
b"FOO", b':10', [b':42', b':98'], [])
self.processor.commit_handler(cmd)
commit = self.repo[self.processor.last_commit]
self.assertEqual(c1.id, commit.parents[0])
self.assertEqual(c2.id, commit.parents[1])
self.assertEqual(c3.id, commit.parents[2])
def test_import_stream(self):
markers = self.processor.import_stream(BytesIO(b"""blob
mark :1
data 11
text for a
commit refs/heads/master
mark :2
committer Joe Foo 1288287382 +0000
data 20
M 100644 :1 a
"""))
self.assertEqual(2, len(markers))
self.assertTrue(isinstance(self.repo[markers[b"1"]], Blob))
self.assertTrue(isinstance(self.repo[markers[b"2"]], Commit))
def test_file_add(self):
from fastimport import commands
cmd = commands.BlobCommand(b"23", b"data")
self.processor.blob_handler(cmd)
cmd = commands.CommitCommand(
b"refs/heads/foo", b"mrkr",
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
b"FOO", None, [],
[commands.FileModifyCommand(b"path", 0o100644, b":23", None)])
self.processor.commit_handler(cmd)
commit = self.repo[self.processor.last_commit]
self.assertEqual([
(b'path', 0o100644, b'6320cd248dd8aeaab759d5871f8781b5c0505172')],
self.repo[commit.tree].items())
def simple_commit(self):
from fastimport import commands
cmd = commands.BlobCommand(b"23", b"data")
self.processor.blob_handler(cmd)
cmd = commands.CommitCommand(
b"refs/heads/foo", b"mrkr",
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
b"FOO", None, [],
[commands.FileModifyCommand(b"path", 0o100644, b":23", None)])
self.processor.commit_handler(cmd)
commit = self.repo[self.processor.last_commit]
return commit
def make_file_commit(self, file_cmds):
"""Create a trivial commit with the specified file commands.
- :param file_cmds: File commands to run.
- :return: The created commit object
+ Args:
+ file_cmds: File commands to run.
+ Returns: The created commit object
"""
from fastimport import commands
cmd = commands.CommitCommand(
b"refs/heads/foo", b"mrkr",
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
(b"Jelmer", b"jelmer@samba.org", 432432432.0, 3600),
b"FOO", None, [], file_cmds)
self.processor.commit_handler(cmd)
return self.repo[self.processor.last_commit]
def test_file_copy(self):
from fastimport import commands
self.simple_commit()
commit = self.make_file_commit(
[commands.FileCopyCommand(b"path", b"new_path")])
self.assertEqual([
(b'new_path', 0o100644,
b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
(b'path', 0o100644,
b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
], self.repo[commit.tree].items())
def test_file_move(self):
from fastimport import commands
self.simple_commit()
commit = self.make_file_commit(
[commands.FileRenameCommand(b"path", b"new_path")])
self.assertEqual([
(b'new_path', 0o100644,
b'6320cd248dd8aeaab759d5871f8781b5c0505172'),
], self.repo[commit.tree].items())
def test_file_delete(self):
from fastimport import commands
self.simple_commit()
commit = self.make_file_commit([commands.FileDeleteCommand(b"path")])
self.assertEqual([], self.repo[commit.tree].items())
def test_file_deleteall(self):
from fastimport import commands
self.simple_commit()
commit = self.make_file_commit([commands.FileDeleteAllCommand()])
self.assertEqual([], self.repo[commit.tree].items())
diff --git a/dulwich/tests/utils.py b/dulwich/tests/utils.py
index cfe446b7..a49119d5 100644
--- a/dulwich/tests/utils.py
+++ b/dulwich/tests/utils.py
@@ -1,363 +1,371 @@
# utils.py -- Test utilities for Dulwich.
# Copyright (C) 2010 Google, Inc.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Utility functions common to Dulwich tests."""
import datetime
import os
import shutil
import tempfile
import time
import types
import warnings
from dulwich.index import (
commit_tree,
)
from dulwich.objects import (
FixedSha,
Commit,
Tag,
object_class,
)
from dulwich.pack import (
OFS_DELTA,
REF_DELTA,
DELTA_TYPES,
obj_sha,
SHA1Writer,
write_pack_header,
write_pack_object,
create_delta,
)
from dulwich.repo import Repo
from dulwich.tests import ( # noqa: F401
skipIf,
SkipTest,
)
# Plain files are very frequently used in tests, so let the mode be very short.
F = 0o100644 # Shorthand mode for Files.
def open_repo(name, temp_dir=None):
"""Open a copy of a repo in a temporary directory.
Use this function for accessing repos in dulwich/tests/data/repos to avoid
accidentally or intentionally modifying those repos in place. Use
tear_down_repo to delete any temp files created.
- :param name: The name of the repository, relative to
+ Args:
+ name: The name of the repository, relative to
dulwich/tests/data/repos
- :param temp_dir: temporary directory to initialize to. If not provided, a
+ temp_dir: temporary directory to initialize to. If not provided, a
temporary directory will be created.
- :returns: An initialized Repo object that lives in a temporary directory.
+ Returns: An initialized Repo object that lives in a temporary directory.
"""
if temp_dir is None:
temp_dir = tempfile.mkdtemp()
repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
temp_repo_dir = os.path.join(temp_dir, name)
shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
return Repo(temp_repo_dir)
def tear_down_repo(repo):
"""Tear down a test repository."""
repo.close()
temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
shutil.rmtree(temp_dir)
def make_object(cls, **attrs):
"""Make an object for testing and assign some members.
This method creates a new subclass to allow arbitrary attribute
reassignment, which is not otherwise possible with objects having
__slots__.
- :param attrs: dict of attributes to set on the new object.
- :return: A newly initialized object of type cls.
+ Args:
+ attrs: dict of attributes to set on the new object.
+ Returns: A newly initialized object of type cls.
"""
class TestObject(cls):
"""Class that inherits from the given class, but without __slots__.
Note that classes with __slots__ can't have arbitrary attributes
monkey-patched in, so this is a class that is exactly the same only
with a __dict__ instead of __slots__.
"""
pass
TestObject.__name__ = 'TestObject_' + cls.__name__
obj = TestObject()
for name, value in attrs.items():
if name == 'id':
# id property is read-only, so we overwrite sha instead.
sha = FixedSha(value)
obj.sha = lambda: sha
else:
setattr(obj, name, value)
return obj
def make_commit(**attrs):
"""Make a Commit object with a default set of members.
- :param attrs: dict of attributes to overwrite from the default values.
- :return: A newly initialized Commit object.
+ Args:
+ attrs: dict of attributes to overwrite from the default values.
+ Returns: A newly initialized Commit object.
"""
default_time = 1262304000 # 2010-01-01 00:00:00
all_attrs = {'author': b'Test Author ',
'author_time': default_time,
'author_timezone': 0,
'committer': b'Test Committer ',
'commit_time': default_time,
'commit_timezone': 0,
'message': b'Test message.',
'parents': [],
'tree': b'0' * 40}
all_attrs.update(attrs)
return make_object(Commit, **all_attrs)
def make_tag(target, **attrs):
"""Make a Tag object with a default set of values.
- :param target: object to be tagged (Commit, Blob, Tree, etc)
- :param attrs: dict of attributes to overwrite from the default values.
- :return: A newly initialized Tag object.
+ Args:
+ target: object to be tagged (Commit, Blob, Tree, etc)
+ attrs: dict of attributes to overwrite from the default values.
+ Returns: A newly initialized Tag object.
"""
target_id = target.id
target_type = object_class(target.type_name)
default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple()))
all_attrs = {'tagger': b'Test Author ',
'tag_time': default_time,
'tag_timezone': 0,
'message': b'Test message.',
'object': (target_type, target_id),
'name': b'Test Tag',
}
all_attrs.update(attrs)
return make_object(Tag, **all_attrs)
def functest_builder(method, func):
"""Generate a test method that tests the given function."""
def do_test(self):
method(self, func)
return do_test
def ext_functest_builder(method, func):
"""Generate a test method that tests the given extension function.
This is intended to generate test methods that test both a pure-Python
version and an extension version using common test code. The extension test
will raise SkipTest if the extension is not found.
Sample usage:
class MyTest(TestCase);
def _do_some_test(self, func_impl):
self.assertEqual('foo', func_impl())
test_foo = functest_builder(_do_some_test, foo_py)
test_foo_extension = ext_functest_builder(_do_some_test, _foo_c)
- :param method: The method to run. It must must two parameters, self and the
+ Args:
+ method: The method to run. It must must two parameters, self and the
function implementation to test.
- :param func: The function implementation to pass to method.
+ func: The function implementation to pass to method.
"""
def do_test(self):
if not isinstance(func, types.BuiltinFunctionType):
raise SkipTest("%s extension not found" % func)
method(self, func)
return do_test
def build_pack(f, objects_spec, store=None):
"""Write test pack data from a concise spec.
- :param f: A file-like object to write the pack to.
- :param objects_spec: A list of (type_num, obj). For non-delta types, obj
+ Args:
+ f: A file-like object to write the pack to.
+ objects_spec: A list of (type_num, obj). For non-delta types, obj
is the string of that object's data.
For delta types, obj is a tuple of (base, data), where:
* base can be either an index in objects_spec of the base for that
* delta; or for a ref delta, a SHA, in which case the resulting pack
* will be thin and the base will be an external ref.
* data is a string of the full, non-deltified data for that object.
Note that offsets/refs and deltas are computed within this function.
- :param store: An optional ObjectStore for looking up external refs.
- :return: A list of tuples in the order specified by objects_spec:
+ store: An optional ObjectStore for looking up external refs.
+ Returns: A list of tuples in the order specified by objects_spec:
(offset, type num, data, sha, CRC32)
"""
sf = SHA1Writer(f)
num_objects = len(objects_spec)
write_pack_header(sf, num_objects)
full_objects = {}
offsets = {}
crc32s = {}
while len(full_objects) < num_objects:
for i, (type_num, data) in enumerate(objects_spec):
if type_num not in DELTA_TYPES:
full_objects[i] = (type_num, data,
obj_sha(type_num, [data]))
continue
base, data = data
if isinstance(base, int):
if base not in full_objects:
continue
base_type_num, _, _ = full_objects[base]
else:
base_type_num, _ = store.get_raw(base)
full_objects[i] = (base_type_num, data,
obj_sha(base_type_num, [data]))
for i, (type_num, obj) in enumerate(objects_spec):
offset = f.tell()
if type_num == OFS_DELTA:
base_index, data = obj
base = offset - offsets[base_index]
_, base_data, _ = full_objects[base_index]
obj = (base, create_delta(base_data, data))
elif type_num == REF_DELTA:
base_ref, data = obj
if isinstance(base_ref, int):
_, base_data, base = full_objects[base_ref]
else:
base_type_num, base_data = store.get_raw(base_ref)
base = obj_sha(base_type_num, base_data)
obj = (base, create_delta(base_data, data))
crc32 = write_pack_object(sf, type_num, obj)
offsets[i] = offset
crc32s[i] = crc32
expected = []
for i in range(num_objects):
type_num, data, sha = full_objects[i]
assert len(sha) == 20
expected.append((offsets[i], type_num, data, sha, crc32s[i]))
sf.write_sha()
f.seek(0)
return expected
def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
"""Build a commit graph from a concise specification.
Sample usage:
>>> c1, c2, c3 = build_commit_graph(store, [[1], [2, 1], [3, 1, 2]])
>>> store[store[c3].parents[0]] == c1
True
>>> store[store[c3].parents[1]] == c2
True
If not otherwise specified, commits will refer to the empty tree and have
commit times increasing in the same order as the commit spec.
- :param object_store: An ObjectStore to commit objects to.
- :param commit_spec: An iterable of iterables of ints defining the commit
+ Args:
+ object_store: An ObjectStore to commit objects to.
+ commit_spec: An iterable of iterables of ints defining the commit
graph. Each entry defines one commit, and entries must be in
topological order. The first element of each entry is a commit number,
and the remaining elements are its parents. The commit numbers are only
meaningful for the call to make_commits; since real commit objects are
created, they will get created with real, opaque SHAs.
- :param trees: An optional dict of commit number -> tree spec for building
+ trees: An optional dict of commit number -> tree spec for building
trees for commits. The tree spec is an iterable of (path, blob, mode)
or (path, blob) entries; if mode is omitted, it defaults to the normal
file mode (0100644).
- :param attrs: A dict of commit number -> (dict of attribute -> value) for
+ attrs: A dict of commit number -> (dict of attribute -> value) for
assigning additional values to the commits.
- :return: The list of commit objects created.
- :raise ValueError: If an undefined commit identifier is listed as a parent.
+ Returns: The list of commit objects created.
+ Raises:
+ ValueError: If an undefined commit identifier is listed as a parent.
"""
if trees is None:
trees = {}
if attrs is None:
attrs = {}
commit_time = 0
nums = {}
commits = []
for commit in commit_spec:
commit_num = commit[0]
try:
parent_ids = [nums[pn] for pn in commit[1:]]
except KeyError as e:
missing_parent, = e.args
raise ValueError('Unknown parent %i' % missing_parent)
blobs = []
for entry in trees.get(commit_num, []):
if len(entry) == 2:
path, blob = entry
entry = (path, blob, F)
path, blob, mode = entry
blobs.append((path, blob.id, mode))
object_store.add_object(blob)
tree_id = commit_tree(object_store, blobs)
commit_attrs = {
'message': ('Commit %i' % commit_num).encode('ascii'),
'parents': parent_ids,
'tree': tree_id,
'commit_time': commit_time,
}
commit_attrs.update(attrs.get(commit_num, {}))
commit_obj = make_commit(**commit_attrs)
# By default, increment the time by a lot. Out-of-order commits should
# be closer together than this because their main cause is clock skew.
commit_time = commit_attrs['commit_time'] + 100
nums[commit_num] = commit_obj.id
object_store.add_object(commit_obj)
commits.append(commit_obj)
return commits
def setup_warning_catcher():
"""Wrap warnings.showwarning with code that records warnings."""
caught_warnings = []
original_showwarning = warnings.showwarning
def custom_showwarning(*args, **kwargs):
caught_warnings.append(args[0])
warnings.showwarning = custom_showwarning
def restore_showwarning():
warnings.showwarning = original_showwarning
return caught_warnings, restore_showwarning
diff --git a/dulwich/walk.py b/dulwich/walk.py
index 2f6565e7..40d0a040 100644
--- a/dulwich/walk.py
+++ b/dulwich/walk.py
@@ -1,414 +1,419 @@
# walk.py -- General implementation of walking commits and their contents.
# Copyright (C) 2010 Google, Inc.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""General implementation of walking commits and their contents."""
import collections
import heapq
from itertools import chain
from dulwich.diff_tree import (
RENAME_CHANGE_TYPES,
tree_changes,
tree_changes_for_merge,
RenameDetector,
)
from dulwich.errors import (
MissingCommitError,
)
from dulwich.objects import (
Tag,
)
ORDER_DATE = 'date'
ORDER_TOPO = 'topo'
ALL_ORDERS = (ORDER_DATE, ORDER_TOPO)
# Maximum number of commits to walk past a commit time boundary.
_MAX_EXTRA_COMMITS = 5
class WalkEntry(object):
"""Object encapsulating a single result from a walk."""
def __init__(self, walker, commit):
self.commit = commit
self._store = walker.store
self._get_parents = walker.get_parents
self._changes = {}
self._rename_detector = walker.rename_detector
def changes(self, path_prefix=None):
"""Get the tree changes for this entry.
- :param path_prefix: Portion of the path in the repository to
+ Args:
+ path_prefix: Portion of the path in the repository to
use to filter changes. Must be a directory name. Must be
a full, valid, path reference (no partial names or wildcards).
- :return: For commits with up to one parent, a list of TreeChange
+ Returns: For commits with up to one parent, a list of TreeChange
objects; if the commit has no parents, these will be relative to
the empty tree. For merge commits, a list of lists of TreeChange
objects; see dulwich.diff.tree_changes_for_merge.
"""
cached = self._changes.get(path_prefix)
if cached is None:
commit = self.commit
if not self._get_parents(commit):
changes_func = tree_changes
parent = None
elif len(self._get_parents(commit)) == 1:
changes_func = tree_changes
parent = self._store[self._get_parents(commit)[0]].tree
if path_prefix:
mode, subtree_sha = parent.lookup_path(
self._store.__getitem__,
path_prefix,
)
parent = self._store[subtree_sha]
else:
changes_func = tree_changes_for_merge
parent = [
self._store[p].tree for p in self._get_parents(commit)]
if path_prefix:
parent_trees = [self._store[p] for p in parent]
parent = []
for p in parent_trees:
try:
mode, st = p.lookup_path(
self._store.__getitem__,
path_prefix,
)
except KeyError:
pass
else:
parent.append(st)
commit_tree_sha = commit.tree
if path_prefix:
commit_tree = self._store[commit_tree_sha]
mode, commit_tree_sha = commit_tree.lookup_path(
self._store.__getitem__,
path_prefix,
)
cached = list(changes_func(
self._store, parent, commit_tree_sha,
rename_detector=self._rename_detector))
self._changes[path_prefix] = cached
return self._changes[path_prefix]
def __repr__(self):
return '' % (
self.commit.id, self.changes())
class _CommitTimeQueue(object):
"""Priority queue of WalkEntry objects by commit time."""
def __init__(self, walker):
self._walker = walker
self._store = walker.store
self._get_parents = walker.get_parents
self._excluded = walker.excluded
self._pq = []
self._pq_set = set()
self._seen = set()
self._done = set()
self._min_time = walker.since
self._last = None
self._extra_commits_left = _MAX_EXTRA_COMMITS
self._is_finished = False
for commit_id in chain(walker.include, walker.excluded):
self._push(commit_id)
def _push(self, object_id):
try:
obj = self._store[object_id]
except KeyError:
raise MissingCommitError(object_id)
if isinstance(obj, Tag):
self._push(obj.object[1])
return
# TODO(jelmer): What to do about non-Commit and non-Tag objects?
commit = obj
if commit.id not in self._pq_set and commit.id not in self._done:
heapq.heappush(self._pq, (-commit.commit_time, commit))
self._pq_set.add(commit.id)
self._seen.add(commit.id)
def _exclude_parents(self, commit):
excluded = self._excluded
seen = self._seen
todo = [commit]
while todo:
commit = todo.pop()
for parent in self._get_parents(commit):
if parent not in excluded and parent in seen:
# TODO: This is inefficient unless the object store does
# some caching (which DiskObjectStore currently does not).
# We could either add caching in this class or pass around
# parsed queue entry objects instead of commits.
todo.append(self._store[parent])
excluded.add(parent)
def next(self):
if self._is_finished:
return None
while self._pq:
_, commit = heapq.heappop(self._pq)
sha = commit.id
self._pq_set.remove(sha)
if sha in self._done:
continue
self._done.add(sha)
for parent_id in self._get_parents(commit):
self._push(parent_id)
reset_extra_commits = True
is_excluded = sha in self._excluded
if is_excluded:
self._exclude_parents(commit)
if self._pq and all(c.id in self._excluded
for _, c in self._pq):
_, n = self._pq[0]
if self._last and n.commit_time >= self._last.commit_time:
# If the next commit is newer than the last one, we
# need to keep walking in case its parents (which we
# may not have seen yet) are excluded. This gives the
# excluded set a chance to "catch up" while the commit
# is still in the Walker's output queue.
reset_extra_commits = True
else:
reset_extra_commits = False
if (self._min_time is not None and
commit.commit_time < self._min_time):
# We want to stop walking at min_time, but commits at the
# boundary may be out of order with respect to their parents.
# So we walk _MAX_EXTRA_COMMITS more commits once we hit this
# boundary.
reset_extra_commits = False
if reset_extra_commits:
# We're not at a boundary, so reset the counter.
self._extra_commits_left = _MAX_EXTRA_COMMITS
else:
self._extra_commits_left -= 1
if not self._extra_commits_left:
break
if not is_excluded:
self._last = commit
return WalkEntry(self._walker, commit)
self._is_finished = True
return None
__next__ = next
class Walker(object):
"""Object for performing a walk of commits in a store.
Walker objects are initialized with a store and other options and can then
be treated as iterators of Commit objects.
"""
def __init__(self, store, include, exclude=None, order=ORDER_DATE,
reverse=False, max_entries=None, paths=None,
rename_detector=None, follow=False, since=None, until=None,
get_parents=lambda commit: commit.parents,
queue_cls=_CommitTimeQueue):
"""Constructor.
- :param store: ObjectStore instance for looking up objects.
- :param include: Iterable of SHAs of commits to include along with their
+ Args:
+ store: ObjectStore instance for looking up objects.
+ include: Iterable of SHAs of commits to include along with their
ancestors.
- :param exclude: Iterable of SHAs of commits to exclude along with their
+ exclude: Iterable of SHAs of commits to exclude along with their
ancestors, overriding includes.
- :param order: ORDER_* constant specifying the order of results.
+ order: ORDER_* constant specifying the order of results.
Anything other than ORDER_DATE may result in O(n) memory usage.
- :param reverse: If True, reverse the order of output, requiring O(n)
+ reverse: If True, reverse the order of output, requiring O(n)
memory.
- :param max_entries: The maximum number of entries to yield, or None for
+ max_entries: The maximum number of entries to yield, or None for
no limit.
- :param paths: Iterable of file or subtree paths to show entries for.
- :param rename_detector: diff.RenameDetector object for detecting
+ paths: Iterable of file or subtree paths to show entries for.
+ rename_detector: diff.RenameDetector object for detecting
renames.
- :param follow: If True, follow path across renames/copies. Forces a
+ follow: If True, follow path across renames/copies. Forces a
default rename_detector.
- :param since: Timestamp to list commits after.
- :param until: Timestamp to list commits before.
- :param get_parents: Method to retrieve the parents of a commit
- :param queue_cls: A class to use for a queue of commits, supporting the
+ since: Timestamp to list commits after.
+ until: Timestamp to list commits before.
+ get_parents: Method to retrieve the parents of a commit
+ queue_cls: A class to use for a queue of commits, supporting the
iterator protocol. The constructor takes a single argument, the
Walker.
"""
# Note: when adding arguments to this method, please also update
# dulwich.repo.BaseRepo.get_walker
if order not in ALL_ORDERS:
raise ValueError('Unknown walk order %s' % order)
self.store = store
if isinstance(include, bytes):
# TODO(jelmer): Really, this should require a single type.
# Print deprecation warning here?
include = [include]
self.include = include
self.excluded = set(exclude or [])
self.order = order
self.reverse = reverse
self.max_entries = max_entries
self.paths = paths and set(paths) or None
if follow and not rename_detector:
rename_detector = RenameDetector(store)
self.rename_detector = rename_detector
self.get_parents = get_parents
self.follow = follow
self.since = since
self.until = until
self._num_entries = 0
self._queue = queue_cls(self)
self._out_queue = collections.deque()
def _path_matches(self, changed_path):
if changed_path is None:
return False
for followed_path in self.paths:
if changed_path == followed_path:
return True
if (changed_path.startswith(followed_path) and
changed_path[len(followed_path)] == b'/'[0]):
return True
return False
def _change_matches(self, change):
if not change:
return False
old_path = change.old.path
new_path = change.new.path
if self._path_matches(new_path):
if self.follow and change.type in RENAME_CHANGE_TYPES:
self.paths.add(old_path)
self.paths.remove(new_path)
return True
elif self._path_matches(old_path):
return True
return False
def _should_return(self, entry):
"""Determine if a walk entry should be returned..
- :param entry: The WalkEntry to consider.
- :return: True if the WalkEntry should be returned by this walk, or
+ Args:
+ entry: The WalkEntry to consider.
+ Returns: True if the WalkEntry should be returned by this walk, or
False otherwise (e.g. if it doesn't match any requested paths).
"""
commit = entry.commit
if self.since is not None and commit.commit_time < self.since:
return False
if self.until is not None and commit.commit_time > self.until:
return False
if commit.id in self.excluded:
return False
if self.paths is None:
return True
if len(self.get_parents(commit)) > 1:
for path_changes in entry.changes():
# For merge commits, only include changes with conflicts for
# this path. Since a rename conflict may include different
# old.paths, we have to check all of them.
for change in path_changes:
if self._change_matches(change):
return True
else:
for change in entry.changes():
if self._change_matches(change):
return True
return None
def _next(self):
max_entries = self.max_entries
while max_entries is None or self._num_entries < max_entries:
entry = next(self._queue)
if entry is not None:
self._out_queue.append(entry)
if entry is None or len(self._out_queue) > _MAX_EXTRA_COMMITS:
if not self._out_queue:
return None
entry = self._out_queue.popleft()
if self._should_return(entry):
self._num_entries += 1
return entry
return None
def _reorder(self, results):
"""Possibly reorder a results iterator.
- :param results: An iterator of WalkEntry objects, in the order returned
+ Args:
+ results: An iterator of WalkEntry objects, in the order returned
from the queue_cls.
- :return: An iterator or list of WalkEntry objects, in the order
+ Returns: An iterator or list of WalkEntry objects, in the order
required by the Walker.
"""
if self.order == ORDER_TOPO:
results = _topo_reorder(results, self.get_parents)
if self.reverse:
results = reversed(list(results))
return results
def __iter__(self):
return iter(self._reorder(iter(self._next, None)))
def _topo_reorder(entries, get_parents=lambda commit: commit.parents):
"""Reorder an iterable of entries topologically.
This works best assuming the entries are already in almost-topological
order, e.g. in commit time order.
- :param entries: An iterable of WalkEntry objects.
- :param get_parents: Optional function for getting the parents of a commit.
- :return: iterator over WalkEntry objects from entries in FIFO order, except
+ Args:
+ entries: An iterable of WalkEntry objects.
+ get_parents: Optional function for getting the parents of a commit.
+ Returns: iterator over WalkEntry objects from entries in FIFO order, except
where a parent would be yielded before any of its children.
"""
todo = collections.deque()
pending = {}
num_children = collections.defaultdict(int)
for entry in entries:
todo.append(entry)
for p in get_parents(entry.commit):
num_children[p] += 1
while todo:
entry = todo.popleft()
commit = entry.commit
commit_id = commit.id
if num_children[commit_id]:
pending[commit_id] = entry
continue
for parent_id in get_parents(commit):
num_children[parent_id] -= 1
if not num_children[parent_id]:
parent_entry = pending.pop(parent_id, None)
if parent_entry:
todo.appendleft(parent_entry)
yield entry
diff --git a/dulwich/web.py b/dulwich/web.py
index 3dc971a7..3e4469e6 100644
--- a/dulwich/web.py
+++ b/dulwich/web.py
@@ -1,522 +1,524 @@
# web.py -- WSGI smart-http server
# Copyright (C) 2010 Google, Inc.
# Copyright (C) 2012 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""HTTP server for dulwich that implements the git smart HTTP protocol."""
from io import BytesIO
import shutil
import tempfile
import gzip
import os
import re
import sys
import time
from wsgiref.simple_server import (
WSGIRequestHandler,
ServerHandler,
WSGIServer,
make_server,
)
try:
from urlparse import parse_qs
except ImportError:
from urllib.parse import parse_qs
from dulwich import log_utils
from dulwich.protocol import (
ReceivableProtocol,
)
from dulwich.repo import (
NotGitRepository,
Repo,
)
from dulwich.server import (
DictBackend,
DEFAULT_HANDLERS,
generate_info_refs,
generate_objects_info_packs,
)
logger = log_utils.getLogger(__name__)
# HTTP error strings
HTTP_OK = '200 OK'
HTTP_NOT_FOUND = '404 Not Found'
HTTP_FORBIDDEN = '403 Forbidden'
HTTP_ERROR = '500 Internal Server Error'
def date_time_string(timestamp=None):
# From BaseHTTPRequestHandler.date_time_string in BaseHTTPServer.py in the
# Python 2.6.5 standard library, following modifications:
# - Made a global rather than an instance method.
# - weekdayname and monthname are renamed and locals rather than class
# variables.
# Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
months = [None,
'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
if timestamp is None:
timestamp = time.time()
year, month, day, hh, mm, ss, wd, y, z = time.gmtime(timestamp)
return '%s, %02d %3s %4d %02d:%02d:%02d GMD' % (
weekdays[wd], day, months[month], year, hh, mm, ss)
def url_prefix(mat):
"""Extract the URL prefix from a regex match.
- :param mat: A regex match object.
- :returns: The URL prefix, defined as the text before the match in the
+ Args:
+ mat: A regex match object.
+ Returns: The URL prefix, defined as the text before the match in the
original string. Normalized to start with one leading slash and end
with zero.
"""
return '/' + mat.string[:mat.start()].strip('/')
def get_repo(backend, mat):
"""Get a Repo instance for the given backend and URL regex match."""
return backend.open_repository(url_prefix(mat))
def send_file(req, f, content_type):
"""Send a file-like object to the request output.
- :param req: The HTTPGitRequest object to send output to.
- :param f: An open file-like object to send; will be closed.
- :param content_type: The MIME type for the file.
- :return: Iterator over the contents of the file, as chunks.
+ Args:
+ req: The HTTPGitRequest object to send output to.
+ f: An open file-like object to send; will be closed.
+ content_type: The MIME type for the file.
+ Returns: Iterator over the contents of the file, as chunks.
"""
if f is None:
yield req.not_found('File not found')
return
try:
req.respond(HTTP_OK, content_type)
while True:
data = f.read(10240)
if not data:
break
yield data
except IOError:
yield req.error('Error reading file')
finally:
f.close()
def _url_to_path(url):
return url.replace('/', os.path.sep)
def get_text_file(req, backend, mat):
req.nocache()
path = _url_to_path(mat.group())
logger.info('Sending plain text file %s', path)
return send_file(req, get_repo(backend, mat).get_named_file(path),
'text/plain')
def get_loose_object(req, backend, mat):
sha = (mat.group(1) + mat.group(2)).encode('ascii')
logger.info('Sending loose object %s', sha)
object_store = get_repo(backend, mat).object_store
if not object_store.contains_loose(sha):
yield req.not_found('Object not found')
return
try:
data = object_store[sha].as_legacy_object()
except IOError:
yield req.error('Error reading object')
return
req.cache_forever()
req.respond(HTTP_OK, 'application/x-git-loose-object')
yield data
def get_pack_file(req, backend, mat):
req.cache_forever()
path = _url_to_path(mat.group())
logger.info('Sending pack file %s', path)
return send_file(req, get_repo(backend, mat).get_named_file(path),
'application/x-git-packed-objects')
def get_idx_file(req, backend, mat):
req.cache_forever()
path = _url_to_path(mat.group())
logger.info('Sending pack file %s', path)
return send_file(req, get_repo(backend, mat).get_named_file(path),
'application/x-git-packed-objects-toc')
def get_info_refs(req, backend, mat):
params = parse_qs(req.environ['QUERY_STRING'])
service = params.get('service', [None])[0]
try:
repo = get_repo(backend, mat)
except NotGitRepository as e:
yield req.not_found(str(e))
return
if service and not req.dumb:
handler_cls = req.handlers.get(service.encode('ascii'), None)
if handler_cls is None:
yield req.forbidden('Unsupported service')
return
req.nocache()
write = req.respond(
HTTP_OK, 'application/x-%s-advertisement' % service)
proto = ReceivableProtocol(BytesIO().read, write)
handler = handler_cls(backend, [url_prefix(mat)], proto,
http_req=req, advertise_refs=True)
handler.proto.write_pkt_line(
b'# service=' + service.encode('ascii') + b'\n')
handler.proto.write_pkt_line(None)
handler.handle()
else:
# non-smart fallback
# TODO: select_getanyfile() (see http-backend.c)
req.nocache()
req.respond(HTTP_OK, 'text/plain')
logger.info('Emulating dumb info/refs')
for text in generate_info_refs(repo):
yield text
def get_info_packs(req, backend, mat):
req.nocache()
req.respond(HTTP_OK, 'text/plain')
logger.info('Emulating dumb info/packs')
return generate_objects_info_packs(get_repo(backend, mat))
class _LengthLimitedFile(object):
"""Wrapper class to limit the length of reads from a file-like object.
This is used to ensure EOF is read from the wsgi.input object once
Content-Length bytes are read. This behavior is required by the WSGI spec
but not implemented in wsgiref as of 2.5.
"""
def __init__(self, input, max_bytes):
self._input = input
self._bytes_avail = max_bytes
def read(self, size=-1):
if self._bytes_avail <= 0:
return b''
if size == -1 or size > self._bytes_avail:
size = self._bytes_avail
self._bytes_avail -= size
return self._input.read(size)
# TODO: support more methods as necessary
def handle_service_request(req, backend, mat):
service = mat.group().lstrip('/')
logger.info('Handling service request for %s', service)
handler_cls = req.handlers.get(service.encode('ascii'), None)
if handler_cls is None:
yield req.forbidden('Unsupported service')
return
try:
get_repo(backend, mat)
except NotGitRepository as e:
yield req.not_found(str(e))
return
req.nocache()
write = req.respond(HTTP_OK, 'application/x-%s-result' % service)
proto = ReceivableProtocol(req.environ['wsgi.input'].read, write)
# TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
# reopen.
handler = handler_cls(backend, [url_prefix(mat)], proto, http_req=req)
handler.handle()
class HTTPGitRequest(object):
"""Class encapsulating the state of a single git HTTP request.
:ivar environ: the WSGI environment for the request.
"""
def __init__(self, environ, start_response, dumb=False, handlers=None):
self.environ = environ
self.dumb = dumb
self.handlers = handlers
self._start_response = start_response
self._cache_headers = []
self._headers = []
def add_header(self, name, value):
"""Add a header to the response."""
self._headers.append((name, value))
def respond(self, status=HTTP_OK, content_type=None, headers=None):
"""Begin a response with the given status and other headers."""
if headers:
self._headers.extend(headers)
if content_type:
self._headers.append(('Content-Type', content_type))
self._headers.extend(self._cache_headers)
return self._start_response(status, self._headers)
def not_found(self, message):
"""Begin a HTTP 404 response and return the text of a message."""
self._cache_headers = []
logger.info('Not found: %s', message)
self.respond(HTTP_NOT_FOUND, 'text/plain')
return message.encode('ascii')
def forbidden(self, message):
"""Begin a HTTP 403 response and return the text of a message."""
self._cache_headers = []
logger.info('Forbidden: %s', message)
self.respond(HTTP_FORBIDDEN, 'text/plain')
return message.encode('ascii')
def error(self, message):
"""Begin a HTTP 500 response and return the text of a message."""
self._cache_headers = []
logger.error('Error: %s', message)
self.respond(HTTP_ERROR, 'text/plain')
return message.encode('ascii')
def nocache(self):
"""Set the response to never be cached by the client."""
self._cache_headers = [
('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
('Pragma', 'no-cache'),
('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
]
def cache_forever(self):
"""Set the response to be cached forever by the client."""
now = time.time()
self._cache_headers = [
('Date', date_time_string(now)),
('Expires', date_time_string(now + 31536000)),
('Cache-Control', 'public, max-age=31536000'),
]
class HTTPGitApplication(object):
"""Class encapsulating the state of a git WSGI application.
:ivar backend: the Backend object backing this application
"""
services = {
('GET', re.compile('/HEAD$')): get_text_file,
('GET', re.compile('/info/refs$')): get_info_refs,
('GET', re.compile('/objects/info/alternates$')): get_text_file,
('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
('GET', re.compile('/objects/info/packs$')): get_info_packs,
('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')):
get_loose_object,
('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')):
get_pack_file,
('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')):
get_idx_file,
('POST', re.compile('/git-upload-pack$')): handle_service_request,
('POST', re.compile('/git-receive-pack$')): handle_service_request,
}
def __init__(self, backend, dumb=False, handlers=None, fallback_app=None):
self.backend = backend
self.dumb = dumb
self.handlers = dict(DEFAULT_HANDLERS)
self.fallback_app = fallback_app
if handlers is not None:
self.handlers.update(handlers)
def __call__(self, environ, start_response):
path = environ['PATH_INFO']
method = environ['REQUEST_METHOD']
req = HTTPGitRequest(environ, start_response, dumb=self.dumb,
handlers=self.handlers)
# environ['QUERY_STRING'] has qs args
handler = None
for smethod, spath in self.services.keys():
if smethod != method:
continue
mat = spath.search(path)
if mat:
handler = self.services[smethod, spath]
break
if handler is None:
if self.fallback_app is not None:
return self.fallback_app(environ, start_response)
else:
return [req.not_found('Sorry, that method is not supported')]
return handler(req, self.backend, mat)
class GunzipFilter(object):
"""WSGI middleware that unzips gzip-encoded requests before
passing on to the underlying application.
"""
def __init__(self, application):
self.app = application
def __call__(self, environ, start_response):
if environ.get('HTTP_CONTENT_ENCODING', '') == 'gzip':
try:
environ['wsgi.input'].tell()
wsgi_input = environ['wsgi.input']
except (AttributeError, IOError, NotImplementedError):
# The gzip implementation in the standard library of Python 2.x
# requires working '.seek()' and '.tell()' methods on the input
# stream. Read the data into a temporary file to work around
# this limitation.
wsgi_input = tempfile.SpooledTemporaryFile(16 * 1024 * 1024)
shutil.copyfileobj(environ['wsgi.input'], wsgi_input)
wsgi_input.seek(0)
environ['wsgi.input'] = gzip.GzipFile(
filename=None, fileobj=wsgi_input, mode='r')
del environ['HTTP_CONTENT_ENCODING']
if 'CONTENT_LENGTH' in environ:
del environ['CONTENT_LENGTH']
return self.app(environ, start_response)
class LimitedInputFilter(object):
"""WSGI middleware that limits the input length of a request to that
specified in Content-Length.
"""
def __init__(self, application):
self.app = application
def __call__(self, environ, start_response):
# This is not necessary if this app is run from a conforming WSGI
# server. Unfortunately, there's no way to tell that at this point.
# TODO: git may used HTTP/1.1 chunked encoding instead of specifying
# content-length
content_length = environ.get('CONTENT_LENGTH', '')
if content_length:
environ['wsgi.input'] = _LengthLimitedFile(
environ['wsgi.input'], int(content_length))
return self.app(environ, start_response)
def make_wsgi_chain(*args, **kwargs):
"""Factory function to create an instance of HTTPGitApplication,
correctly wrapped with needed middleware.
"""
app = HTTPGitApplication(*args, **kwargs)
wrapped_app = LimitedInputFilter(GunzipFilter(app))
return wrapped_app
class ServerHandlerLogger(ServerHandler):
"""ServerHandler that uses dulwich's logger for logging exceptions."""
def log_exception(self, exc_info):
if sys.version_info < (2, 7):
logger.exception('Exception happened during processing of request')
else:
logger.exception('Exception happened during processing of request',
exc_info=exc_info)
def log_message(self, format, *args):
logger.info(format, *args)
def log_error(self, *args):
logger.error(*args)
class WSGIRequestHandlerLogger(WSGIRequestHandler):
"""WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
def log_exception(self, exc_info):
logger.exception('Exception happened during processing of request',
exc_info=exc_info)
def log_message(self, format, *args):
logger.info(format, *args)
def log_error(self, *args):
logger.error(*args)
def handle(self):
"""Handle a single HTTP request"""
self.raw_requestline = self.rfile.readline()
if not self.parse_request(): # An error code has been sent, just exit
return
handler = ServerHandlerLogger(
self.rfile, self.wfile, self.get_stderr(), self.get_environ()
)
handler.request_handler = self # backpointer for logging
handler.run(self.server.get_app())
class WSGIServerLogger(WSGIServer):
def handle_error(self, request, client_address):
"""Handle an error. """
logger.exception(
'Exception happened during processing of request from %s' %
str(client_address))
def main(argv=sys.argv):
"""Entry point for starting an HTTP git server."""
import optparse
parser = optparse.OptionParser()
parser.add_option("-l", "--listen_address", dest="listen_address",
default="localhost",
help="Binding IP address.")
parser.add_option("-p", "--port", dest="port", type=int,
default=8000,
help="Port to listen on.")
options, args = parser.parse_args(argv)
if len(args) > 1:
gitdir = args[1]
else:
gitdir = os.getcwd()
log_utils.default_logging_config()
backend = DictBackend({'/': Repo(gitdir)})
app = make_wsgi_chain(backend)
server = make_server(options.listen_address, options.port, app,
handler_class=WSGIRequestHandlerLogger,
server_class=WSGIServerLogger)
logger.info('Listening for HTTP connections on %s:%d',
options.listen_address, options.port)
server.serve_forever()
if __name__ == '__main__':
main()