diff --git a/.github/workflows/pythonpackage.yml b/.github/workflows/pythonpackage.yml
new file mode 100644
index 00000000..5c47f986
--- /dev/null
+++ b/.github/workflows/pythonpackage.yml
@@ -0,0 +1,32 @@
+name: Python package
+
+on: [push, pull_request]
+
+jobs:
+ build:
+
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: [3.6, 3.7, 3.8, pypy3]
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install -U pip coverage codecov flake8 fastimport
+ - name: Style checks
+ run: |
+ python -m flake8
+ - name: Build
+ run: |
+ python setup.py build_ext -i
+ - name: Coverage test suite run
+ run: |
+ python -m coverage run -p -m unittest dulwich.tests.test_suite
diff --git a/.github/workflows/pythonpublish.yml b/.github/workflows/pythonpublish.yml
new file mode 100644
index 00000000..27f39248
--- /dev/null
+++ b/.github/workflows/pythonpublish.yml
@@ -0,0 +1,32 @@
+name: Upload Python Package
+
+on:
+ release:
+ types: [created]
+
+jobs:
+ deploy:
+
+ runs-on: ${{ matrix.os }}
+ strategy:
+ matrix:
+ os: [ubuntu-latest, macos-latest, windows-latest]
+ python-version: ['3.x']
+
+ steps:
+ - uses: actions/checkout@v2
+ - name: Set up Python
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install setuptools wheel twine
+ - name: Build and publish
+ env:
+ TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
+ TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
+ run: |
+ python setup.py sdist bdist_wheel
+ twine upload dist/*
diff --git a/Makefile b/Makefile
index c0c7047d..b676540d 100644
--- a/Makefile
+++ b/Makefile
@@ -1,69 +1,69 @@
PYTHON = python3
-PYFLAKES = pyflakes
+PYFLAKES = $(PYTHON) -m pyflakes
PEP8 = pep8
-FLAKE8 ?= flake8
+FLAKE8 ?= $(PYTHON) -m flake8
SETUP = $(PYTHON) setup.py
TESTRUNNER ?= unittest
RUNTEST = PYTHONHASHSEED=random PYTHONPATH=$(shell pwd)$(if $(PYTHONPATH),:$(PYTHONPATH),) $(PYTHON) -m $(TESTRUNNER) $(TEST_OPTIONS)
COVERAGE = python3-coverage
DESTDIR=/
all: build
doc:: sphinx
sphinx::
$(MAKE) -C docs html
build::
$(SETUP) build
$(SETUP) build_ext -i
install::
$(SETUP) install --root="$(DESTDIR)"
check:: build
$(RUNTEST) dulwich.tests.test_suite
check-tutorial:: build
$(RUNTEST) dulwich.tests.tutorial_test_suite
check-nocompat:: build
$(RUNTEST) dulwich.tests.nocompat_test_suite
check-compat:: build
$(RUNTEST) dulwich.tests.compat_test_suite
check-pypy:: clean
$(MAKE) check-noextensions PYTHON=pypy
check-noextensions:: clean
$(RUNTEST) dulwich.tests.test_suite
check-all: check check-pypy check-noextensions
typing:
mypy dulwich
clean::
$(SETUP) clean --all
rm -f dulwich/*.so
flakes:
$(PYFLAKES) dulwich
pep8:
$(PEP8) dulwich
style:
$(FLAKE8)
before-push: check
git diff origin/master | $(PEP8) --diff
coverage:
$(COVERAGE) run -m unittest dulwich.tests.test_suite dulwich.contrib.test_suite
coverage-html: coverage
$(COVERAGE) html
diff --git a/dulwich/_diff_tree.c b/dulwich/_diff_tree.c
index 068b4f63..ebed2d7f 100644
--- a/dulwich/_diff_tree.c
+++ b/dulwich/_diff_tree.c
@@ -1,505 +1,510 @@
/*
* Copyright (C) 2010 Google, Inc.
*
* Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
* General Public License as public by the Free Software Foundation; version 2.0
* or (at your option) any later version. You can redistribute it and/or
* modify it under the terms of either of these two licenses.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* You should have received a copy of the licenses; if not, see
* for a copy of the GNU General Public License
* and for a copy of the Apache
* License, Version 2.0.
*/
#define PY_SSIZE_T_CLEAN
#include
#include
#ifdef _MSC_VER
typedef unsigned short mode_t;
#endif
#if PY_MAJOR_VERSION < 3
typedef long Py_hash_t;
#endif
#if PY_MAJOR_VERSION >= 3
#define PyInt_FromLong PyLong_FromLong
#define PyInt_AsLong PyLong_AsLong
#define PyInt_AS_LONG PyLong_AS_LONG
#define PyString_AS_STRING PyBytes_AS_STRING
#define PyString_AsStringAndSize PyBytes_AsStringAndSize
#define PyString_Check PyBytes_Check
#define PyString_CheckExact PyBytes_CheckExact
#define PyString_FromStringAndSize PyBytes_FromStringAndSize
#define PyString_FromString PyBytes_FromString
#define PyString_GET_SIZE PyBytes_GET_SIZE
#define PyString_Size PyBytes_Size
#define _PyString_Join _PyBytes_Join
#endif
static PyObject *tree_entry_cls = NULL, *null_entry = NULL,
*defaultdict_cls = NULL, *int_cls = NULL;
static int block_size;
/**
* Free an array of PyObject pointers, decrementing any references.
*/
static void free_objects(PyObject **objs, Py_ssize_t n)
{
Py_ssize_t i;
for (i = 0; i < n; i++)
Py_XDECREF(objs[i]);
PyMem_Free(objs);
}
/**
* Get the entries of a tree, prepending the given path.
*
* Args:
* path: The path to prepend, without trailing slashes.
* path_len: The length of path.
* tree: The Tree object to iterate.
* n: Set to the length of result.
* Returns: A (C) array of PyObject pointers to TreeEntry objects for each path
* in tree.
*/
static PyObject **tree_entries(char *path, Py_ssize_t path_len, PyObject *tree,
Py_ssize_t *n)
{
PyObject *iteritems, *items, **result = NULL;
PyObject *old_entry, *name, *sha;
Py_ssize_t i = 0, name_len, new_path_len;
char *new_path;
if (tree == Py_None) {
*n = 0;
result = PyMem_New(PyObject*, 0);
if (!result) {
PyErr_NoMemory();
return NULL;
}
return result;
}
iteritems = PyObject_GetAttrString(tree, "iteritems");
if (!iteritems)
return NULL;
items = PyObject_CallFunctionObjArgs(iteritems, Py_True, NULL);
Py_DECREF(iteritems);
if (items == NULL) {
return NULL;
}
/* The C implementation of iteritems returns a list, so depend on that. */
if (!PyList_Check(items)) {
PyErr_SetString(PyExc_TypeError,
"Tree.iteritems() did not return a list");
return NULL;
}
*n = PyList_Size(items);
result = PyMem_New(PyObject*, *n);
if (!result) {
PyErr_NoMemory();
goto error;
}
for (i = 0; i < *n; i++) {
old_entry = PyList_GetItem(items, i);
if (!old_entry)
goto error;
sha = PyTuple_GetItem(old_entry, 2);
if (!sha)
goto error;
name = PyTuple_GET_ITEM(old_entry, 0);
name_len = PyString_Size(name);
if (PyErr_Occurred())
goto error;
new_path_len = name_len;
if (path_len)
new_path_len += path_len + 1;
new_path = PyMem_Malloc(new_path_len);
if (!new_path) {
PyErr_NoMemory();
goto error;
}
if (path_len) {
memcpy(new_path, path, path_len);
new_path[path_len] = '/';
memcpy(new_path + path_len + 1, PyString_AS_STRING(name), name_len);
} else {
memcpy(new_path, PyString_AS_STRING(name), name_len);
}
#if PY_MAJOR_VERSION >= 3
result[i] = PyObject_CallFunction(tree_entry_cls, "y#OO", new_path,
new_path_len, PyTuple_GET_ITEM(old_entry, 1), sha);
#else
result[i] = PyObject_CallFunction(tree_entry_cls, "s#OO", new_path,
new_path_len, PyTuple_GET_ITEM(old_entry, 1), sha);
#endif
PyMem_Free(new_path);
if (!result[i]) {
goto error;
}
}
Py_DECREF(items);
return result;
error:
if (result)
free_objects(result, i);
Py_DECREF(items);
return NULL;
}
/**
* Use strcmp to compare the paths of two TreeEntry objects.
*/
static int entry_path_cmp(PyObject *entry1, PyObject *entry2)
{
PyObject *path1 = NULL, *path2 = NULL;
int result = 0;
path1 = PyObject_GetAttrString(entry1, "path");
if (!path1)
goto done;
if (!PyString_Check(path1)) {
PyErr_SetString(PyExc_TypeError, "path is not a (byte)string");
goto done;
}
path2 = PyObject_GetAttrString(entry2, "path");
if (!path2)
goto done;
if (!PyString_Check(path2)) {
PyErr_SetString(PyExc_TypeError, "path is not a (byte)string");
goto done;
}
result = strcmp(PyString_AS_STRING(path1), PyString_AS_STRING(path2));
done:
Py_XDECREF(path1);
Py_XDECREF(path2);
return result;
}
static PyObject *py_merge_entries(PyObject *self, PyObject *args)
{
PyObject *tree1, *tree2, **entries1 = NULL, **entries2 = NULL;
PyObject *e1, *e2, *pair, *result = NULL;
Py_ssize_t n1 = 0, n2 = 0, i1 = 0, i2 = 0, path_len;
char *path_str;
int cmp;
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTuple(args, "y#OO", &path_str, &path_len, &tree1, &tree2))
#else
if (!PyArg_ParseTuple(args, "s#OO", &path_str, &path_len, &tree1, &tree2))
#endif
return NULL;
entries1 = tree_entries(path_str, path_len, tree1, &n1);
if (!entries1)
goto error;
entries2 = tree_entries(path_str, path_len, tree2, &n2);
if (!entries2)
goto error;
result = PyList_New(0);
if (!result)
goto error;
while (i1 < n1 && i2 < n2) {
cmp = entry_path_cmp(entries1[i1], entries2[i2]);
if (PyErr_Occurred())
goto error;
if (!cmp) {
e1 = entries1[i1++];
e2 = entries2[i2++];
} else if (cmp < 0) {
e1 = entries1[i1++];
e2 = null_entry;
} else {
e1 = null_entry;
e2 = entries2[i2++];
}
pair = PyTuple_Pack(2, e1, e2);
if (!pair)
goto error;
PyList_Append(result, pair);
Py_DECREF(pair);
}
while (i1 < n1) {
pair = PyTuple_Pack(2, entries1[i1++], null_entry);
if (!pair)
goto error;
PyList_Append(result, pair);
Py_DECREF(pair);
}
while (i2 < n2) {
pair = PyTuple_Pack(2, null_entry, entries2[i2++]);
if (!pair)
goto error;
PyList_Append(result, pair);
Py_DECREF(pair);
}
goto done;
error:
Py_XDECREF(result);
result = NULL;
done:
if (entries1)
free_objects(entries1, n1);
if (entries2)
free_objects(entries2, n2);
return result;
}
+/* Not all environments define S_ISDIR */
+#if !defined(S_ISDIR) && defined(S_IFMT) && defined(S_IFDIR)
+#define S_ISDIR(m) (((m) & S_IFMT) == S_IFDIR)
+#endif
+
static PyObject *py_is_tree(PyObject *self, PyObject *args)
{
PyObject *entry, *mode, *result;
long lmode;
if (!PyArg_ParseTuple(args, "O", &entry))
return NULL;
mode = PyObject_GetAttrString(entry, "mode");
if (!mode)
return NULL;
if (mode == Py_None) {
result = Py_False;
Py_INCREF(result);
} else {
lmode = PyInt_AsLong(mode);
if (lmode == -1 && PyErr_Occurred()) {
Py_DECREF(mode);
return NULL;
}
result = PyBool_FromLong(S_ISDIR((mode_t)lmode));
}
Py_DECREF(mode);
return result;
}
static Py_hash_t add_hash(PyObject *get, PyObject *set, char *str, int n)
{
PyObject *str_obj = NULL, *hash_obj = NULL, *value = NULL,
*set_value = NULL;
Py_hash_t hash;
/* It would be nice to hash without copying str into a PyString, but that
* isn't exposed by the API. */
str_obj = PyString_FromStringAndSize(str, n);
if (!str_obj)
goto error;
hash = PyObject_Hash(str_obj);
if (hash == -1)
goto error;
hash_obj = PyInt_FromLong(hash);
if (!hash_obj)
goto error;
value = PyObject_CallFunctionObjArgs(get, hash_obj, NULL);
if (!value)
goto error;
set_value = PyObject_CallFunction(set, "(Ol)", hash_obj,
PyInt_AS_LONG(value) + n);
if (!set_value)
goto error;
Py_DECREF(str_obj);
Py_DECREF(hash_obj);
Py_DECREF(value);
Py_DECREF(set_value);
return 0;
error:
Py_XDECREF(str_obj);
Py_XDECREF(hash_obj);
Py_XDECREF(value);
Py_XDECREF(set_value);
return -1;
}
static PyObject *py_count_blocks(PyObject *self, PyObject *args)
{
PyObject *obj, *chunks = NULL, *chunk, *counts = NULL, *get = NULL,
*set = NULL;
char *chunk_str, *block = NULL;
Py_ssize_t num_chunks, chunk_len;
int i, j, n = 0;
char c;
if (!PyArg_ParseTuple(args, "O", &obj))
goto error;
counts = PyObject_CallFunctionObjArgs(defaultdict_cls, int_cls, NULL);
if (!counts)
goto error;
get = PyObject_GetAttrString(counts, "__getitem__");
set = PyObject_GetAttrString(counts, "__setitem__");
chunks = PyObject_CallMethod(obj, "as_raw_chunks", NULL);
if (!chunks)
goto error;
if (!PyList_Check(chunks)) {
PyErr_SetString(PyExc_TypeError,
"as_raw_chunks() did not return a list");
goto error;
}
num_chunks = PyList_GET_SIZE(chunks);
block = PyMem_New(char, block_size);
if (!block) {
PyErr_NoMemory();
goto error;
}
for (i = 0; i < num_chunks; i++) {
chunk = PyList_GET_ITEM(chunks, i);
if (!PyString_Check(chunk)) {
PyErr_SetString(PyExc_TypeError, "chunk is not a string");
goto error;
}
if (PyString_AsStringAndSize(chunk, &chunk_str, &chunk_len) == -1)
goto error;
for (j = 0; j < chunk_len; j++) {
c = chunk_str[j];
block[n++] = c;
if (c == '\n' || n == block_size) {
if (add_hash(get, set, block, n) == -1)
goto error;
n = 0;
}
}
}
if (n && add_hash(get, set, block, n) == -1)
goto error;
Py_DECREF(chunks);
Py_DECREF(get);
Py_DECREF(set);
PyMem_Free(block);
return counts;
error:
Py_XDECREF(chunks);
Py_XDECREF(get);
Py_XDECREF(set);
Py_XDECREF(counts);
PyMem_Free(block);
return NULL;
}
static PyMethodDef py_diff_tree_methods[] = {
{ "_is_tree", (PyCFunction)py_is_tree, METH_VARARGS, NULL },
{ "_merge_entries", (PyCFunction)py_merge_entries, METH_VARARGS, NULL },
{ "_count_blocks", (PyCFunction)py_count_blocks, METH_VARARGS, NULL },
{ NULL, NULL, 0, NULL }
};
static PyObject *
moduleinit(void)
{
PyObject *m, *objects_mod = NULL, *diff_tree_mod = NULL;
PyObject *block_size_obj = NULL;
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_diff_tree", /* m_name */
NULL, /* m_doc */
-1, /* m_size */
py_diff_tree_methods, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear*/
NULL, /* m_free */
};
m = PyModule_Create(&moduledef);
#else
m = Py_InitModule("_diff_tree", py_diff_tree_methods);
#endif
if (!m)
goto error;
objects_mod = PyImport_ImportModule("dulwich.objects");
if (!objects_mod)
goto error;
tree_entry_cls = PyObject_GetAttrString(objects_mod, "TreeEntry");
Py_DECREF(objects_mod);
if (!tree_entry_cls)
goto error;
diff_tree_mod = PyImport_ImportModule("dulwich.diff_tree");
if (!diff_tree_mod)
goto error;
null_entry = PyObject_GetAttrString(diff_tree_mod, "_NULL_ENTRY");
if (!null_entry)
goto error;
block_size_obj = PyObject_GetAttrString(diff_tree_mod, "_BLOCK_SIZE");
if (!block_size_obj)
goto error;
block_size = (int)PyInt_AsLong(block_size_obj);
if (PyErr_Occurred())
goto error;
defaultdict_cls = PyObject_GetAttrString(diff_tree_mod, "defaultdict");
if (!defaultdict_cls)
goto error;
/* This is kind of hacky, but I don't know of a better way to get the
* PyObject* version of int. */
int_cls = PyDict_GetItemString(PyEval_GetBuiltins(), "int");
if (!int_cls) {
PyErr_SetString(PyExc_NameError, "int");
goto error;
}
Py_DECREF(diff_tree_mod);
return m;
error:
Py_XDECREF(objects_mod);
Py_XDECREF(diff_tree_mod);
Py_XDECREF(null_entry);
Py_XDECREF(block_size_obj);
Py_XDECREF(defaultdict_cls);
Py_XDECREF(int_cls);
return NULL;
}
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
PyInit__diff_tree(void)
{
return moduleinit();
}
#else
PyMODINIT_FUNC
init_diff_tree(void)
{
moduleinit();
}
#endif
diff --git a/dulwich/_objects.c b/dulwich/_objects.c
index 417e189d..eb8b9e5b 100644
--- a/dulwich/_objects.c
+++ b/dulwich/_objects.c
@@ -1,330 +1,335 @@
/*
* Copyright (C) 2009 Jelmer Vernooij
*
* Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
* General Public License as public by the Free Software Foundation; version 2.0
* or (at your option) any later version. You can redistribute it and/or
* modify it under the terms of either of these two licenses.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* You should have received a copy of the licenses; if not, see
* for a copy of the GNU General Public License
* and for a copy of the Apache
* License, Version 2.0.
*/
#define PY_SSIZE_T_CLEAN
#include
#include
#include
#if PY_MAJOR_VERSION >= 3
#define PyInt_Check(obj) 0
#define PyInt_CheckExact(obj) 0
#define PyInt_AsLong PyLong_AsLong
#define PyString_AS_STRING PyBytes_AS_STRING
#define PyString_Check PyBytes_Check
#define PyString_FromStringAndSize PyBytes_FromStringAndSize
#endif
#if defined(__MINGW32_VERSION) || defined(__APPLE__)
size_t rep_strnlen(char *text, size_t maxlen);
size_t rep_strnlen(char *text, size_t maxlen)
{
const char *last = memchr(text, '\0', maxlen);
return last ? (size_t) (last - text) : maxlen;
}
#define strnlen rep_strnlen
#endif
#define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
static PyObject *tree_entry_cls;
static PyObject *object_format_exception_cls;
static PyObject *sha_to_pyhex(const unsigned char *sha)
{
char hexsha[41];
int i;
for (i = 0; i < 20; i++) {
hexsha[i*2] = bytehex((sha[i] & 0xF0) >> 4);
hexsha[i*2+1] = bytehex(sha[i] & 0x0F);
}
return PyString_FromStringAndSize(hexsha, 40);
}
static PyObject *py_parse_tree(PyObject *self, PyObject *args, PyObject *kw)
{
char *text, *start, *end;
Py_ssize_t len; int strict;
size_t namelen;
PyObject *ret, *item, *name, *sha, *py_strict = NULL;
static char *kwlist[] = {"text", "strict", NULL};
#if PY_MAJOR_VERSION >= 3
if (!PyArg_ParseTupleAndKeywords(args, kw, "y#|O", kwlist,
&text, &len, &py_strict))
#else
if (!PyArg_ParseTupleAndKeywords(args, kw, "s#|O", kwlist,
&text, &len, &py_strict))
#endif
return NULL;
strict = py_strict ? PyObject_IsTrue(py_strict) : 0;
/* TODO: currently this returns a list; if memory usage is a concern,
* consider rewriting as a custom iterator object */
ret = PyList_New(0);
if (ret == NULL) {
return NULL;
}
start = text;
end = text + len;
while (text < end) {
long mode;
if (strict && text[0] == '0') {
PyErr_SetString(object_format_exception_cls,
"Illegal leading zero on mode");
Py_DECREF(ret);
return NULL;
}
mode = strtol(text, &text, 8);
if (*text != ' ') {
PyErr_SetString(PyExc_ValueError, "Expected space");
Py_DECREF(ret);
return NULL;
}
text++;
namelen = strnlen(text, len - (text - start));
name = PyString_FromStringAndSize(text, namelen);
if (name == NULL) {
Py_DECREF(ret);
return NULL;
}
if (text + namelen + 20 >= end) {
PyErr_SetString(PyExc_ValueError, "SHA truncated");
Py_DECREF(ret);
Py_DECREF(name);
return NULL;
}
sha = sha_to_pyhex((unsigned char *)text+namelen+1);
if (sha == NULL) {
Py_DECREF(ret);
Py_DECREF(name);
return NULL;
}
item = Py_BuildValue("(NlN)", name, mode, sha);
if (item == NULL) {
Py_DECREF(ret);
Py_DECREF(sha);
Py_DECREF(name);
return NULL;
}
if (PyList_Append(ret, item) == -1) {
Py_DECREF(ret);
Py_DECREF(item);
return NULL;
}
Py_DECREF(item);
text += namelen+21;
}
return ret;
}
struct tree_item {
const char *name;
int mode;
PyObject *tuple;
};
+/* Not all environments define S_ISDIR */
+#if !defined(S_ISDIR) && defined(S_IFMT) && defined(S_IFDIR)
+#define S_ISDIR(m) (((m) & S_IFMT) == S_IFDIR)
+#endif
+
int cmp_tree_item(const void *_a, const void *_b)
{
const struct tree_item *a = _a, *b = _b;
const char *remain_a, *remain_b;
int ret;
size_t common;
if (strlen(a->name) > strlen(b->name)) {
common = strlen(b->name);
remain_a = a->name + common;
remain_b = (S_ISDIR(b->mode)?"/":"");
} else if (strlen(b->name) > strlen(a->name)) {
common = strlen(a->name);
remain_a = (S_ISDIR(a->mode)?"/":"");
remain_b = b->name + common;
} else { /* strlen(a->name) == strlen(b->name) */
common = 0;
remain_a = a->name;
remain_b = b->name;
}
ret = strncmp(a->name, b->name, common);
if (ret != 0)
return ret;
return strcmp(remain_a, remain_b);
}
int cmp_tree_item_name_order(const void *_a, const void *_b) {
const struct tree_item *a = _a, *b = _b;
return strcmp(a->name, b->name);
}
static PyObject *py_sorted_tree_items(PyObject *self, PyObject *args)
{
struct tree_item *qsort_entries = NULL;
int name_order, n = 0, i;
PyObject *entries, *py_name_order, *ret, *key, *value, *py_mode, *py_sha;
Py_ssize_t pos = 0, num_entries;
int (*cmp)(const void *, const void *);
if (!PyArg_ParseTuple(args, "OO", &entries, &py_name_order))
goto error;
if (!PyDict_Check(entries)) {
PyErr_SetString(PyExc_TypeError, "Argument not a dictionary");
goto error;
}
name_order = PyObject_IsTrue(py_name_order);
if (name_order == -1)
goto error;
cmp = name_order ? cmp_tree_item_name_order : cmp_tree_item;
num_entries = PyDict_Size(entries);
if (PyErr_Occurred())
goto error;
qsort_entries = PyMem_New(struct tree_item, num_entries);
if (!qsort_entries) {
PyErr_NoMemory();
goto error;
}
while (PyDict_Next(entries, &pos, &key, &value)) {
if (!PyString_Check(key)) {
PyErr_SetString(PyExc_TypeError, "Name is not a string");
goto error;
}
if (PyTuple_Size(value) != 2) {
PyErr_SetString(PyExc_ValueError, "Tuple has invalid size");
goto error;
}
py_mode = PyTuple_GET_ITEM(value, 0);
if (!PyInt_Check(py_mode) && !PyLong_Check(py_mode)) {
PyErr_SetString(PyExc_TypeError, "Mode is not an integral type");
goto error;
}
py_sha = PyTuple_GET_ITEM(value, 1);
if (!PyString_Check(py_sha)) {
PyErr_SetString(PyExc_TypeError, "SHA is not a string");
goto error;
}
qsort_entries[n].name = PyString_AS_STRING(key);
qsort_entries[n].mode = PyInt_AsLong(py_mode);
qsort_entries[n].tuple = PyObject_CallFunctionObjArgs(
tree_entry_cls, key, py_mode, py_sha, NULL);
if (qsort_entries[n].tuple == NULL)
goto error;
n++;
}
qsort(qsort_entries, num_entries, sizeof(struct tree_item), cmp);
ret = PyList_New(num_entries);
if (ret == NULL) {
PyErr_NoMemory();
goto error;
}
for (i = 0; i < num_entries; i++) {
PyList_SET_ITEM(ret, i, qsort_entries[i].tuple);
}
PyMem_Free(qsort_entries);
return ret;
error:
for (i = 0; i < n; i++) {
Py_XDECREF(qsort_entries[i].tuple);
}
PyMem_Free(qsort_entries);
return NULL;
}
static PyMethodDef py_objects_methods[] = {
{ "parse_tree", (PyCFunction)py_parse_tree, METH_VARARGS | METH_KEYWORDS,
NULL },
{ "sorted_tree_items", py_sorted_tree_items, METH_VARARGS, NULL },
{ NULL, NULL, 0, NULL }
};
static PyObject *
moduleinit(void)
{
PyObject *m, *objects_mod, *errors_mod;
#if PY_MAJOR_VERSION >= 3
static struct PyModuleDef moduledef = {
PyModuleDef_HEAD_INIT,
"_objects", /* m_name */
NULL, /* m_doc */
-1, /* m_size */
py_objects_methods, /* m_methods */
NULL, /* m_reload */
NULL, /* m_traverse */
NULL, /* m_clear*/
NULL, /* m_free */
};
m = PyModule_Create(&moduledef);
#else
m = Py_InitModule3("_objects", py_objects_methods, NULL);
#endif
if (m == NULL) {
return NULL;
}
errors_mod = PyImport_ImportModule("dulwich.errors");
if (errors_mod == NULL) {
return NULL;
}
object_format_exception_cls = PyObject_GetAttrString(
errors_mod, "ObjectFormatException");
Py_DECREF(errors_mod);
if (object_format_exception_cls == NULL) {
return NULL;
}
/* This is a circular import but should be safe since this module is
* imported at at the very bottom of objects.py. */
objects_mod = PyImport_ImportModule("dulwich.objects");
if (objects_mod == NULL) {
return NULL;
}
tree_entry_cls = PyObject_GetAttrString(objects_mod, "TreeEntry");
Py_DECREF(objects_mod);
if (tree_entry_cls == NULL) {
return NULL;
}
return m;
}
#if PY_MAJOR_VERSION >= 3
PyMODINIT_FUNC
PyInit__objects(void)
{
return moduleinit();
}
#else
PyMODINIT_FUNC
init_objects(void)
{
moduleinit();
}
#endif
diff --git a/dulwich/client.py b/dulwich/client.py
index 42726ef6..712f542d 100644
--- a/dulwich/client.py
+++ b/dulwich/client.py
@@ -1,1896 +1,1897 @@
# client.py -- Implementation of the client side git protocols
# Copyright (C) 2008-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Client side support for the Git protocol.
The Dulwich client supports the following capabilities:
* thin-pack
* multi_ack_detailed
* multi_ack
* side-band-64k
* ofs-delta
* quiet
* report-status
* delete-refs
* shallow
Known capabilities that are not supported:
* no-progress
* include-tag
"""
from contextlib import closing
from io import BytesIO, BufferedReader
import errno
import os
import select
import socket
import subprocess
import sys
from urllib.parse import (
quote as urlquote,
unquote as urlunquote,
urlparse,
urljoin,
urlunparse,
urlunsplit,
urlunparse,
)
import dulwich
from dulwich.config import get_xdg_config_home_path
from dulwich.errors import (
GitProtocolError,
NotGitRepository,
SendPackError,
UpdateRefsError,
)
from dulwich.protocol import (
HangupException,
_RBUFSIZE,
agent_string,
capability_agent,
extract_capability_names,
CAPABILITY_AGENT,
CAPABILITY_DELETE_REFS,
CAPABILITY_INCLUDE_TAG,
CAPABILITY_MULTI_ACK,
CAPABILITY_MULTI_ACK_DETAILED,
CAPABILITY_OFS_DELTA,
CAPABILITY_QUIET,
CAPABILITY_REPORT_STATUS,
CAPABILITY_SHALLOW,
CAPABILITY_SYMREF,
CAPABILITY_SIDE_BAND_64K,
CAPABILITY_THIN_PACK,
CAPABILITIES_REF,
KNOWN_RECEIVE_CAPABILITIES,
KNOWN_UPLOAD_CAPABILITIES,
COMMAND_DEEPEN,
COMMAND_SHALLOW,
COMMAND_UNSHALLOW,
COMMAND_DONE,
COMMAND_HAVE,
COMMAND_WANT,
SIDE_BAND_CHANNEL_DATA,
SIDE_BAND_CHANNEL_PROGRESS,
SIDE_BAND_CHANNEL_FATAL,
PktLineParser,
Protocol,
ProtocolFile,
TCP_GIT_PORT,
ZERO_SHA,
extract_capabilities,
parse_capability,
)
from dulwich.pack import (
write_pack_data,
write_pack_objects,
)
from dulwich.refs import (
read_info_refs,
ANNOTATED_TAG_SUFFIX,
)
class InvalidWants(Exception):
"""Invalid wants."""
def __init__(self, wants):
Exception.__init__(
self,
"requested wants not in server provided refs: %r" % wants)
def _fileno_can_read(fileno):
"""Check if a file descriptor is readable.
"""
return len(select.select([fileno], [], [], 0)[0]) > 0
def _win32_peek_avail(handle):
"""Wrapper around PeekNamedPipe to check how many bytes are available.
"""
from ctypes import byref, wintypes, windll
c_avail = wintypes.DWORD()
c_message = wintypes.DWORD()
success = windll.kernel32.PeekNamedPipe(
handle, None, 0, None, byref(c_avail),
byref(c_message))
if not success:
raise OSError(wintypes.GetLastError())
return c_avail.value
COMMON_CAPABILITIES = [CAPABILITY_OFS_DELTA, CAPABILITY_SIDE_BAND_64K]
UPLOAD_CAPABILITIES = ([CAPABILITY_THIN_PACK, CAPABILITY_MULTI_ACK,
CAPABILITY_MULTI_ACK_DETAILED, CAPABILITY_SHALLOW]
+ COMMON_CAPABILITIES)
RECEIVE_CAPABILITIES = (
[CAPABILITY_REPORT_STATUS, CAPABILITY_DELETE_REFS]
+ COMMON_CAPABILITIES)
class ReportStatusParser(object):
"""Handle status as reported by servers with 'report-status' capability."""
def __init__(self):
self._done = False
self._pack_status = None
self._ref_status_ok = True
self._ref_statuses = []
def check(self):
"""Check if there were any errors and, if so, raise exceptions.
Raises:
SendPackError: Raised when the server could not unpack
UpdateRefsError: Raised when refs could not be updated
"""
if self._pack_status not in (b'unpack ok', None):
raise SendPackError(self._pack_status)
if not self._ref_status_ok:
ref_status = {}
ok = set()
for status in self._ref_statuses:
if b' ' not in status:
# malformed response, move on to the next one
continue
status, ref = status.split(b' ', 1)
if status == b'ng':
if b' ' in ref:
ref, status = ref.split(b' ', 1)
else:
ok.add(ref)
ref_status[ref] = status
# TODO(jelmer): don't assume encoding of refs is ascii.
raise UpdateRefsError(', '.join([
refname.decode('ascii') for refname in ref_status
if refname not in ok]) +
' failed to update', ref_status=ref_status)
def handle_packet(self, pkt):
"""Handle a packet.
Raises:
GitProtocolError: Raised when packets are received after a flush
packet.
"""
if self._done:
raise GitProtocolError("received more data after status report")
if pkt is None:
self._done = True
return
if self._pack_status is None:
self._pack_status = pkt.strip()
else:
ref_status = pkt.strip()
self._ref_statuses.append(ref_status)
if not ref_status.startswith(b'ok '):
self._ref_status_ok = False
def read_pkt_refs(proto):
server_capabilities = None
refs = {}
# Receive refs from server
for pkt in proto.read_pkt_seq():
(sha, ref) = pkt.rstrip(b'\n').split(None, 1)
if sha == b'ERR':
raise GitProtocolError(ref.decode('utf-8', 'replace'))
if server_capabilities is None:
(ref, server_capabilities) = extract_capabilities(ref)
refs[ref] = sha
if len(refs) == 0:
return {}, set([])
if refs == {CAPABILITIES_REF: ZERO_SHA}:
refs = {}
return refs, set(server_capabilities)
class FetchPackResult(object):
"""Result of a fetch-pack operation.
Attributes:
refs: Dictionary with all remote refs
symrefs: Dictionary with remote symrefs
agent: User agent string
"""
_FORWARDED_ATTRS = [
'clear', 'copy', 'fromkeys', 'get', 'has_key', 'items',
'iteritems', 'iterkeys', 'itervalues', 'keys', 'pop', 'popitem',
'setdefault', 'update', 'values', 'viewitems', 'viewkeys',
'viewvalues']
def __init__(self, refs, symrefs, agent, new_shallow=None,
new_unshallow=None):
self.refs = refs
self.symrefs = symrefs
self.agent = agent
self.new_shallow = new_shallow
self.new_unshallow = new_unshallow
def _warn_deprecated(self):
import warnings
warnings.warn(
"Use FetchPackResult.refs instead.",
DeprecationWarning, stacklevel=3)
def __eq__(self, other):
if isinstance(other, dict):
self._warn_deprecated()
return (self.refs == other)
return (self.refs == other.refs and
self.symrefs == other.symrefs and
self.agent == other.agent)
def __contains__(self, name):
self._warn_deprecated()
return name in self.refs
def __getitem__(self, name):
self._warn_deprecated()
return self.refs[name]
def __len__(self):
self._warn_deprecated()
return len(self.refs)
def __iter__(self):
self._warn_deprecated()
return iter(self.refs)
def __getattribute__(self, name):
if name in type(self)._FORWARDED_ATTRS:
self._warn_deprecated()
return getattr(self.refs, name)
return super(FetchPackResult, self).__getattribute__(name)
def __repr__(self):
return "%s(%r, %r, %r)" % (
self.__class__.__name__, self.refs, self.symrefs, self.agent)
def _read_shallow_updates(proto):
new_shallow = set()
new_unshallow = set()
for pkt in proto.read_pkt_seq():
cmd, sha = pkt.split(b' ', 1)
if cmd == COMMAND_SHALLOW:
new_shallow.add(sha.strip())
elif cmd == COMMAND_UNSHALLOW:
new_unshallow.add(sha.strip())
else:
raise GitProtocolError('unknown command %s' % pkt)
return (new_shallow, new_unshallow)
# TODO(durin42): this doesn't correctly degrade if the server doesn't
# support some capabilities. This should work properly with servers
# that don't support multi_ack.
class GitClient(object):
"""Git smart server client."""
def __init__(self, thin_packs=True, report_activity=None, quiet=False,
include_tags=False):
"""Create a new GitClient instance.
Args:
thin_packs: Whether or not thin packs should be retrieved
report_activity: Optional callback for reporting transport
activity.
include_tags: send annotated tags when sending the objects they point
to
"""
self._report_activity = report_activity
self._report_status_parser = None
self._fetch_capabilities = set(UPLOAD_CAPABILITIES)
self._fetch_capabilities.add(capability_agent())
self._send_capabilities = set(RECEIVE_CAPABILITIES)
self._send_capabilities.add(capability_agent())
if quiet:
self._send_capabilities.add(CAPABILITY_QUIET)
if not thin_packs:
self._fetch_capabilities.remove(CAPABILITY_THIN_PACK)
if include_tags:
self._fetch_capabilities.add(CAPABILITY_INCLUDE_TAG)
def get_url(self, path):
"""Retrieves full url to given path.
Args:
path: Repository path (as string)
Returns:
Url to path (as string)
"""
raise NotImplementedError(self.get_url)
@classmethod
def from_parsedurl(cls, parsedurl, **kwargs):
"""Create an instance of this client from a urlparse.parsed object.
Args:
parsedurl: Result of urlparse()
Returns:
A `GitClient` object
"""
raise NotImplementedError(cls.from_parsedurl)
def send_pack(self, path, update_refs, generate_pack_data,
progress=None):
"""Upload a pack to a remote repository.
Args:
path: Repository path (as bytestring)
update_refs: Function to determine changes to remote refs. Receive
dict with existing remote refs, returns dict with
changed refs (name -> sha, where sha=ZERO_SHA for deletions)
generate_pack_data: Function that can return a tuple
with number of objects and list of pack data to include
progress: Optional progress function
Returns:
new_refs dictionary containing the changes that were made
{refname: new_ref}, including deleted refs.
Raises:
SendPackError: if server rejects the pack data
UpdateRefsError: if the server supports report-status
and rejects ref updates
"""
raise NotImplementedError(self.send_pack)
def fetch(self, path, target, determine_wants=None, progress=None,
depth=None):
"""Fetch into a target repository.
Args:
path: Path to fetch from (as bytestring)
target: Target repository to fetch into
determine_wants: Optional function to determine what refs to fetch.
Receives dictionary of name->sha, should return
list of shas to fetch. Defaults to all shas.
progress: Optional progress function
depth: Depth to fetch at
Returns:
Dictionary with all remote refs (not just those fetched)
"""
if determine_wants is None:
determine_wants = target.object_store.determine_wants_all
if CAPABILITY_THIN_PACK in self._fetch_capabilities:
# TODO(jelmer): Avoid reading entire file into memory and
# only processing it after the whole file has been fetched.
f = BytesIO()
def commit():
if f.tell():
f.seek(0)
target.object_store.add_thin_pack(f.read, None)
def abort():
pass
else:
f, commit, abort = target.object_store.add_pack()
try:
result = self.fetch_pack(
path, determine_wants, target.get_graph_walker(), f.write,
progress=progress, depth=depth)
except BaseException:
abort()
raise
else:
commit()
target.update_shallow(result.new_shallow, result.new_unshallow)
return result
def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
progress=None, depth=None):
"""Retrieve a pack from a git smart server.
Args:
path: Remote path to fetch from
determine_wants: Function determine what refs
to fetch. Receives dictionary of name->sha, should return
list of shas to fetch.
graph_walker: Object with next() and ack().
pack_data: Callback called for each bit of data in the pack
progress: Callback for progress reports (strings)
depth: Shallow fetch depth
Returns:
FetchPackResult object
"""
raise NotImplementedError(self.fetch_pack)
def get_refs(self, path):
"""Retrieve the current refs from a git smart server.
Args:
path: Path to the repo to fetch from. (as bytestring)
Returns:
"""
raise NotImplementedError(self.get_refs)
def _parse_status_report(self, proto):
unpack = proto.read_pkt_line().strip()
if unpack != b'unpack ok':
st = True
# flush remaining error data
while st is not None:
st = proto.read_pkt_line()
raise SendPackError(unpack)
statuses = []
errs = False
ref_status = proto.read_pkt_line()
while ref_status:
ref_status = ref_status.strip()
statuses.append(ref_status)
if not ref_status.startswith(b'ok '):
errs = True
ref_status = proto.read_pkt_line()
if errs:
ref_status = {}
ok = set()
for status in statuses:
if b' ' not in status:
# malformed response, move on to the next one
continue
status, ref = status.split(b' ', 1)
if status == b'ng':
if b' ' in ref:
ref, status = ref.split(b' ', 1)
else:
ok.add(ref)
ref_status[ref] = status
raise UpdateRefsError(', '.join([
refname for refname in ref_status if refname not in ok]) +
b' failed to update', ref_status=ref_status)
def _read_side_band64k_data(self, proto, channel_callbacks):
"""Read per-channel data.
This requires the side-band-64k capability.
Args:
proto: Protocol object to read from
channel_callbacks: Dictionary mapping channels to packet
handlers to use. None for a callback discards channel data.
"""
for pkt in proto.read_pkt_seq():
channel = ord(pkt[:1])
pkt = pkt[1:]
try:
cb = channel_callbacks[channel]
except KeyError:
raise AssertionError('Invalid sideband channel %d' % channel)
else:
if cb is not None:
cb(pkt)
def _handle_receive_pack_head(self, proto, capabilities, old_refs,
new_refs):
"""Handle the head of a 'git-receive-pack' request.
Args:
proto: Protocol object to read from
capabilities: List of negotiated capabilities
old_refs: Old refs, as received from the server
new_refs: Refs to change
Returns:
have, want) tuple
"""
want = []
have = [x for x in old_refs.values() if not x == ZERO_SHA]
sent_capabilities = False
for refname in new_refs:
if not isinstance(refname, bytes):
raise TypeError('refname is not a bytestring: %r' % refname)
old_sha1 = old_refs.get(refname, ZERO_SHA)
if not isinstance(old_sha1, bytes):
raise TypeError('old sha1 for %s is not a bytestring: %r' %
(refname, old_sha1))
new_sha1 = new_refs.get(refname, ZERO_SHA)
if not isinstance(new_sha1, bytes):
raise TypeError('old sha1 for %s is not a bytestring %r' %
(refname, new_sha1))
if old_sha1 != new_sha1:
if sent_capabilities:
proto.write_pkt_line(old_sha1 + b' ' + new_sha1 + b' ' +
refname)
else:
proto.write_pkt_line(
old_sha1 + b' ' + new_sha1 + b' ' + refname + b'\0' +
b' '.join(sorted(capabilities)))
sent_capabilities = True
if new_sha1 not in have and new_sha1 != ZERO_SHA:
want.append(new_sha1)
proto.write_pkt_line(None)
return (have, want)
def _negotiate_receive_pack_capabilities(self, server_capabilities):
negotiated_capabilities = (
self._send_capabilities & server_capabilities)
unknown_capabilities = ( # noqa: F841
extract_capability_names(server_capabilities) -
KNOWN_RECEIVE_CAPABILITIES)
# TODO(jelmer): warn about unknown capabilities
return negotiated_capabilities
def _handle_receive_pack_tail(self, proto, capabilities, progress=None):
"""Handle the tail of a 'git-receive-pack' request.
Args:
proto: Protocol object to read from
capabilities: List of negotiated capabilities
progress: Optional progress reporting function
Returns:
"""
if CAPABILITY_SIDE_BAND_64K in capabilities:
if progress is None:
def progress(x):
pass
channel_callbacks = {2: progress}
if CAPABILITY_REPORT_STATUS in capabilities:
channel_callbacks[1] = PktLineParser(
self._report_status_parser.handle_packet).parse
self._read_side_band64k_data(proto, channel_callbacks)
else:
if CAPABILITY_REPORT_STATUS in capabilities:
for pkt in proto.read_pkt_seq():
self._report_status_parser.handle_packet(pkt)
if self._report_status_parser is not None:
self._report_status_parser.check()
def _negotiate_upload_pack_capabilities(self, server_capabilities):
unknown_capabilities = ( # noqa: F841
extract_capability_names(server_capabilities) -
KNOWN_UPLOAD_CAPABILITIES)
# TODO(jelmer): warn about unknown capabilities
symrefs = {}
agent = None
for capability in server_capabilities:
k, v = parse_capability(capability)
if k == CAPABILITY_SYMREF:
(src, dst) = v.split(b':', 1)
symrefs[src] = dst
if k == CAPABILITY_AGENT:
agent = v
negotiated_capabilities = (
self._fetch_capabilities & server_capabilities)
return (negotiated_capabilities, symrefs, agent)
def _handle_upload_pack_head(self, proto, capabilities, graph_walker,
wants, can_read, depth):
"""Handle the head of a 'git-upload-pack' request.
Args:
proto: Protocol object to read from
capabilities: List of negotiated capabilities
graph_walker: GraphWalker instance to call .ack() on
wants: List of commits to fetch
can_read: function that returns a boolean that indicates
whether there is extra graph data to read on proto
depth: Depth for request
Returns:
"""
assert isinstance(wants, list) and isinstance(wants[0], bytes)
proto.write_pkt_line(COMMAND_WANT + b' ' + wants[0] + b' ' +
b' '.join(sorted(capabilities)) + b'\n')
for want in wants[1:]:
proto.write_pkt_line(COMMAND_WANT + b' ' + want + b'\n')
if depth not in (0, None) or getattr(graph_walker, 'shallow', None):
if CAPABILITY_SHALLOW not in capabilities:
raise GitProtocolError(
"server does not support shallow capability required for "
"depth")
for sha in graph_walker.shallow:
proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha + b'\n')
if depth is not None:
proto.write_pkt_line(COMMAND_DEEPEN + b' ' +
str(depth).encode('ascii') + b'\n')
proto.write_pkt_line(None)
if can_read is not None:
(new_shallow, new_unshallow) = _read_shallow_updates(proto)
else:
new_shallow = new_unshallow = None
else:
new_shallow = new_unshallow = set()
proto.write_pkt_line(None)
have = next(graph_walker)
while have:
proto.write_pkt_line(COMMAND_HAVE + b' ' + have + b'\n')
if can_read is not None and can_read():
pkt = proto.read_pkt_line()
parts = pkt.rstrip(b'\n').split(b' ')
if parts[0] == b'ACK':
graph_walker.ack(parts[1])
if parts[2] in (b'continue', b'common'):
pass
elif parts[2] == b'ready':
break
else:
raise AssertionError(
"%s not in ('continue', 'ready', 'common)" %
parts[2])
have = next(graph_walker)
proto.write_pkt_line(COMMAND_DONE + b'\n')
return (new_shallow, new_unshallow)
def _handle_upload_pack_tail(self, proto, capabilities, graph_walker,
pack_data, progress=None, rbufsize=_RBUFSIZE):
"""Handle the tail of a 'git-upload-pack' request.
Args:
proto: Protocol object to read from
capabilities: List of negotiated capabilities
graph_walker: GraphWalker instance to call .ack() on
pack_data: Function to call with pack data
progress: Optional progress reporting function
rbufsize: Read buffer size
Returns:
"""
pkt = proto.read_pkt_line()
while pkt:
parts = pkt.rstrip(b'\n').split(b' ')
if parts[0] == b'ACK':
graph_walker.ack(parts[1])
if len(parts) < 3 or parts[2] not in (
b'ready', b'continue', b'common'):
break
pkt = proto.read_pkt_line()
if CAPABILITY_SIDE_BAND_64K in capabilities:
if progress is None:
# Just ignore progress data
def progress(x):
pass
self._read_side_band64k_data(proto, {
SIDE_BAND_CHANNEL_DATA: pack_data,
SIDE_BAND_CHANNEL_PROGRESS: progress}
)
else:
while True:
data = proto.read(rbufsize)
if data == b"":
break
pack_data(data)
def check_wants(wants, refs):
"""Check that a set of wants is valid.
Args:
wants: Set of object SHAs to fetch
refs: Refs dictionary to check against
Returns:
"""
missing = set(wants) - {
v for (k, v) in refs.items()
if not k.endswith(ANNOTATED_TAG_SUFFIX)}
if missing:
raise InvalidWants(missing)
def remote_error_from_stderr(stderr):
if stderr is None:
return HangupException()
- for l in stderr.readlines():
- if l.startswith(b'ERROR: '):
+ for line in stderr.readlines():
+ if line.startswith(b'ERROR: '):
return GitProtocolError(
- l[len(b'ERROR: '):].decode('utf-8', 'replace'))
- return GitProtocolError(l.decode('utf-8', 'replace'))
+ line[len(b'ERROR: '):].decode('utf-8', 'replace'))
+ return GitProtocolError(line.decode('utf-8', 'replace'))
return HangupException()
class TraditionalGitClient(GitClient):
"""Traditional Git client."""
DEFAULT_ENCODING = 'utf-8'
def __init__(self, path_encoding=DEFAULT_ENCODING, **kwargs):
self._remote_path_encoding = path_encoding
super(TraditionalGitClient, self).__init__(**kwargs)
def _connect(self, cmd, path):
"""Create a connection to the server.
This method is abstract - concrete implementations should
implement their own variant which connects to the server and
returns an initialized Protocol object with the service ready
for use and a can_read function which may be used to see if
reads would block.
Args:
cmd: The git service name to which we should connect.
path: The path we should pass to the service. (as bytestirng)
"""
raise NotImplementedError()
def send_pack(self, path, update_refs, generate_pack_data,
progress=None):
"""Upload a pack to a remote repository.
Args:
path: Repository path (as bytestring)
update_refs: Function to determine changes to remote refs.
Receive dict with existing remote refs, returns dict with
changed refs (name -> sha, where sha=ZERO_SHA for deletions)
generate_pack_data: Function that can return a tuple with
number of objects and pack data to upload.
progress: Optional callback called with progress updates
Returns:
new_refs dictionary containing the changes that were made
{refname: new_ref}, including deleted refs.
Raises:
SendPackError: if server rejects the pack data
UpdateRefsError: if the server supports report-status
and rejects ref updates
"""
proto, unused_can_read, stderr = self._connect(b'receive-pack', path)
with proto:
try:
old_refs, server_capabilities = read_pkt_refs(proto)
except HangupException:
raise remote_error_from_stderr(stderr)
negotiated_capabilities = \
self._negotiate_receive_pack_capabilities(server_capabilities)
if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
self._report_status_parser = ReportStatusParser()
report_status_parser = self._report_status_parser
try:
new_refs = orig_new_refs = update_refs(dict(old_refs))
except BaseException:
proto.write_pkt_line(None)
raise
if CAPABILITY_DELETE_REFS not in server_capabilities:
# Server does not support deletions. Fail later.
new_refs = dict(orig_new_refs)
for ref, sha in orig_new_refs.items():
if sha == ZERO_SHA:
if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
report_status_parser._ref_statuses.append(
b'ng ' + sha +
b' remote does not support deleting refs')
report_status_parser._ref_status_ok = False
del new_refs[ref]
if new_refs is None:
proto.write_pkt_line(None)
return old_refs
if len(new_refs) == 0 and len(orig_new_refs):
# NOOP - Original new refs filtered out by policy
proto.write_pkt_line(None)
if report_status_parser is not None:
report_status_parser.check()
return old_refs
(have, want) = self._handle_receive_pack_head(
proto, negotiated_capabilities, old_refs, new_refs)
if (not want and
set(new_refs.items()).issubset(set(old_refs.items()))):
return new_refs
pack_data_count, pack_data = generate_pack_data(
have, want,
ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities))
dowrite = bool(pack_data_count)
dowrite = dowrite or any(old_refs.get(ref) != sha
for (ref, sha) in new_refs.items()
if sha != ZERO_SHA)
if dowrite:
write_pack_data(proto.write_file(), pack_data_count, pack_data)
self._handle_receive_pack_tail(
proto, negotiated_capabilities, progress)
return new_refs
def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
progress=None, depth=None):
"""Retrieve a pack from a git smart server.
Args:
path: Remote path to fetch from
determine_wants: Function determine what refs
to fetch. Receives dictionary of name->sha, should return
list of shas to fetch.
graph_walker: Object with next() and ack().
pack_data: Callback called for each bit of data in the pack
progress: Callback for progress reports (strings)
depth: Shallow fetch depth
Returns:
FetchPackResult object
"""
proto, can_read, stderr = self._connect(b'upload-pack', path)
with proto:
try:
refs, server_capabilities = read_pkt_refs(proto)
except HangupException:
raise remote_error_from_stderr(stderr)
negotiated_capabilities, symrefs, agent = (
self._negotiate_upload_pack_capabilities(
server_capabilities))
if refs is None:
proto.write_pkt_line(None)
return FetchPackResult(refs, symrefs, agent)
try:
wants = determine_wants(refs)
except BaseException:
proto.write_pkt_line(None)
raise
if wants is not None:
wants = [cid for cid in wants if cid != ZERO_SHA]
if not wants:
proto.write_pkt_line(None)
return FetchPackResult(refs, symrefs, agent)
(new_shallow, new_unshallow) = self._handle_upload_pack_head(
proto, negotiated_capabilities, graph_walker, wants, can_read,
depth=depth)
self._handle_upload_pack_tail(
proto, negotiated_capabilities, graph_walker, pack_data,
progress)
return FetchPackResult(
refs, symrefs, agent, new_shallow, new_unshallow)
def get_refs(self, path):
"""Retrieve the current refs from a git smart server.
"""
# stock `git ls-remote` uses upload-pack
proto, _, stderr = self._connect(b'upload-pack', path)
with proto:
try:
refs, _ = read_pkt_refs(proto)
except HangupException:
raise remote_error_from_stderr(stderr)
proto.write_pkt_line(None)
return refs
def archive(self, path, committish, write_data, progress=None,
write_error=None, format=None, subdirs=None, prefix=None):
proto, can_read, stderr = self._connect(b'upload-archive', path)
with proto:
if format is not None:
proto.write_pkt_line(b"argument --format=" + format)
proto.write_pkt_line(b"argument " + committish)
if subdirs is not None:
for subdir in subdirs:
proto.write_pkt_line(b"argument " + subdir)
if prefix is not None:
proto.write_pkt_line(b"argument --prefix=" + prefix)
proto.write_pkt_line(None)
try:
pkt = proto.read_pkt_line()
except HangupException:
raise remote_error_from_stderr(stderr)
if pkt == b"NACK\n":
return
elif pkt == b"ACK\n":
pass
elif pkt.startswith(b"ERR "):
raise GitProtocolError(
pkt[4:].rstrip(b"\n").decode('utf-8', 'replace'))
else:
raise AssertionError("invalid response %r" % pkt)
ret = proto.read_pkt_line()
if ret is not None:
raise AssertionError("expected pkt tail")
self._read_side_band64k_data(proto, {
SIDE_BAND_CHANNEL_DATA: write_data,
SIDE_BAND_CHANNEL_PROGRESS: progress,
SIDE_BAND_CHANNEL_FATAL: write_error})
class TCPGitClient(TraditionalGitClient):
"""A Git Client that works over TCP directly (i.e. git://)."""
def __init__(self, host, port=None, **kwargs):
if port is None:
port = TCP_GIT_PORT
self._host = host
self._port = port
super(TCPGitClient, self).__init__(**kwargs)
@classmethod
def from_parsedurl(cls, parsedurl, **kwargs):
return cls(parsedurl.hostname, port=parsedurl.port, **kwargs)
def get_url(self, path):
netloc = self._host
if self._port is not None and self._port != TCP_GIT_PORT:
netloc += ":%d" % self._port
return urlunsplit(("git", netloc, path, '', ''))
def _connect(self, cmd, path):
if not isinstance(cmd, bytes):
raise TypeError(cmd)
if not isinstance(path, bytes):
path = path.encode(self._remote_path_encoding)
sockaddrs = socket.getaddrinfo(
self._host, self._port, socket.AF_UNSPEC, socket.SOCK_STREAM)
s = None
err = socket.error("no address found for %s" % self._host)
for (family, socktype, proto, canonname, sockaddr) in sockaddrs:
s = socket.socket(family, socktype, proto)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
try:
s.connect(sockaddr)
break
except socket.error as e:
err = e
if s is not None:
s.close()
s = None
if s is None:
raise err
# -1 means system default buffering
rfile = s.makefile('rb', -1)
# 0 means unbuffered
wfile = s.makefile('wb', 0)
def close():
rfile.close()
wfile.close()
s.close()
proto = Protocol(rfile.read, wfile.write, close,
report_activity=self._report_activity)
if path.startswith(b"/~"):
path = path[1:]
# TODO(jelmer): Alternative to ascii?
proto.send_cmd(
b'git-' + cmd, path, b'host=' + self._host.encode('ascii'))
return proto, lambda: _fileno_can_read(s), None
class SubprocessWrapper(object):
"""A socket-like object that talks to a subprocess via pipes."""
def __init__(self, proc):
self.proc = proc
self.read = BufferedReader(proc.stdout).read
self.write = proc.stdin.write
@property
def stderr(self):
return self.proc.stderr
def can_read(self):
if sys.platform == 'win32':
from msvcrt import get_osfhandle
handle = get_osfhandle(self.proc.stdout.fileno())
return _win32_peek_avail(handle) != 0
else:
return _fileno_can_read(self.proc.stdout.fileno())
def close(self):
self.proc.stdin.close()
self.proc.stdout.close()
if self.proc.stderr:
self.proc.stderr.close()
self.proc.wait()
def find_git_command():
"""Find command to run for system Git (usually C Git)."""
if sys.platform == 'win32': # support .exe, .bat and .cmd
try: # to avoid overhead
import win32api
except ImportError: # run through cmd.exe with some overhead
return ['cmd', '/c', 'git']
else:
status, git = win32api.FindExecutable('git')
return [git]
else:
return ['git']
class SubprocessGitClient(TraditionalGitClient):
"""Git client that talks to a server using a subprocess."""
@classmethod
def from_parsedurl(cls, parsedurl, **kwargs):
return cls(**kwargs)
git_command = None
def _connect(self, service, path):
if not isinstance(service, bytes):
raise TypeError(service)
if isinstance(path, bytes):
path = path.decode(self._remote_path_encoding)
if self.git_command is None:
git_command = find_git_command()
argv = git_command + [service.decode('ascii'), path]
p = subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
pw = SubprocessWrapper(p)
return (Protocol(pw.read, pw.write, pw.close,
report_activity=self._report_activity),
pw.can_read, p.stderr)
class LocalGitClient(GitClient):
"""Git Client that just uses a local Repo."""
def __init__(self, thin_packs=True, report_activity=None, config=None):
"""Create a new LocalGitClient instance.
Args:
thin_packs: Whether or not thin packs should be retrieved
report_activity: Optional callback for reporting transport
activity.
"""
self._report_activity = report_activity
# Ignore the thin_packs argument
def get_url(self, path):
return urlunsplit(('file', '', path, '', ''))
@classmethod
def from_parsedurl(cls, parsedurl, **kwargs):
return cls(**kwargs)
@classmethod
def _open_repo(cls, path):
from dulwich.repo import Repo
if not isinstance(path, str):
path = os.fsdecode(path)
return closing(Repo(path))
def send_pack(self, path, update_refs, generate_pack_data,
progress=None):
"""Upload a pack to a remote repository.
Args:
path: Repository path (as bytestring)
update_refs: Function to determine changes to remote refs.
Receive dict with existing remote refs, returns dict with
changed refs (name -> sha, where sha=ZERO_SHA for deletions)
with number of items and pack data to upload.
progress: Optional progress function
Returns:
new_refs dictionary containing the changes that were made
{refname: new_ref}, including deleted refs.
Raises:
SendPackError: if server rejects the pack data
UpdateRefsError: if the server supports report-status
and rejects ref updates
"""
if not progress:
def progress(x):
pass
with self._open_repo(path) as target:
old_refs = target.get_refs()
new_refs = update_refs(dict(old_refs))
have = [sha1 for sha1 in old_refs.values() if sha1 != ZERO_SHA]
want = []
for refname, new_sha1 in new_refs.items():
if (new_sha1 not in have and
new_sha1 not in want and
new_sha1 != ZERO_SHA):
want.append(new_sha1)
if (not want and
set(new_refs.items()).issubset(set(old_refs.items()))):
return new_refs
target.object_store.add_pack_data(
*generate_pack_data(have, want, ofs_delta=True))
for refname, new_sha1 in new_refs.items():
old_sha1 = old_refs.get(refname, ZERO_SHA)
if new_sha1 != ZERO_SHA:
if not target.refs.set_if_equals(
refname, old_sha1, new_sha1):
progress('unable to set %s to %s' %
(refname, new_sha1))
else:
if not target.refs.remove_if_equals(refname, old_sha1):
progress('unable to remove %s' % refname)
return new_refs
def fetch(self, path, target, determine_wants=None, progress=None,
depth=None):
"""Fetch into a target repository.
Args:
path: Path to fetch from (as bytestring)
target: Target repository to fetch into
determine_wants: Optional function determine what refs
to fetch. Receives dictionary of name->sha, should return
list of shas to fetch. Defaults to all shas.
progress: Optional progress function
depth: Shallow fetch depth
Returns:
FetchPackResult object
"""
with self._open_repo(path) as r:
refs = r.fetch(target, determine_wants=determine_wants,
progress=progress, depth=depth)
return FetchPackResult(refs, r.refs.get_symrefs(),
agent_string())
def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
progress=None, depth=None):
"""Retrieve a pack from a git smart server.
Args:
path: Remote path to fetch from
determine_wants: Function determine what refs
to fetch. Receives dictionary of name->sha, should return
list of shas to fetch.
graph_walker: Object with next() and ack().
pack_data: Callback called for each bit of data in the pack
progress: Callback for progress reports (strings)
depth: Shallow fetch depth
Returns:
FetchPackResult object
"""
with self._open_repo(path) as r:
objects_iter = r.fetch_objects(
determine_wants, graph_walker, progress=progress, depth=depth)
symrefs = r.refs.get_symrefs()
agent = agent_string()
# Did the process short-circuit (e.g. in a stateless RPC call)?
# Note that the client still expects a 0-object pack in most cases.
if objects_iter is None:
return FetchPackResult(None, symrefs, agent)
protocol = ProtocolFile(None, pack_data)
write_pack_objects(protocol, objects_iter)
return FetchPackResult(r.get_refs(), symrefs, agent)
def get_refs(self, path):
"""Retrieve the current refs from a git smart server.
"""
with self._open_repo(path) as target:
return target.get_refs()
# What Git client to use for local access
default_local_git_client_cls = LocalGitClient
class SSHVendor(object):
"""A client side SSH implementation."""
def connect_ssh(self, host, command, username=None, port=None,
password=None, key_filename=None):
# This function was deprecated in 0.9.1
import warnings
warnings.warn(
"SSHVendor.connect_ssh has been renamed to SSHVendor.run_command",
DeprecationWarning)
return self.run_command(host, command, username=username, port=port,
password=password, key_filename=key_filename)
def run_command(self, host, command, username=None, port=None,
password=None, key_filename=None):
"""Connect to an SSH server.
Run a command remotely and return a file-like object for interaction
with the remote command.
Args:
host: Host name
command: Command to run (as argv array)
username: Optional ame of user to log in as
port: Optional SSH port to use
password: Optional ssh password for login or private key
key_filename: Optional path to private keyfile
Returns:
"""
raise NotImplementedError(self.run_command)
class StrangeHostname(Exception):
"""Refusing to connect to strange SSH hostname."""
def __init__(self, hostname):
super(StrangeHostname, self).__init__(hostname)
class SubprocessSSHVendor(SSHVendor):
"""SSH vendor that shells out to the local 'ssh' command."""
def run_command(self, host, command, username=None, port=None,
password=None, key_filename=None):
if password is not None:
raise NotImplementedError(
"Setting password not supported by SubprocessSSHVendor.")
args = ['ssh', '-x']
if port:
args.extend(['-p', str(port)])
if key_filename:
args.extend(['-i', str(key_filename)])
if username:
host = '%s@%s' % (username, host)
if host.startswith('-'):
raise StrangeHostname(hostname=host)
args.append(host)
proc = subprocess.Popen(args + [command], bufsize=0,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return SubprocessWrapper(proc)
class PLinkSSHVendor(SSHVendor):
"""SSH vendor that shells out to the local 'plink' command."""
def run_command(self, host, command, username=None, port=None,
password=None, key_filename=None):
if sys.platform == 'win32':
args = ['plink.exe', '-ssh']
else:
args = ['plink', '-ssh']
if password is not None:
import warnings
warnings.warn(
"Invoking PLink with a password exposes the password in the "
"process list.")
args.extend(['-pw', str(password)])
if port:
args.extend(['-P', str(port)])
if key_filename:
args.extend(['-i', str(key_filename)])
if username:
host = '%s@%s' % (username, host)
if host.startswith('-'):
raise StrangeHostname(hostname=host)
args.append(host)
proc = subprocess.Popen(args + [command], bufsize=0,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
return SubprocessWrapper(proc)
def ParamikoSSHVendor(**kwargs):
import warnings
warnings.warn(
"ParamikoSSHVendor has been moved to dulwich.contrib.paramiko_vendor.",
DeprecationWarning)
from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
return ParamikoSSHVendor(**kwargs)
# Can be overridden by users
get_ssh_vendor = SubprocessSSHVendor
class SSHGitClient(TraditionalGitClient):
def __init__(self, host, port=None, username=None, vendor=None,
config=None, password=None, key_filename=None, **kwargs):
self.host = host
self.port = port
self.username = username
self.password = password
self.key_filename = key_filename
super(SSHGitClient, self).__init__(**kwargs)
self.alternative_paths = {}
if vendor is not None:
self.ssh_vendor = vendor
else:
self.ssh_vendor = get_ssh_vendor()
def get_url(self, path):
netloc = self.host
if self.port is not None:
netloc += ":%d" % self.port
if self.username is not None:
netloc = urlquote(self.username, '@/:') + "@" + netloc
return urlunsplit(('ssh', netloc, path, '', ''))
@classmethod
def from_parsedurl(cls, parsedurl, **kwargs):
return cls(host=parsedurl.hostname, port=parsedurl.port,
username=parsedurl.username, **kwargs)
def _get_cmd_path(self, cmd):
cmd = self.alternative_paths.get(cmd, b'git-' + cmd)
assert isinstance(cmd, bytes)
return cmd
def _connect(self, cmd, path):
if not isinstance(cmd, bytes):
raise TypeError(cmd)
if isinstance(path, bytes):
path = path.decode(self._remote_path_encoding)
if path.startswith("/~"):
path = path[1:]
argv = (self._get_cmd_path(cmd).decode(self._remote_path_encoding) +
" '" + path + "'")
kwargs = {}
if self.password is not None:
kwargs['password'] = self.password
if self.key_filename is not None:
kwargs['key_filename'] = self.key_filename
con = self.ssh_vendor.run_command(
self.host, argv, port=self.port, username=self.username,
**kwargs)
return (Protocol(con.read, con.write, con.close,
report_activity=self._report_activity),
con.can_read, getattr(con, 'stderr', None))
def default_user_agent_string():
# Start user agent with "git/", because GitHub requires this. :-( See
# https://github.com/jelmer/dulwich/issues/562 for details.
return "git/dulwich/%s" % ".".join([str(x) for x in dulwich.__version__])
def default_urllib3_manager(config, pool_manager_cls=None,
proxy_manager_cls=None, **override_kwargs):
"""Return `urllib3` connection pool manager.
Honour detected proxy configurations.
Args:
config: dulwich.config.ConfigDict` instance with Git configuration.
kwargs: Additional arguments for urllib3.ProxyManager
Returns:
`pool_manager_cls` (defaults to `urllib3.ProxyManager`) instance for
proxy configurations, `proxy_manager_cls` (defaults to
`urllib3.PoolManager`) instance otherwise.
"""
proxy_server = user_agent = None
ca_certs = ssl_verify = None
if config is not None:
try:
proxy_server = config.get(b"http", b"proxy")
except KeyError:
pass
try:
user_agent = config.get(b"http", b"useragent")
except KeyError:
pass
# TODO(jelmer): Support per-host settings
try:
ssl_verify = config.get_boolean(b"http", b"sslVerify")
except KeyError:
ssl_verify = True
try:
ca_certs = config.get(b"http", b"sslCAInfo")
except KeyError:
ca_certs = None
if user_agent is None:
user_agent = default_user_agent_string()
headers = {"User-agent": user_agent}
kwargs = {}
if ssl_verify is True:
kwargs['cert_reqs'] = "CERT_REQUIRED"
elif ssl_verify is False:
kwargs['cert_reqs'] = 'CERT_NONE'
else:
# Default to SSL verification
kwargs['cert_reqs'] = "CERT_REQUIRED"
if ca_certs is not None:
kwargs['ca_certs'] = ca_certs
kwargs.update(override_kwargs)
# Try really hard to find a SSL certificate path
if 'ca_certs' not in kwargs and kwargs.get('cert_reqs') != 'CERT_NONE':
try:
import certifi
except ImportError:
pass
else:
kwargs['ca_certs'] = certifi.where()
import urllib3
if proxy_server is not None:
if proxy_manager_cls is None:
proxy_manager_cls = urllib3.ProxyManager
# `urllib3` requires a `str` object in both Python 2 and 3, while
# `ConfigDict` coerces entries to `bytes` on Python 3. Compensate.
if not isinstance(proxy_server, str):
proxy_server = proxy_server.decode()
manager = proxy_manager_cls(proxy_server, headers=headers, **kwargs)
else:
if pool_manager_cls is None:
pool_manager_cls = urllib3.PoolManager
manager = pool_manager_cls(headers=headers, **kwargs)
return manager
class HttpGitClient(GitClient):
def __init__(self, base_url, dumb=None, pool_manager=None, config=None,
username=None, password=None, **kwargs):
self._base_url = base_url.rstrip("/") + "/"
self._username = username
self._password = password
self.dumb = dumb
if pool_manager is None:
self.pool_manager = default_urllib3_manager(config)
else:
self.pool_manager = pool_manager
if username is not None:
# No escaping needed: ":" is not allowed in username:
# https://tools.ietf.org/html/rfc2617#section-2
credentials = "%s:%s" % (username, password)
import urllib3.util
basic_auth = urllib3.util.make_headers(basic_auth=credentials)
self.pool_manager.headers.update(basic_auth)
GitClient.__init__(self, **kwargs)
def get_url(self, path):
return self._get_url(path).rstrip("/")
@classmethod
def from_parsedurl(cls, parsedurl, **kwargs):
password = parsedurl.password
if password is not None:
kwargs['password'] = urlunquote(password)
username = parsedurl.username
if username is not None:
kwargs['username'] = urlunquote(username)
netloc = parsedurl.hostname
if parsedurl.port:
netloc = "%s:%s" % (netloc, parsedurl.port)
if parsedurl.username:
netloc = "%s@%s" % (parsedurl.username, netloc)
parsedurl = parsedurl._replace(netloc=netloc)
return cls(urlunparse(parsedurl), **kwargs)
def __repr__(self):
return "%s(%r, dumb=%r)" % (
type(self).__name__, self._base_url, self.dumb)
def _get_url(self, path):
if not isinstance(path, str):
# urllib3.util.url._encode_invalid_chars() converts the path back
# to bytes using the utf-8 codec.
path = path.decode('utf-8')
return urljoin(self._base_url, path).rstrip("/") + "/"
def _http_request(self, url, headers=None, data=None,
allow_compression=False):
"""Perform HTTP request.
Args:
url: Request URL.
headers: Optional custom headers to override defaults.
data: Request data.
allow_compression: Allow GZipped communication.
Returns:
Tuple (`response`, `read`), where response is an `urllib3`
response object with additional `content_type` and
`redirect_location` properties, and `read` is a consumable read
method for the response data.
"""
req_headers = self.pool_manager.headers.copy()
if headers is not None:
req_headers.update(headers)
req_headers["Pragma"] = "no-cache"
if allow_compression:
req_headers["Accept-Encoding"] = "gzip"
else:
req_headers["Accept-Encoding"] = "identity"
if data is None:
resp = self.pool_manager.request("GET", url, headers=req_headers)
else:
resp = self.pool_manager.request("POST", url, headers=req_headers,
body=data)
if resp.status == 404:
raise NotGitRepository()
elif resp.status != 200:
raise GitProtocolError("unexpected http resp %d for %s" %
(resp.status, url))
# TODO: Optimization available by adding `preload_content=False` to the
# request and just passing the `read` method on instead of going via
# `BytesIO`, if we can guarantee that the entire response is consumed
# before issuing the next to still allow for connection reuse from the
# pool.
read = BytesIO(resp.data).read
resp.content_type = resp.getheader("Content-Type")
# Check if geturl() is available (urllib3 version >= 1.23)
try:
resp_url = resp.geturl()
except AttributeError:
# get_redirect_location() is available for urllib3 >= 1.1
resp.redirect_location = resp.get_redirect_location()
else:
resp.redirect_location = resp_url if resp_url != url else ''
return resp, read
def _discover_references(self, service, base_url):
assert base_url[-1] == "/"
tail = "info/refs"
headers = {"Accept": "*/*"}
if self.dumb is not True:
tail += "?service=%s" % service.decode('ascii')
url = urljoin(base_url, tail)
resp, read = self._http_request(url, headers, allow_compression=True)
if resp.redirect_location:
# Something changed (redirect!), so let's update the base URL
if not resp.redirect_location.endswith(tail):
raise GitProtocolError(
"Redirected from URL %s to URL %s without %s" % (
url, resp.redirect_location, tail))
base_url = resp.redirect_location[:-len(tail)]
try:
self.dumb = not resp.content_type.startswith("application/x-git-")
if not self.dumb:
proto = Protocol(read, None)
# The first line should mention the service
try:
[pkt] = list(proto.read_pkt_seq())
except ValueError:
raise GitProtocolError(
"unexpected number of packets received")
if pkt.rstrip(b'\n') != (b'# service=' + service):
raise GitProtocolError(
"unexpected first line %r from smart server" % pkt)
return read_pkt_refs(proto) + (base_url, )
else:
return read_info_refs(resp), set(), base_url
finally:
resp.close()
def _smart_request(self, service, url, data):
assert url[-1] == "/"
url = urljoin(url, service)
result_content_type = "application/x-%s-result" % service
headers = {
"Content-Type": "application/x-%s-request" % service,
"Accept": result_content_type,
"Content-Length": str(len(data)),
}
resp, read = self._http_request(url, headers, data)
if resp.content_type != result_content_type:
raise GitProtocolError("Invalid content-type from server: %s"
% resp.content_type)
return resp, read
def send_pack(self, path, update_refs, generate_pack_data,
progress=None):
"""Upload a pack to a remote repository.
Args:
path: Repository path (as bytestring)
update_refs: Function to determine changes to remote refs.
Receives dict with existing remote refs, returns dict with
changed refs (name -> sha, where sha=ZERO_SHA for deletions)
generate_pack_data: Function that can return a tuple
with number of elements and pack data to upload.
progress: Optional progress function
Returns:
new_refs dictionary containing the changes that were made
{refname: new_ref}, including deleted refs.
Raises:
SendPackError: if server rejects the pack data
UpdateRefsError: if the server supports report-status
and rejects ref updates
"""
url = self._get_url(path)
old_refs, server_capabilities, url = self._discover_references(
b"git-receive-pack", url)
negotiated_capabilities = self._negotiate_receive_pack_capabilities(
server_capabilities)
negotiated_capabilities.add(capability_agent())
if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
self._report_status_parser = ReportStatusParser()
new_refs = update_refs(dict(old_refs))
if new_refs is None:
# Determine wants function is aborting the push.
return old_refs
if self.dumb:
raise NotImplementedError(self.fetch_pack)
req_data = BytesIO()
req_proto = Protocol(None, req_data.write)
(have, want) = self._handle_receive_pack_head(
req_proto, negotiated_capabilities, old_refs, new_refs)
if not want and set(new_refs.items()).issubset(set(old_refs.items())):
return new_refs
pack_data_count, pack_data = generate_pack_data(
have, want,
ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities))
if pack_data_count:
write_pack_data(req_proto.write_file(), pack_data_count, pack_data)
resp, read = self._smart_request("git-receive-pack", url,
data=req_data.getvalue())
try:
resp_proto = Protocol(read, None)
self._handle_receive_pack_tail(
resp_proto, negotiated_capabilities, progress)
return new_refs
finally:
resp.close()
def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
progress=None, depth=None):
"""Retrieve a pack from a git smart server.
Args:
path: Path to fetch from
determine_wants: Callback that returns list of commits to fetch
graph_walker: Object with next() and ack().
pack_data: Callback called for each bit of data in the pack
progress: Callback for progress reports (strings)
depth: Depth for request
Returns:
FetchPackResult object
"""
url = self._get_url(path)
refs, server_capabilities, url = self._discover_references(
b"git-upload-pack", url)
negotiated_capabilities, symrefs, agent = (
self._negotiate_upload_pack_capabilities(
server_capabilities))
wants = determine_wants(refs)
if wants is not None:
wants = [cid for cid in wants if cid != ZERO_SHA]
if not wants:
return FetchPackResult(refs, symrefs, agent)
if self.dumb:
raise NotImplementedError(self.send_pack)
req_data = BytesIO()
req_proto = Protocol(None, req_data.write)
(new_shallow, new_unshallow) = self._handle_upload_pack_head(
req_proto, negotiated_capabilities, graph_walker, wants,
can_read=None, depth=depth)
resp, read = self._smart_request(
"git-upload-pack", url, data=req_data.getvalue())
try:
resp_proto = Protocol(read, None)
if new_shallow is None and new_unshallow is None:
(new_shallow, new_unshallow) = _read_shallow_updates(
resp_proto)
self._handle_upload_pack_tail(
resp_proto, negotiated_capabilities, graph_walker, pack_data,
progress)
return FetchPackResult(
refs, symrefs, agent, new_shallow, new_unshallow)
finally:
resp.close()
def get_refs(self, path):
"""Retrieve the current refs from a git smart server.
"""
url = self._get_url(path)
refs, _, _ = self._discover_references(
b"git-upload-pack", url)
return refs
def get_transport_and_path_from_url(url, config=None, **kwargs):
"""Obtain a git client from a URL.
Args:
url: URL to open (a unicode string)
config: Optional config object
thin_packs: Whether or not thin packs should be retrieved
report_activity: Optional callback for reporting transport
activity.
Returns:
Tuple with client instance and relative path.
"""
parsed = urlparse(url)
if parsed.scheme == 'git':
return (TCPGitClient.from_parsedurl(parsed, **kwargs),
parsed.path)
elif parsed.scheme in ('git+ssh', 'ssh'):
return SSHGitClient.from_parsedurl(parsed, **kwargs), parsed.path
elif parsed.scheme in ('http', 'https'):
return HttpGitClient.from_parsedurl(
parsed, config=config, **kwargs), parsed.path
elif parsed.scheme == 'file':
return default_local_git_client_cls.from_parsedurl(
parsed, **kwargs), parsed.path
raise ValueError("unknown scheme '%s'" % parsed.scheme)
def parse_rsync_url(location):
"""Parse a rsync-style URL.
"""
if ':' in location and '@' not in location:
# SSH with no user@, zero or one leading slash.
(host, path) = location.split(':', 1)
user = None
elif ':' in location:
# SSH with user@host:foo.
user_host, path = location.split(':', 1)
if '@' in user_host:
user, host = user_host.rsplit('@', 1)
else:
user = None
host = user_host
else:
raise ValueError('not a valid rsync-style URL')
return (user, host, path)
def get_transport_and_path(location, **kwargs):
"""Obtain a git client from a URL.
Args:
location: URL or path (a string)
config: Optional config object
thin_packs: Whether or not thin packs should be retrieved
report_activity: Optional callback for reporting transport
activity.
Returns:
Tuple with client instance and relative path.
"""
# First, try to parse it as a URL
try:
return get_transport_and_path_from_url(location, **kwargs)
except ValueError:
pass
if (sys.platform == 'win32' and
location[0].isalpha() and location[1:3] == ':\\'):
# Windows local path
return default_local_git_client_cls(**kwargs), location
try:
(username, hostname, path) = parse_rsync_url(location)
except ValueError:
# Otherwise, assume it's a local path.
return default_local_git_client_cls(**kwargs), location
else:
return SSHGitClient(hostname, username=username, **kwargs), path
DEFAULT_GIT_CREDENTIALS_PATHS = [
os.path.expanduser('~/.git-credentials'),
get_xdg_config_home_path('git', 'credentials')]
+
def get_credentials_from_store(scheme, hostname, username=None,
fnames=DEFAULT_GIT_CREDENTIALS_PATHS):
for fname in fnames:
try:
with open(fname, 'rb') as f:
for line in f:
parsed_line = urlparse.urlparse(line)
if (parsed_line.scheme == scheme and
parsed_line.hostname == hostname and
(username is None or
parsed_line.username == username)):
return parsed_line.username, parsed_line.password
except OSError as e:
if e.errno != errno.ENOENT:
raise
# If the file doesn't exist, try the next one.
continue
diff --git a/dulwich/objects.py b/dulwich/objects.py
index a0300dc6..83bfd849 100644
--- a/dulwich/objects.py
+++ b/dulwich/objects.py
@@ -1,1443 +1,1442 @@
# objects.py -- Access to base git objects
# Copyright (C) 2007 James Westby
# Copyright (C) 2008-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Access to base git objects."""
import binascii
from io import BytesIO
from collections import namedtuple
import os
import posixpath
import stat
from typing import (
Optional,
Dict,
Union,
Type,
)
import warnings
import zlib
from hashlib import sha1
from dulwich.errors import (
ChecksumMismatch,
NotBlobError,
NotCommitError,
NotTagError,
NotTreeError,
ObjectFormatException,
FileFormatException,
)
from dulwich.file import GitFile
ZERO_SHA = b'0' * 40
# Header fields for commits
_TREE_HEADER = b'tree'
_PARENT_HEADER = b'parent'
_AUTHOR_HEADER = b'author'
_COMMITTER_HEADER = b'committer'
_ENCODING_HEADER = b'encoding'
_MERGETAG_HEADER = b'mergetag'
_GPGSIG_HEADER = b'gpgsig'
# Header fields for objects
_OBJECT_HEADER = b'object'
_TYPE_HEADER = b'type'
_TAG_HEADER = b'tag'
_TAGGER_HEADER = b'tagger'
S_IFGITLINK = 0o160000
MAX_TIME = 9223372036854775807 # (2**63) - 1 - signed long int max
BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"
class EmptyFileException(FileFormatException):
"""An unexpectedly empty file was encountered."""
-
def S_ISGITLINK(m):
"""Check if a mode indicates a submodule.
Args:
m: Mode to check
Returns: a ``boolean``
"""
return (stat.S_IFMT(m) == S_IFGITLINK)
def _decompress(string):
dcomp = zlib.decompressobj()
dcomped = dcomp.decompress(string)
dcomped += dcomp.flush()
return dcomped
def sha_to_hex(sha):
"""Takes a string and returns the hex of the sha within"""
hexsha = binascii.hexlify(sha)
assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % hexsha
return hexsha
def hex_to_sha(hex):
"""Takes a hex sha and returns a binary sha"""
assert len(hex) == 40, "Incorrect length of hexsha: %s" % hex
try:
return binascii.unhexlify(hex)
except TypeError as exc:
if not isinstance(hex, bytes):
raise
raise ValueError(exc.args[0])
def valid_hexsha(hex):
if len(hex) != 40:
return False
try:
binascii.unhexlify(hex)
except (TypeError, binascii.Error):
return False
else:
return True
def hex_to_filename(path, hex):
"""Takes a hex sha and returns its filename relative to the given path."""
# os.path.join accepts bytes or unicode, but all args must be of the same
# type. Make sure that hex which is expected to be bytes, is the same type
# as path.
if getattr(path, 'encode', None) is not None:
hex = hex.decode('ascii')
dir = hex[:2]
file = hex[2:]
# Check from object dir
return os.path.join(path, dir, file)
def filename_to_hex(filename):
"""Takes an object filename and returns its corresponding hex sha."""
# grab the last (up to) two path components
names = filename.rsplit(os.path.sep, 2)[-2:]
errmsg = "Invalid object filename: %s" % filename
assert len(names) == 2, errmsg
base, rest = names
assert len(base) == 2 and len(rest) == 38, errmsg
hex = (base + rest).encode('ascii')
hex_to_sha(hex)
return hex
def object_header(num_type: int, length: int) -> bytes:
"""Return an object header for the given numeric type and text length."""
return (object_class(num_type).type_name +
b' ' + str(length).encode('ascii') + b'\0')
def serializable_property(name: str, docstring: Optional[str] = None):
"""A property that helps tracking whether serialization is necessary.
"""
def set(obj, value):
setattr(obj, "_"+name, value)
obj._needs_serialization = True
def get(obj):
return getattr(obj, "_"+name)
return property(get, set, doc=docstring)
def object_class(type):
"""Get the object class corresponding to the given type.
Args:
type: Either a type name string or a numeric type.
Returns: The ShaFile subclass corresponding to the given type, or None if
type is not a valid type name/number.
"""
return _TYPE_MAP.get(type, None)
def check_hexsha(hex, error_msg):
"""Check if a string is a valid hex sha string.
Args:
hex: Hex string to check
error_msg: Error message to use in exception
Raises:
ObjectFormatException: Raised when the string is not valid
"""
if not valid_hexsha(hex):
raise ObjectFormatException("%s %s" % (error_msg, hex))
def check_identity(identity, error_msg):
"""Check if the specified identity is valid.
This will raise an exception if the identity is not valid.
Args:
identity: Identity string
error_msg: Error message to use in exception
"""
email_start = identity.find(b'<')
email_end = identity.find(b'>')
if (email_start < 0 or email_end < 0 or email_end <= email_start
or identity.find(b'<', email_start + 1) >= 0
or identity.find(b'>', email_end + 1) >= 0
or not identity.endswith(b'>')):
raise ObjectFormatException(error_msg)
def check_time(time_seconds):
"""Check if the specified time is not prone to overflow error.
This will raise an exception if the time is not valid.
Args:
time_info: author/committer/tagger info
"""
# Prevent overflow error
if time_seconds > MAX_TIME:
raise ObjectFormatException(
'Date field should not exceed %s' % MAX_TIME)
def git_line(*items):
"""Formats items into a space separated line."""
return b' '.join(items) + b'\n'
class FixedSha(object):
"""SHA object that behaves like hashlib's but is given a fixed value."""
__slots__ = ('_hexsha', '_sha')
def __init__(self, hexsha):
if getattr(hexsha, 'encode', None) is not None:
hexsha = hexsha.encode('ascii')
if not isinstance(hexsha, bytes):
raise TypeError('Expected bytes for hexsha, got %r' % hexsha)
self._hexsha = hexsha
self._sha = hex_to_sha(hexsha)
def digest(self):
"""Return the raw SHA digest."""
return self._sha
def hexdigest(self):
"""Return the hex SHA digest."""
return self._hexsha.decode('ascii')
class ShaFile(object):
"""A git SHA file."""
__slots__ = ('_chunked_text', '_sha', '_needs_serialization')
type_name: bytes
type_num: int
@staticmethod
def _parse_legacy_object_header(magic, f):
"""Parse a legacy object, creating it but not reading the file."""
bufsize = 1024
decomp = zlib.decompressobj()
header = decomp.decompress(magic)
start = 0
end = -1
while end < 0:
extra = f.read(bufsize)
header += decomp.decompress(extra)
magic += extra
end = header.find(b'\0', start)
start = len(header)
header = header[:end]
type_name, size = header.split(b' ', 1)
try:
int(size) # sanity check
except ValueError as e:
raise ObjectFormatException("Object size not an integer: %s" % e)
obj_class = object_class(type_name)
if not obj_class:
raise ObjectFormatException("Not a known type: %s" % type_name)
return obj_class()
def _parse_legacy_object(self, map):
"""Parse a legacy object, setting the raw string."""
text = _decompress(map)
header_end = text.find(b'\0')
if header_end < 0:
raise ObjectFormatException("Invalid object header, no \\0")
self.set_raw_string(text[header_end+1:])
def as_legacy_object_chunks(self, compression_level=-1):
"""Return chunks representing the object in the experimental format.
Returns: List of strings
"""
compobj = zlib.compressobj(compression_level)
yield compobj.compress(self._header())
for chunk in self.as_raw_chunks():
yield compobj.compress(chunk)
yield compobj.flush()
def as_legacy_object(self, compression_level=-1):
"""Return string representing the object in the experimental format.
"""
return b''.join(self.as_legacy_object_chunks(
compression_level=compression_level))
def as_raw_chunks(self):
"""Return chunks with serialization of the object.
Returns: List of strings, not necessarily one per line
"""
if self._needs_serialization:
self._sha = None
self._chunked_text = self._serialize()
self._needs_serialization = False
return self._chunked_text
def as_raw_string(self):
"""Return raw string with serialization of the object.
Returns: String object
"""
return b''.join(self.as_raw_chunks())
def __bytes__(self):
"""Return raw string serialization of this object."""
return self.as_raw_string()
def __hash__(self):
"""Return unique hash for this object."""
return hash(self.id)
def as_pretty_string(self):
"""Return a string representing this object, fit for display."""
return self.as_raw_string()
def set_raw_string(self, text, sha=None):
"""Set the contents of this object from a serialized string."""
if not isinstance(text, bytes):
raise TypeError('Expected bytes for text, got %r' % text)
self.set_raw_chunks([text], sha)
def set_raw_chunks(self, chunks, sha=None):
"""Set the contents of this object from a list of chunks."""
self._chunked_text = chunks
self._deserialize(chunks)
if sha is None:
self._sha = None
else:
self._sha = FixedSha(sha)
self._needs_serialization = False
@staticmethod
def _parse_object_header(magic, f):
"""Parse a new style object, creating it but not reading the file."""
num_type = (ord(magic[0:1]) >> 4) & 7
obj_class = object_class(num_type)
if not obj_class:
raise ObjectFormatException("Not a known type %d" % num_type)
return obj_class()
def _parse_object(self, map):
"""Parse a new style object, setting self._text."""
# skip type and size; type must have already been determined, and
# we trust zlib to fail if it's otherwise corrupted
byte = ord(map[0:1])
used = 1
while (byte & 0x80) != 0:
byte = ord(map[used:used+1])
used += 1
raw = map[used:]
self.set_raw_string(_decompress(raw))
@classmethod
def _is_legacy_object(cls, magic):
b0 = ord(magic[0:1])
b1 = ord(magic[1:2])
word = (b0 << 8) + b1
return (b0 & 0x8F) == 0x08 and (word % 31) == 0
@classmethod
def _parse_file(cls, f):
map = f.read()
if not map:
raise EmptyFileException('Corrupted empty file detected')
if cls._is_legacy_object(map):
obj = cls._parse_legacy_object_header(map, f)
obj._parse_legacy_object(map)
else:
obj = cls._parse_object_header(map, f)
obj._parse_object(map)
return obj
def __init__(self):
"""Don't call this directly"""
self._sha = None
self._chunked_text = []
self._needs_serialization = True
def _deserialize(self, chunks):
raise NotImplementedError(self._deserialize)
def _serialize(self):
raise NotImplementedError(self._serialize)
@classmethod
def from_path(cls, path):
"""Open a SHA file from disk."""
with GitFile(path, 'rb') as f:
return cls.from_file(f)
@classmethod
def from_file(cls, f):
"""Get the contents of a SHA file on disk."""
try:
obj = cls._parse_file(f)
obj._sha = None
return obj
except (IndexError, ValueError):
raise ObjectFormatException("invalid object header")
@staticmethod
def from_raw_string(type_num, string, sha=None):
"""Creates an object of the indicated type from the raw string given.
Args:
type_num: The numeric type of the object.
string: The raw uncompressed contents.
sha: Optional known sha for the object
"""
obj = object_class(type_num)()
obj.set_raw_string(string, sha)
return obj
@staticmethod
def from_raw_chunks(type_num, chunks, sha=None):
"""Creates an object of the indicated type from the raw chunks given.
Args:
type_num: The numeric type of the object.
chunks: An iterable of the raw uncompressed contents.
sha: Optional known sha for the object
"""
obj = object_class(type_num)()
obj.set_raw_chunks(chunks, sha)
return obj
@classmethod
def from_string(cls, string):
"""Create a ShaFile from a string."""
obj = cls()
obj.set_raw_string(string)
return obj
def _check_has_member(self, member, error_msg):
"""Check that the object has a given member variable.
Args:
member: the member variable to check for
error_msg: the message for an error if the member is missing
Raises:
ObjectFormatException: with the given error_msg if member is
missing or is None
"""
if getattr(self, member, None) is None:
raise ObjectFormatException(error_msg)
def check(self):
"""Check this object for internal consistency.
Raises:
ObjectFormatException: if the object is malformed in some way
ChecksumMismatch: if the object was created with a SHA that does
not match its contents
"""
# TODO: if we find that error-checking during object parsing is a
# performance bottleneck, those checks should be moved to the class's
# check() method during optimization so we can still check the object
# when necessary.
old_sha = self.id
try:
self._deserialize(self.as_raw_chunks())
self._sha = None
new_sha = self.id
except Exception as e:
raise ObjectFormatException(e)
if old_sha != new_sha:
raise ChecksumMismatch(new_sha, old_sha)
def _header(self):
return object_header(self.type, self.raw_length())
def raw_length(self):
"""Returns the length of the raw string of this object."""
ret = 0
for chunk in self.as_raw_chunks():
ret += len(chunk)
return ret
def sha(self):
"""The SHA1 object that is the name of this object."""
if self._sha is None or self._needs_serialization:
# this is a local because as_raw_chunks() overwrites self._sha
new_sha = sha1()
new_sha.update(self._header())
for chunk in self.as_raw_chunks():
new_sha.update(chunk)
self._sha = new_sha
return self._sha
def copy(self):
"""Create a new copy of this SHA1 object from its raw string"""
obj_class = object_class(self.get_type())
return obj_class.from_raw_string(
self.get_type(),
self.as_raw_string(),
self.id)
@property
def id(self):
"""The hex SHA of this object."""
return self.sha().hexdigest().encode('ascii')
def get_type(self):
"""Return the type number for this object class."""
return self.type_num
def set_type(self, type):
"""Set the type number for this object class."""
self.type_num = type
# DEPRECATED: use type_num or type_name as needed.
type = property(get_type, set_type)
def __repr__(self):
return "<%s %s>" % (self.__class__.__name__, self.id)
def __ne__(self, other):
"""Check whether this object does not match the other."""
return not isinstance(other, ShaFile) or self.id != other.id
def __eq__(self, other):
"""Return True if the SHAs of the two objects match.
"""
return isinstance(other, ShaFile) and self.id == other.id
def __lt__(self, other):
"""Return whether SHA of this object is less than the other.
"""
if not isinstance(other, ShaFile):
raise TypeError
return self.id < other.id
def __le__(self, other):
"""Check whether SHA of this object is less than or equal to the other.
"""
if not isinstance(other, ShaFile):
raise TypeError
return self.id <= other.id
def __cmp__(self, other):
"""Compare the SHA of this object with that of the other object.
"""
if not isinstance(other, ShaFile):
raise TypeError
return cmp(self.id, other.id) # noqa: F821
class Blob(ShaFile):
"""A Git Blob object."""
__slots__ = ()
type_name = b'blob'
type_num = 3
def __init__(self):
super(Blob, self).__init__()
self._chunked_text = []
self._needs_serialization = False
def _get_data(self):
return self.as_raw_string()
def _set_data(self, data):
self.set_raw_string(data)
data = property(_get_data, _set_data,
doc="The text contained within the blob object.")
def _get_chunked(self):
return self._chunked_text
def _set_chunked(self, chunks):
self._chunked_text = chunks
def _serialize(self):
return self._chunked_text
def _deserialize(self, chunks):
self._chunked_text = chunks
chunked = property(
_get_chunked, _set_chunked,
doc="The text in the blob object, as chunks (not necessarily lines)")
@classmethod
def from_path(cls, path):
blob = ShaFile.from_path(path)
if not isinstance(blob, cls):
raise NotBlobError(path)
return blob
def check(self):
"""Check this object for internal consistency.
Raises:
ObjectFormatException: if the object is malformed in some way
"""
super(Blob, self).check()
def splitlines(self):
"""Return list of lines in this blob.
This preserves the original line endings.
"""
chunks = self.chunked
if not chunks:
return []
if len(chunks) == 1:
return chunks[0].splitlines(True)
remaining = None
ret = []
for chunk in chunks:
lines = chunk.splitlines(True)
if len(lines) > 1:
ret.append((remaining or b"") + lines[0])
ret.extend(lines[1:-1])
remaining = lines[-1]
elif len(lines) == 1:
if remaining is None:
remaining = lines.pop()
else:
remaining += lines.pop()
if remaining is not None:
ret.append(remaining)
return ret
def _parse_message(chunks):
"""Parse a message with a list of fields and a body.
Args:
chunks: the raw chunks of the tag or commit object.
Returns: iterator of tuples of (field, value), one per header line, in the
order read from the text, possibly including duplicates. Includes a
field named None for the freeform tag/commit text.
"""
f = BytesIO(b''.join(chunks))
k = None
v = ""
eof = False
def _strip_last_newline(value):
"""Strip the last newline from value"""
if value and value.endswith(b'\n'):
return value[:-1]
return value
# Parse the headers
#
# Headers can contain newlines. The next line is indented with a space.
# We store the latest key as 'k', and the accumulated value as 'v'.
for line in f:
if line.startswith(b' '):
# Indented continuation of the previous line
v += line[1:]
else:
if k is not None:
# We parsed a new header, return its value
yield (k, _strip_last_newline(v))
if line == b'\n':
# Empty line indicates end of headers
break
(k, v) = line.split(b' ', 1)
else:
# We reached end of file before the headers ended. We still need to
# return the previous header, then we need to return a None field for
# the text.
eof = True
if k is not None:
yield (k, _strip_last_newline(v))
yield (None, None)
if not eof:
# We didn't reach the end of file while parsing headers. We can return
# the rest of the file as a message.
yield (None, f.read())
f.close()
class Tag(ShaFile):
"""A Git Tag object."""
type_name = b'tag'
type_num = 4
__slots__ = ('_tag_timezone_neg_utc', '_name', '_object_sha',
'_object_class', '_tag_time', '_tag_timezone',
'_tagger', '_message', '_signature')
def __init__(self):
super(Tag, self).__init__()
self._tagger = None
self._tag_time = None
self._tag_timezone = None
self._tag_timezone_neg_utc = False
self._signature = None
@classmethod
def from_path(cls, filename):
tag = ShaFile.from_path(filename)
if not isinstance(tag, cls):
raise NotTagError(filename)
return tag
def check(self):
"""Check this object for internal consistency.
Raises:
ObjectFormatException: if the object is malformed in some way
"""
super(Tag, self).check()
self._check_has_member("_object_sha", "missing object sha")
self._check_has_member("_object_class", "missing object type")
self._check_has_member("_name", "missing tag name")
if not self._name:
raise ObjectFormatException("empty tag name")
check_hexsha(self._object_sha, "invalid object sha")
if getattr(self, "_tagger", None):
check_identity(self._tagger, "invalid tagger")
self._check_has_member("_tag_time", "missing tag time")
check_time(self._tag_time)
last = None
for field, _ in _parse_message(self._chunked_text):
if field == _OBJECT_HEADER and last is not None:
raise ObjectFormatException("unexpected object")
elif field == _TYPE_HEADER and last != _OBJECT_HEADER:
raise ObjectFormatException("unexpected type")
elif field == _TAG_HEADER and last != _TYPE_HEADER:
raise ObjectFormatException("unexpected tag name")
elif field == _TAGGER_HEADER and last != _TAG_HEADER:
raise ObjectFormatException("unexpected tagger")
last = field
def _serialize(self):
chunks = []
chunks.append(git_line(_OBJECT_HEADER, self._object_sha))
chunks.append(git_line(_TYPE_HEADER, self._object_class.type_name))
chunks.append(git_line(_TAG_HEADER, self._name))
if self._tagger:
if self._tag_time is None:
chunks.append(git_line(_TAGGER_HEADER, self._tagger))
else:
chunks.append(git_line(
_TAGGER_HEADER, self._tagger,
str(self._tag_time).encode('ascii'),
format_timezone(
self._tag_timezone, self._tag_timezone_neg_utc)))
if self._message is not None:
chunks.append(b'\n') # To close headers
chunks.append(self._message)
if self._signature is not None:
chunks.append(self._signature)
return chunks
def _deserialize(self, chunks):
"""Grab the metadata attached to the tag"""
self._tagger = None
self._tag_time = None
self._tag_timezone = None
self._tag_timezone_neg_utc = False
for field, value in _parse_message(chunks):
if field == _OBJECT_HEADER:
self._object_sha = value
elif field == _TYPE_HEADER:
obj_class = object_class(value)
if not obj_class:
raise ObjectFormatException("Not a known type: %s" % value)
self._object_class = obj_class
elif field == _TAG_HEADER:
self._name = value
elif field == _TAGGER_HEADER:
(self._tagger,
self._tag_time,
(self._tag_timezone,
self._tag_timezone_neg_utc)) = parse_time_entry(value)
elif field is None:
if value is None:
self._message = None
self._signature = None
else:
try:
sig_idx = value.index(BEGIN_PGP_SIGNATURE)
except ValueError:
self._message = value
self._signature = None
else:
self._message = value[:sig_idx]
self._signature = value[sig_idx:]
else:
raise ObjectFormatException("Unknown field %s" % field)
def _get_object(self):
"""Get the object pointed to by this tag.
Returns: tuple of (object class, sha).
"""
return (self._object_class, self._object_sha)
def _set_object(self, value):
(self._object_class, self._object_sha) = value
self._needs_serialization = True
object = property(_get_object, _set_object)
name = serializable_property("name", "The name of this tag")
tagger = serializable_property(
"tagger",
"Returns the name of the person who created this tag")
tag_time = serializable_property(
"tag_time",
"The creation timestamp of the tag. As the number of seconds "
"since the epoch")
tag_timezone = serializable_property(
"tag_timezone",
"The timezone that tag_time is in.")
message = serializable_property(
"message", "the message attached to this tag")
signature = serializable_property(
"signature", "Optional detached GPG signature")
class TreeEntry(namedtuple('TreeEntry', ['path', 'mode', 'sha'])):
"""Named tuple encapsulating a single tree entry."""
def in_path(self, path):
"""Return a copy of this entry with the given path prepended."""
if not isinstance(self.path, bytes):
raise TypeError('Expected bytes for path, got %r' % path)
return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
def parse_tree(text, strict=False):
"""Parse a tree text.
Args:
text: Serialized text to parse
Returns: iterator of tuples of (name, mode, sha)
Raises:
ObjectFormatException: if the object was malformed in some way
"""
count = 0
length = len(text)
while count < length:
mode_end = text.index(b' ', count)
mode_text = text[count:mode_end]
if strict and mode_text.startswith(b'0'):
raise ObjectFormatException("Invalid mode '%s'" % mode_text)
try:
mode = int(mode_text, 8)
except ValueError:
raise ObjectFormatException("Invalid mode '%s'" % mode_text)
name_end = text.index(b'\0', mode_end)
name = text[mode_end+1:name_end]
count = name_end+21
sha = text[name_end+1:count]
if len(sha) != 20:
raise ObjectFormatException("Sha has invalid length")
hexsha = sha_to_hex(sha)
yield (name, mode, hexsha)
def serialize_tree(items):
"""Serialize the items in a tree to a text.
Args:
items: Sorted iterable over (name, mode, sha) tuples
Returns: Serialized tree text as chunks
"""
for name, mode, hexsha in items:
yield (("%04o" % mode).encode('ascii') + b' ' + name +
b'\0' + hex_to_sha(hexsha))
def sorted_tree_items(entries, name_order):
"""Iterate over a tree entries dictionary.
Args:
name_order: If True, iterate entries in order of their name. If
False, iterate entries in tree order, that is, treat subtree entries as
having '/' appended.
entries: Dictionary mapping names to (mode, sha) tuples
Returns: Iterator over (name, mode, hexsha)
"""
key_func = name_order and key_entry_name_order or key_entry
for name, entry in sorted(entries.items(), key=key_func):
mode, hexsha = entry
# Stricter type checks than normal to mirror checks in the C version.
mode = int(mode)
if not isinstance(hexsha, bytes):
raise TypeError('Expected bytes for SHA, got %r' % hexsha)
yield TreeEntry(name, mode, hexsha)
def key_entry(entry):
"""Sort key for tree entry.
Args:
entry: (name, value) tuplee
"""
(name, value) = entry
if stat.S_ISDIR(value[0]):
name += b'/'
return name
def key_entry_name_order(entry):
"""Sort key for tree entry in name order."""
return entry[0]
def pretty_format_tree_entry(name, mode, hexsha, encoding="utf-8"):
"""Pretty format tree entry.
Args:
name: Name of the directory entry
mode: Mode of entry
hexsha: Hexsha of the referenced object
Returns: string describing the tree entry
"""
if mode & stat.S_IFDIR:
kind = "tree"
else:
kind = "blob"
return "%04o %s %s\t%s\n" % (
mode, kind, hexsha.decode('ascii'),
name.decode(encoding, 'replace'))
class Tree(ShaFile):
"""A Git tree object"""
type_name = b'tree'
type_num = 2
__slots__ = ('_entries')
def __init__(self):
super(Tree, self).__init__()
self._entries = {}
@classmethod
def from_path(cls, filename):
tree = ShaFile.from_path(filename)
if not isinstance(tree, cls):
raise NotTreeError(filename)
return tree
def __contains__(self, name):
return name in self._entries
def __getitem__(self, name):
return self._entries[name]
def __setitem__(self, name, value):
"""Set a tree entry by name.
Args:
name: The name of the entry, as a string.
value: A tuple of (mode, hexsha), where mode is the mode of the
entry as an integral type and hexsha is the hex SHA of the entry as
a string.
"""
mode, hexsha = value
self._entries[name] = (mode, hexsha)
self._needs_serialization = True
def __delitem__(self, name):
del self._entries[name]
self._needs_serialization = True
def __len__(self):
return len(self._entries)
def __iter__(self):
return iter(self._entries)
def add(self, name, mode, hexsha):
"""Add an entry to the tree.
Args:
mode: The mode of the entry as an integral type. Not all
possible modes are supported by git; see check() for details.
name: The name of the entry, as a string.
hexsha: The hex SHA of the entry as a string.
"""
if isinstance(name, int) and isinstance(mode, bytes):
(name, mode) = (mode, name)
warnings.warn(
"Please use Tree.add(name, mode, hexsha)",
category=DeprecationWarning, stacklevel=2)
self._entries[name] = mode, hexsha
self._needs_serialization = True
def iteritems(self, name_order=False):
"""Iterate over entries.
Args:
name_order: If True, iterate in name order instead of tree
order.
Returns: Iterator over (name, mode, sha) tuples
"""
return sorted_tree_items(self._entries, name_order)
def items(self):
"""Return the sorted entries in this tree.
Returns: List with (name, mode, sha) tuples
"""
return list(self.iteritems())
def _deserialize(self, chunks):
"""Grab the entries in the tree"""
try:
parsed_entries = parse_tree(b''.join(chunks))
except ValueError as e:
raise ObjectFormatException(e)
# TODO: list comprehension is for efficiency in the common (small)
# case; if memory efficiency in the large case is a concern, use a
# genexp.
self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
def check(self):
"""Check this object for internal consistency.
Raises:
ObjectFormatException: if the object is malformed in some way
"""
super(Tree, self).check()
last = None
allowed_modes = (stat.S_IFREG | 0o755, stat.S_IFREG | 0o644,
stat.S_IFLNK, stat.S_IFDIR, S_IFGITLINK,
# TODO: optionally exclude as in git fsck --strict
stat.S_IFREG | 0o664)
for name, mode, sha in parse_tree(b''.join(self._chunked_text),
True):
check_hexsha(sha, 'invalid sha %s' % sha)
if b'/' in name or name in (b'', b'.', b'..', b'.git'):
raise ObjectFormatException(
'invalid name %s' %
name.decode('utf-8', 'replace'))
if mode not in allowed_modes:
raise ObjectFormatException('invalid mode %06o' % mode)
entry = (name, (mode, sha))
if last:
if key_entry(last) > key_entry(entry):
raise ObjectFormatException('entries not sorted')
if name == last[0]:
raise ObjectFormatException('duplicate entry %s' % name)
last = entry
def _serialize(self):
return list(serialize_tree(self.iteritems()))
def as_pretty_string(self):
text = []
for name, mode, hexsha in self.iteritems():
text.append(pretty_format_tree_entry(name, mode, hexsha))
return "".join(text)
def lookup_path(self, lookup_obj, path):
"""Look up an object in a Git tree.
Args:
lookup_obj: Callback for retrieving object by SHA1
path: Path to lookup
Returns: A tuple of (mode, SHA) of the resulting path.
"""
parts = path.split(b'/')
sha = self.id
mode = None
for p in parts:
if not p:
continue
obj = lookup_obj(sha)
if not isinstance(obj, Tree):
raise NotTreeError(sha)
mode, sha = obj[p]
return mode, sha
def parse_timezone(text):
"""Parse a timezone text fragment (e.g. '+0100').
Args:
text: Text to parse.
Returns: Tuple with timezone as seconds difference to UTC
and a boolean indicating whether this was a UTC timezone
prefixed with a negative sign (-0000).
"""
# cgit parses the first character as the sign, and the rest
# as an integer (using strtol), which could also be negative.
# We do the same for compatibility. See #697828.
if not text[0] in b'+-':
raise ValueError("Timezone must start with + or - (%(text)s)" % vars())
sign = text[:1]
offset = int(text[1:])
if sign == b'-':
offset = -offset
unnecessary_negative_timezone = (offset >= 0 and sign == b'-')
signum = (offset < 0) and -1 or 1
offset = abs(offset)
hours = int(offset / 100)
minutes = (offset % 100)
return (signum * (hours * 3600 + minutes * 60),
unnecessary_negative_timezone)
def format_timezone(offset, unnecessary_negative_timezone=False):
"""Format a timezone for Git serialization.
Args:
offset: Timezone offset as seconds difference to UTC
unnecessary_negative_timezone: Whether to use a minus sign for
UTC or positive timezones (-0000 and --700 rather than +0000 / +0700).
"""
if offset % 60 != 0:
raise ValueError("Unable to handle non-minute offset.")
if offset < 0 or unnecessary_negative_timezone:
sign = '-'
offset = -offset
else:
sign = '+'
return ('%c%02d%02d' %
(sign, offset / 3600, (offset / 60) % 60)).encode('ascii')
def parse_time_entry(value):
"""Parse time entry behavior
Args:
value: Bytes representing a git commit/tag line
Raises:
ObjectFormatException in case of parsing error (malformed
field date)
Returns: Tuple of (author, time, (timezone, timezone_neg_utc))
"""
try:
sep = value.rindex(b'> ')
except ValueError:
return (value, None, (None, False))
try:
person = value[0:sep+1]
rest = value[sep+2:]
timetext, timezonetext = rest.rsplit(b' ', 1)
time = int(timetext)
timezone, timezone_neg_utc = parse_timezone(timezonetext)
except ValueError as e:
raise ObjectFormatException(e)
return person, time, (timezone, timezone_neg_utc)
def parse_commit(chunks):
"""Parse a commit object from chunks.
Args:
chunks: Chunks to parse
Returns: Tuple of (tree, parents, author_info, commit_info,
encoding, mergetag, gpgsig, message, extra)
"""
parents = []
extra = []
tree = None
author_info = (None, None, (None, None))
commit_info = (None, None, (None, None))
encoding = None
mergetag = []
message = None
gpgsig = None
for field, value in _parse_message(chunks):
# TODO(jelmer): Enforce ordering
if field == _TREE_HEADER:
tree = value
elif field == _PARENT_HEADER:
parents.append(value)
elif field == _AUTHOR_HEADER:
author_info = parse_time_entry(value)
elif field == _COMMITTER_HEADER:
commit_info = parse_time_entry(value)
elif field == _ENCODING_HEADER:
encoding = value
elif field == _MERGETAG_HEADER:
mergetag.append(Tag.from_string(value + b'\n'))
elif field == _GPGSIG_HEADER:
gpgsig = value
elif field is None:
message = value
else:
extra.append((field, value))
return (tree, parents, author_info, commit_info, encoding, mergetag,
gpgsig, message, extra)
class Commit(ShaFile):
"""A git commit object"""
type_name = b'commit'
type_num = 1
__slots__ = ('_parents', '_encoding', '_extra', '_author_timezone_neg_utc',
'_commit_timezone_neg_utc', '_commit_time',
'_author_time', '_author_timezone', '_commit_timezone',
'_author', '_committer', '_tree', '_message',
'_mergetag', '_gpgsig')
def __init__(self):
super(Commit, self).__init__()
self._parents = []
self._encoding = None
self._mergetag = []
self._gpgsig = None
self._extra = []
self._author_timezone_neg_utc = False
self._commit_timezone_neg_utc = False
@classmethod
def from_path(cls, path):
commit = ShaFile.from_path(path)
if not isinstance(commit, cls):
raise NotCommitError(path)
return commit
def _deserialize(self, chunks):
(self._tree, self._parents, author_info, commit_info, self._encoding,
self._mergetag, self._gpgsig, self._message, self._extra) = (
parse_commit(chunks))
(self._author, self._author_time,
(self._author_timezone, self._author_timezone_neg_utc)) = author_info
(self._committer, self._commit_time,
(self._commit_timezone, self._commit_timezone_neg_utc)) = commit_info
def check(self):
"""Check this object for internal consistency.
Raises:
ObjectFormatException: if the object is malformed in some way
"""
super(Commit, self).check()
self._check_has_member("_tree", "missing tree")
self._check_has_member("_author", "missing author")
self._check_has_member("_committer", "missing committer")
self._check_has_member("_author_time", "missing author time")
self._check_has_member("_commit_time", "missing commit time")
for parent in self._parents:
check_hexsha(parent, "invalid parent sha")
check_hexsha(self._tree, "invalid tree sha")
check_identity(self._author, "invalid author")
check_identity(self._committer, "invalid committer")
check_time(self._author_time)
check_time(self._commit_time)
last = None
for field, _ in _parse_message(self._chunked_text):
if field == _TREE_HEADER and last is not None:
raise ObjectFormatException("unexpected tree")
elif field == _PARENT_HEADER and last not in (_PARENT_HEADER,
_TREE_HEADER):
raise ObjectFormatException("unexpected parent")
elif field == _AUTHOR_HEADER and last not in (_TREE_HEADER,
_PARENT_HEADER):
raise ObjectFormatException("unexpected author")
elif field == _COMMITTER_HEADER and last != _AUTHOR_HEADER:
raise ObjectFormatException("unexpected committer")
elif field == _ENCODING_HEADER and last != _COMMITTER_HEADER:
raise ObjectFormatException("unexpected encoding")
last = field
# TODO: optionally check for duplicate parents
def _serialize(self):
chunks = []
tree_bytes = (
self._tree.id if isinstance(self._tree, Tree) else self._tree)
chunks.append(git_line(_TREE_HEADER, tree_bytes))
for p in self._parents:
chunks.append(git_line(_PARENT_HEADER, p))
chunks.append(git_line(
_AUTHOR_HEADER, self._author,
str(self._author_time).encode('ascii'),
format_timezone(
self._author_timezone, self._author_timezone_neg_utc)))
chunks.append(git_line(
_COMMITTER_HEADER, self._committer,
str(self._commit_time).encode('ascii'),
format_timezone(self._commit_timezone,
self._commit_timezone_neg_utc)))
if self.encoding:
chunks.append(git_line(_ENCODING_HEADER, self.encoding))
for mergetag in self.mergetag:
mergetag_chunks = mergetag.as_raw_string().split(b'\n')
chunks.append(git_line(_MERGETAG_HEADER, mergetag_chunks[0]))
# Embedded extra header needs leading space
for chunk in mergetag_chunks[1:]:
chunks.append(b' ' + chunk + b'\n')
# No trailing empty line
if chunks[-1].endswith(b' \n'):
chunks[-1] = chunks[-1][:-2]
for k, v in self.extra:
if b'\n' in k or b'\n' in v:
raise AssertionError(
"newline in extra data: %r -> %r" % (k, v))
chunks.append(git_line(k, v))
if self.gpgsig:
sig_chunks = self.gpgsig.split(b'\n')
chunks.append(git_line(_GPGSIG_HEADER, sig_chunks[0]))
for chunk in sig_chunks[1:]:
chunks.append(git_line(b'', chunk))
chunks.append(b'\n') # There must be a new line after the headers
chunks.append(self._message)
return chunks
tree = serializable_property(
"tree", "Tree that is the state of this commit")
def _get_parents(self):
"""Return a list of parents of this commit."""
return self._parents
def _set_parents(self, value):
"""Set a list of parents of this commit."""
self._needs_serialization = True
self._parents = value
parents = property(_get_parents, _set_parents,
doc="Parents of this commit, by their SHA1.")
def _get_extra(self):
"""Return extra settings of this commit."""
return self._extra
extra = property(
_get_extra,
doc="Extra header fields not understood (presumably added in a "
"newer version of git). Kept verbatim so the object can "
"be correctly reserialized. For private commit metadata, use "
"pseudo-headers in Commit.message, rather than this field.")
author = serializable_property(
"author",
"The name of the author of the commit")
committer = serializable_property(
"committer",
"The name of the committer of the commit")
message = serializable_property(
"message", "The commit message")
commit_time = serializable_property(
"commit_time",
"The timestamp of the commit. As the number of seconds since the "
"epoch.")
commit_timezone = serializable_property(
"commit_timezone",
"The zone the commit time is in")
author_time = serializable_property(
"author_time",
"The timestamp the commit was written. As the number of "
"seconds since the epoch.")
author_timezone = serializable_property(
"author_timezone", "Returns the zone the author time is in.")
encoding = serializable_property(
"encoding", "Encoding of the commit message.")
mergetag = serializable_property(
"mergetag", "Associated signed tag.")
gpgsig = serializable_property(
"gpgsig", "GPG Signature.")
OBJECT_CLASSES = (
Commit,
Tree,
Blob,
Tag,
)
_TYPE_MAP: Dict[Union[bytes, int], Type[ShaFile]] = {}
for cls in OBJECT_CLASSES:
_TYPE_MAP[cls.type_name] = cls
_TYPE_MAP[cls.type_num] = cls
# Hold on to the pure-python implementations for testing
_parse_tree_py = parse_tree
_sorted_tree_items_py = sorted_tree_items
try:
# Try to import C versions
from dulwich._objects import parse_tree, sorted_tree_items # type: ignore
except ImportError:
pass
diff --git a/dulwich/porcelain.py b/dulwich/porcelain.py
index 443083b6..b0b669d0 100644
--- a/dulwich/porcelain.py
+++ b/dulwich/porcelain.py
@@ -1,1615 +1,1615 @@
# porcelain.py -- Porcelain-like layer on top of Dulwich
# Copyright (C) 2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Simple wrapper that provides porcelain-like functions on top of Dulwich.
Currently implemented:
* archive
* add
* branch{_create,_delete,_list}
* check-ignore
* checkout
* clone
* commit
* commit-tree
* daemon
* describe
* diff-tree
* fetch
* init
* ls-files
* ls-remote
* ls-tree
* pull
* push
* rm
* remote{_add}
* receive-pack
* reset
* rev-list
* tag{_create,_delete,_list}
* upload-pack
* update-server-info
* status
* symbolic-ref
These functions are meant to behave similarly to the git subcommands.
Differences in behaviour are considered bugs.
Functions should generally accept both unicode strings and bytestrings
"""
from collections import namedtuple
from contextlib import (
closing,
contextmanager,
)
from io import BytesIO, RawIOBase
import datetime
import os
import posixpath
import shutil
import stat
import sys
import time
from dulwich.archive import (
tar_stream,
)
from dulwich.client import (
get_transport_and_path,
)
from dulwich.config import (
StackedConfig,
)
from dulwich.diff_tree import (
CHANGE_ADD,
CHANGE_DELETE,
CHANGE_MODIFY,
CHANGE_RENAME,
CHANGE_COPY,
RENAME_CHANGE_TYPES,
)
from dulwich.errors import (
SendPackError,
UpdateRefsError,
)
from dulwich.ignore import IgnoreFilterManager
from dulwich.index import (
blob_from_path_and_stat,
get_unstaged_changes,
)
from dulwich.object_store import (
tree_lookup_path,
)
from dulwich.objects import (
Commit,
Tag,
format_timezone,
parse_timezone,
pretty_format_tree_entry,
)
from dulwich.objectspec import (
parse_commit,
parse_object,
parse_ref,
parse_reftuples,
parse_tree,
)
from dulwich.pack import (
write_pack_index,
write_pack_objects,
)
from dulwich.patch import write_tree_diff
from dulwich.protocol import (
Protocol,
ZERO_SHA,
)
from dulwich.refs import (
ANNOTATED_TAG_SUFFIX,
LOCAL_BRANCH_PREFIX,
strip_peeled_refs,
)
from dulwich.repo import (BaseRepo, Repo)
from dulwich.server import (
FileSystemBackend,
TCPGitServer,
ReceivePackHandler,
UploadPackHandler,
update_server_info as server_update_server_info,
)
# Module level tuple definition for status output
GitStatus = namedtuple('GitStatus', 'staged unstaged untracked')
class NoneStream(RawIOBase):
"""Fallback if stdout or stderr are unavailable, does nothing."""
def read(self, size=-1):
return None
def readall(self):
return None
def readinto(self, b):
return None
def write(self, b):
return None
default_bytes_out_stream = (
getattr(sys.stdout, 'buffer', None) or NoneStream())
default_bytes_err_stream = (
getattr(sys.stderr, 'buffer', None) or NoneStream())
DEFAULT_ENCODING = 'utf-8'
class RemoteExists(Exception):
"""Raised when the remote already exists."""
def open_repo(path_or_repo):
"""Open an argument that can be a repository or a path for a repository."""
if isinstance(path_or_repo, BaseRepo):
return path_or_repo
return Repo(path_or_repo)
@contextmanager
def _noop_context_manager(obj):
"""Context manager that has the same api as closing but does nothing."""
yield obj
def open_repo_closing(path_or_repo):
"""Open an argument that can be a repository or a path for a repository.
returns a context manager that will close the repo on exit if the argument
is a path, else does nothing if the argument is a repo.
"""
if isinstance(path_or_repo, BaseRepo):
return _noop_context_manager(path_or_repo)
return closing(Repo(path_or_repo))
def path_to_tree_path(repopath, path):
"""Convert a path to a path usable in an index, e.g. bytes and relative to
the repository root.
Args:
repopath: Repository path, absolute or relative to the cwd
path: A path, absolute or relative to the cwd
Returns: A path formatted for use in e.g. an index
"""
if not isinstance(path, bytes):
path = os.fsencode(path)
if not isinstance(repopath, bytes):
repopath = os.fsencode(repopath)
treepath = os.path.relpath(path, repopath)
if treepath.startswith(b'..'):
raise ValueError('Path not in repo')
if os.path.sep != '/':
treepath = treepath.replace(os.path.sep.encode('ascii'), b'/')
return treepath
def archive(repo, committish=None, outstream=default_bytes_out_stream,
errstream=default_bytes_err_stream):
"""Create an archive.
Args:
repo: Path of repository for which to generate an archive.
committish: Commit SHA1 or ref to use
outstream: Output stream (defaults to stdout)
errstream: Error stream (defaults to stderr)
"""
if committish is None:
committish = "HEAD"
with open_repo_closing(repo) as repo_obj:
c = parse_commit(repo_obj, committish)
for chunk in tar_stream(
repo_obj.object_store, repo_obj.object_store[c.tree],
c.commit_time):
outstream.write(chunk)
def update_server_info(repo="."):
"""Update server info files for a repository.
Args:
repo: path to the repository
"""
with open_repo_closing(repo) as r:
server_update_server_info(r)
def symbolic_ref(repo, ref_name, force=False):
"""Set git symbolic ref into HEAD.
Args:
repo: path to the repository
ref_name: short name of the new ref
force: force settings without checking if it exists in refs/heads
"""
with open_repo_closing(repo) as repo_obj:
ref_path = _make_branch_ref(ref_name)
if not force and ref_path not in repo_obj.refs.keys():
raise ValueError('fatal: ref `%s` is not a ref' % ref_name)
repo_obj.refs.set_symbolic_ref(b'HEAD', ref_path)
def commit(repo=".", message=None, author=None, committer=None, encoding=None):
"""Create a new commit.
Args:
repo: Path to repository
message: Optional commit message
author: Optional author name and email
committer: Optional committer name and email
Returns: SHA1 of the new commit
"""
# FIXME: Support --all argument
# FIXME: Support --signoff argument
if getattr(message, 'encode', None):
message = message.encode(encoding or DEFAULT_ENCODING)
if getattr(author, 'encode', None):
author = author.encode(encoding or DEFAULT_ENCODING)
if getattr(committer, 'encode', None):
committer = committer.encode(encoding or DEFAULT_ENCODING)
with open_repo_closing(repo) as r:
return r.do_commit(
message=message, author=author, committer=committer,
encoding=encoding)
def commit_tree(repo, tree, message=None, author=None, committer=None):
"""Create a new commit object.
Args:
repo: Path to repository
tree: An existing tree object
author: Optional author name and email
committer: Optional committer name and email
"""
with open_repo_closing(repo) as r:
return r.do_commit(
message=message, tree=tree, committer=committer, author=author)
def init(path=".", bare=False):
"""Create a new git repository.
Args:
path: Path to repository.
bare: Whether to create a bare repository.
Returns: A Repo instance
"""
if not os.path.exists(path):
os.mkdir(path)
if bare:
return Repo.init_bare(path)
else:
return Repo.init(path)
def clone(source, target=None, bare=False, checkout=None,
errstream=default_bytes_err_stream, outstream=None,
origin=b"origin", depth=None, **kwargs):
"""Clone a local or remote git repository.
Args:
source: Path or URL for source repository
target: Path to target repository (optional)
bare: Whether or not to create a bare repository
checkout: Whether or not to check-out HEAD after cloning
errstream: Optional stream to write progress to
outstream: Optional stream to write progress to (deprecated)
origin: Name of remote from the repository used to clone
depth: Depth to fetch at
Returns: The new repository
"""
# TODO(jelmer): This code overlaps quite a bit with Repo.clone
if outstream is not None:
import warnings
warnings.warn(
"outstream= has been deprecated in favour of errstream=.",
DeprecationWarning, stacklevel=3)
errstream = outstream
if checkout is None:
checkout = (not bare)
if checkout and bare:
raise ValueError("checkout and bare are incompatible")
if target is None:
target = source.split("/")[-1]
if not os.path.exists(target):
os.mkdir(target)
if bare:
r = Repo.init_bare(target)
else:
r = Repo.init(target)
reflog_message = b'clone: from ' + source.encode('utf-8')
try:
fetch_result = fetch(
r, source, origin, errstream=errstream, message=reflog_message,
depth=depth, **kwargs)
target_config = r.get_config()
if not isinstance(source, bytes):
source = source.encode(DEFAULT_ENCODING)
target_config.set((b'remote', origin), b'url', source)
target_config.set(
(b'remote', origin), b'fetch',
b'+refs/heads/*:refs/remotes/' + origin + b'/*')
target_config.write_to_path()
# TODO(jelmer): Support symref capability,
# https://github.com/jelmer/dulwich/issues/485
try:
head = r[fetch_result[b'HEAD']]
except KeyError:
head = None
else:
r[b'HEAD'] = head.id
if checkout and not bare and head is not None:
errstream.write(b'Checking out ' + head.id + b'\n')
r.reset_index(head.tree)
except BaseException:
shutil.rmtree(target)
r.close()
raise
return r
def add(repo=".", paths=None):
"""Add files to the staging area.
Args:
repo: Repository for the files
paths: Paths to add. No value passed stages all modified files.
Returns: Tuple with set of added files and ignored files
"""
ignored = set()
with open_repo_closing(repo) as r:
ignore_manager = IgnoreFilterManager.from_repo(r)
if not paths:
paths = list(
get_untracked_paths(os.getcwd(), r.path, r.open_index()))
relpaths = []
if not isinstance(paths, list):
paths = [paths]
for p in paths:
relpath = os.path.relpath(p, r.path)
if relpath.startswith('..' + os.path.sep):
raise ValueError('path %r is not in repo' % relpath)
# FIXME: Support patterns, directories.
if ignore_manager.is_ignored(relpath):
ignored.add(relpath)
continue
relpaths.append(relpath)
r.stage(relpaths)
return (relpaths, ignored)
def _is_subdir(subdir, parentdir):
"""Check whether subdir is parentdir or a subdir of parentdir
If parentdir or subdir is a relative path, it will be disamgibuated
relative to the pwd.
"""
parentdir_abs = os.path.realpath(parentdir) + os.path.sep
subdir_abs = os.path.realpath(subdir) + os.path.sep
return subdir_abs.startswith(parentdir_abs)
# TODO: option to remove ignored files also, in line with `git clean -fdx`
def clean(repo=".", target_dir=None):
"""Remove any untracked files from the target directory recursively
Equivalent to running `git clean -fd` in target_dir.
Args:
repo: Repository where the files may be tracked
target_dir: Directory to clean - current directory if None
"""
if target_dir is None:
target_dir = os.getcwd()
with open_repo_closing(repo) as r:
if not _is_subdir(target_dir, r.path):
raise ValueError("target_dir must be in the repo's working dir")
index = r.open_index()
ignore_manager = IgnoreFilterManager.from_repo(r)
paths_in_wd = _walk_working_dir_paths(target_dir, r.path)
# Reverse file visit order, so that files and subdirectories are
# removed before containing directory
for ap, is_dir in reversed(list(paths_in_wd)):
if is_dir:
# All subdirectories and files have been removed if untracked,
# so dir contains no tracked files iff it is empty.
is_empty = len(os.listdir(ap)) == 0
if is_empty:
os.rmdir(ap)
else:
ip = path_to_tree_path(r.path, ap)
is_tracked = ip in index
rp = os.path.relpath(ap, r.path)
is_ignored = ignore_manager.is_ignored(rp)
if not is_tracked and not is_ignored:
os.remove(ap)
def remove(repo=".", paths=None, cached=False):
"""Remove files from the staging area.
Args:
repo: Repository for the files
paths: Paths to remove
"""
with open_repo_closing(repo) as r:
index = r.open_index()
for p in paths:
full_path = os.fsencode(os.path.abspath(p))
tree_path = path_to_tree_path(r.path, p)
try:
index_sha = index[tree_path].sha
except KeyError:
raise Exception('%s did not match any files' % p)
if not cached:
try:
st = os.lstat(full_path)
except OSError:
pass
else:
try:
blob = blob_from_path_and_stat(full_path, st)
except IOError:
pass
else:
try:
committed_sha = tree_lookup_path(
r.__getitem__, r[r.head()].tree, tree_path)[1]
except KeyError:
committed_sha = None
if blob.id != index_sha and index_sha != committed_sha:
raise Exception(
'file has staged content differing '
'from both the file and head: %s' % p)
if index_sha != committed_sha:
raise Exception(
'file has staged changes: %s' % p)
os.remove(full_path)
del index[tree_path]
index.write()
rm = remove
def commit_decode(commit, contents, default_encoding=DEFAULT_ENCODING):
if commit.encoding:
encoding = commit.encoding.decode('ascii')
else:
encoding = default_encoding
return contents.decode(encoding, "replace")
def commit_encode(commit, contents, default_encoding=DEFAULT_ENCODING):
if commit.encoding:
encoding = commit.encoding.decode('ascii')
else:
encoding = default_encoding
return contents.encode(encoding)
def print_commit(commit, decode, outstream=sys.stdout):
"""Write a human-readable commit log entry.
Args:
commit: A `Commit` object
outstream: A stream file to write to
"""
outstream.write("-" * 50 + "\n")
outstream.write("commit: " + commit.id.decode('ascii') + "\n")
if len(commit.parents) > 1:
outstream.write(
"merge: " +
"...".join([c.decode('ascii') for c in commit.parents[1:]]) + "\n")
outstream.write("Author: " + decode(commit.author) + "\n")
if commit.author != commit.committer:
outstream.write("Committer: " + decode(commit.committer) + "\n")
time_tuple = time.gmtime(commit.author_time + commit.author_timezone)
time_str = time.strftime("%a %b %d %Y %H:%M:%S", time_tuple)
timezone_str = format_timezone(commit.author_timezone).decode('ascii')
outstream.write("Date: " + time_str + " " + timezone_str + "\n")
outstream.write("\n")
outstream.write(decode(commit.message) + "\n")
outstream.write("\n")
def print_tag(tag, decode, outstream=sys.stdout):
"""Write a human-readable tag.
Args:
tag: A `Tag` object
decode: Function for decoding bytes to unicode string
outstream: A stream to write to
"""
outstream.write("Tagger: " + decode(tag.tagger) + "\n")
time_tuple = time.gmtime(tag.tag_time + tag.tag_timezone)
time_str = time.strftime("%a %b %d %Y %H:%M:%S", time_tuple)
timezone_str = format_timezone(tag.tag_timezone).decode('ascii')
outstream.write("Date: " + time_str + " " + timezone_str + "\n")
outstream.write("\n")
outstream.write(decode(tag.message) + "\n")
outstream.write("\n")
def show_blob(repo, blob, decode, outstream=sys.stdout):
"""Write a blob to a stream.
Args:
repo: A `Repo` object
blob: A `Blob` object
decode: Function for decoding bytes to unicode string
outstream: A stream file to write to
"""
outstream.write(decode(blob.data))
def show_commit(repo, commit, decode, outstream=sys.stdout):
"""Show a commit to a stream.
Args:
repo: A `Repo` object
commit: A `Commit` object
decode: Function for decoding bytes to unicode string
outstream: Stream to write to
"""
print_commit(commit, decode=decode, outstream=outstream)
if commit.parents:
parent_commit = repo[commit.parents[0]]
base_tree = parent_commit.tree
else:
base_tree = None
diffstream = BytesIO()
write_tree_diff(
diffstream,
repo.object_store, base_tree, commit.tree)
diffstream.seek(0)
outstream.write(commit_decode(commit, diffstream.getvalue()))
def show_tree(repo, tree, decode, outstream=sys.stdout):
"""Print a tree to a stream.
Args:
repo: A `Repo` object
tree: A `Tree` object
decode: Function for decoding bytes to unicode string
outstream: Stream to write to
"""
for n in tree:
outstream.write(decode(n) + "\n")
def show_tag(repo, tag, decode, outstream=sys.stdout):
"""Print a tag to a stream.
Args:
repo: A `Repo` object
tag: A `Tag` object
decode: Function for decoding bytes to unicode string
outstream: Stream to write to
"""
print_tag(tag, decode, outstream)
show_object(repo, repo[tag.object[1]], decode, outstream)
def show_object(repo, obj, decode, outstream):
return {
b"tree": show_tree,
b"blob": show_blob,
b"commit": show_commit,
b"tag": show_tag,
}[obj.type_name](repo, obj, decode, outstream)
def print_name_status(changes):
"""Print a simple status summary, listing changed files.
"""
for change in changes:
if not change:
continue
if isinstance(change, list):
change = change[0]
if change.type == CHANGE_ADD:
path1 = change.new.path
path2 = ''
kind = 'A'
elif change.type == CHANGE_DELETE:
path1 = change.old.path
path2 = ''
kind = 'D'
elif change.type == CHANGE_MODIFY:
path1 = change.new.path
path2 = ''
kind = 'M'
elif change.type in RENAME_CHANGE_TYPES:
path1 = change.old.path
path2 = change.new.path
if change.type == CHANGE_RENAME:
kind = 'R'
elif change.type == CHANGE_COPY:
kind = 'C'
yield '%-8s%-20s%-20s' % (kind, path1, path2)
def log(repo=".", paths=None, outstream=sys.stdout, max_entries=None,
reverse=False, name_status=False):
"""Write commit logs.
Args:
repo: Path to repository
paths: Optional set of specific paths to print entries for
outstream: Stream to write log output to
reverse: Reverse order in which entries are printed
name_status: Print name status
max_entries: Optional maximum number of entries to display
"""
with open_repo_closing(repo) as r:
walker = r.get_walker(
max_entries=max_entries, paths=paths, reverse=reverse)
for entry in walker:
def decode(x):
return commit_decode(entry.commit, x)
print_commit(entry.commit, decode, outstream)
if name_status:
outstream.writelines(
- [l+'\n' for l in print_name_status(entry.changes())])
+ [line+'\n' for line in print_name_status(entry.changes())])
# TODO(jelmer): better default for encoding?
def show(repo=".", objects=None, outstream=sys.stdout,
default_encoding=DEFAULT_ENCODING):
"""Print the changes in a commit.
Args:
repo: Path to repository
objects: Objects to show (defaults to [HEAD])
outstream: Stream to write to
default_encoding: Default encoding to use if none is set in the
commit
"""
if objects is None:
objects = ["HEAD"]
if not isinstance(objects, list):
objects = [objects]
with open_repo_closing(repo) as r:
for objectish in objects:
o = parse_object(r, objectish)
if isinstance(o, Commit):
def decode(x):
return commit_decode(o, x, default_encoding)
else:
def decode(x):
return x.decode(default_encoding)
show_object(r, o, decode, outstream)
def diff_tree(repo, old_tree, new_tree, outstream=sys.stdout):
"""Compares the content and mode of blobs found via two tree objects.
Args:
repo: Path to repository
old_tree: Id of old tree
new_tree: Id of new tree
outstream: Stream to write to
"""
with open_repo_closing(repo) as r:
write_tree_diff(outstream, r.object_store, old_tree, new_tree)
def rev_list(repo, commits, outstream=sys.stdout):
"""Lists commit objects in reverse chronological order.
Args:
repo: Path to repository
commits: Commits over which to iterate
outstream: Stream to write to
"""
with open_repo_closing(repo) as r:
for entry in r.get_walker(include=[r[c].id for c in commits]):
outstream.write(entry.commit.id + b"\n")
def tag(*args, **kwargs):
import warnings
warnings.warn("tag has been deprecated in favour of tag_create.",
DeprecationWarning)
return tag_create(*args, **kwargs)
def tag_create(
repo, tag, author=None, message=None, annotated=False,
objectish="HEAD", tag_time=None, tag_timezone=None,
sign=False):
"""Creates a tag in git via dulwich calls:
Args:
repo: Path to repository
tag: tag string
author: tag author (optional, if annotated is set)
message: tag message (optional)
annotated: whether to create an annotated tag
objectish: object the tag should point at, defaults to HEAD
tag_time: Optional time for annotated tag
tag_timezone: Optional timezone for annotated tag
sign: GPG Sign the tag
"""
with open_repo_closing(repo) as r:
object = parse_object(r, objectish)
if annotated:
# Create the tag object
tag_obj = Tag()
if author is None:
# TODO(jelmer): Don't use repo private method.
author = r._get_user_identity(r.get_config_stack())
tag_obj.tagger = author
tag_obj.message = message
tag_obj.name = tag
tag_obj.object = (type(object), object.id)
if tag_time is None:
tag_time = int(time.time())
tag_obj.tag_time = tag_time
if tag_timezone is None:
# TODO(jelmer) Use current user timezone rather than UTC
tag_timezone = 0
elif isinstance(tag_timezone, str):
tag_timezone = parse_timezone(tag_timezone)
tag_obj.tag_timezone = tag_timezone
if sign:
import gpg
with gpg.Context(armor=True) as c:
tag_obj.signature, unused_result = c.sign(
tag_obj.as_raw_string())
r.object_store.add_object(tag_obj)
tag_id = tag_obj.id
else:
tag_id = object.id
r.refs[_make_tag_ref(tag)] = tag_id
def list_tags(*args, **kwargs):
import warnings
warnings.warn("list_tags has been deprecated in favour of tag_list.",
DeprecationWarning)
return tag_list(*args, **kwargs)
def tag_list(repo, outstream=sys.stdout):
"""List all tags.
Args:
repo: Path to repository
outstream: Stream to write tags to
"""
with open_repo_closing(repo) as r:
tags = sorted(r.refs.as_dict(b"refs/tags"))
return tags
def tag_delete(repo, name):
"""Remove a tag.
Args:
repo: Path to repository
name: Name of tag to remove
"""
with open_repo_closing(repo) as r:
if isinstance(name, bytes):
names = [name]
elif isinstance(name, list):
names = name
else:
raise TypeError("Unexpected tag name type %r" % name)
for name in names:
del r.refs[_make_tag_ref(name)]
def reset(repo, mode, treeish="HEAD"):
"""Reset current HEAD to the specified state.
Args:
repo: Path to repository
mode: Mode ("hard", "soft", "mixed")
treeish: Treeish to reset to
"""
if mode != "hard":
raise ValueError("hard is the only mode currently supported")
with open_repo_closing(repo) as r:
tree = parse_tree(r, treeish)
r.reset_index(tree.id)
def push(repo, remote_location, refspecs,
outstream=default_bytes_out_stream,
errstream=default_bytes_err_stream, **kwargs):
"""Remote push with dulwich via dulwich.client
Args:
repo: Path to repository
remote_location: Location of the remote
refspecs: Refs to push to remote
outstream: A stream file to write output
errstream: A stream file to write errors
"""
# Open the repo
with open_repo_closing(repo) as r:
# Get the client and path
client, path = get_transport_and_path(
remote_location, config=r.get_config_stack(), **kwargs)
selected_refs = []
def update_refs(refs):
selected_refs.extend(parse_reftuples(r.refs, refs, refspecs))
new_refs = {}
# TODO: Handle selected_refs == {None: None}
for (lh, rh, force) in selected_refs:
if lh is None:
new_refs[rh] = ZERO_SHA
else:
new_refs[rh] = r.refs[lh]
return new_refs
err_encoding = getattr(errstream, 'encoding', None) or DEFAULT_ENCODING
remote_location_bytes = client.get_url(path).encode(err_encoding)
try:
client.send_pack(
path, update_refs,
generate_pack_data=r.generate_pack_data,
progress=errstream.write)
errstream.write(
b"Push to " + remote_location_bytes + b" successful.\n")
except UpdateRefsError as e:
errstream.write(b"Push to " + remote_location_bytes +
b" failed -> " + e.message.encode(err_encoding) +
b"\n")
except SendPackError as e:
errstream.write(b"Push to " + remote_location_bytes +
b" failed -> " + e.args[0] + b"\n")
def pull(repo, remote_location=None, refspecs=None,
outstream=default_bytes_out_stream,
errstream=default_bytes_err_stream, **kwargs):
"""Pull from remote via dulwich.client
Args:
repo: Path to repository
remote_location: Location of the remote
refspec: refspecs to fetch
outstream: A stream file to write to output
errstream: A stream file to write to errors
"""
# Open the repo
with open_repo_closing(repo) as r:
if remote_location is None:
config = r.get_config()
remote_name = get_branch_remote(r.path)
section = (b'remote', remote_name)
if config.has_section(section):
url = config.get(section, 'url')
remote_location = url.decode()
if refspecs is None:
refspecs = [b"HEAD"]
selected_refs = []
def determine_wants(remote_refs):
selected_refs.extend(
parse_reftuples(remote_refs, r.refs, refspecs))
return [remote_refs[lh] for (lh, rh, force) in selected_refs]
client, path = get_transport_and_path(
remote_location, config=r.get_config_stack(), **kwargs)
fetch_result = client.fetch(
path, r, progress=errstream.write, determine_wants=determine_wants)
for (lh, rh, force) in selected_refs:
r.refs[rh] = fetch_result.refs[lh]
if selected_refs:
r[b'HEAD'] = fetch_result.refs[selected_refs[0][1]]
# Perform 'git checkout .' - syncs staged changes
tree = r[b"HEAD"].tree
r.reset_index(tree=tree)
def status(repo=".", ignored=False):
"""Returns staged, unstaged, and untracked changes relative to the HEAD.
Args:
repo: Path to repository or repository object
ignored: Whether to include ignored files in `untracked`
Returns: GitStatus tuple,
staged - dict with lists of staged paths (diff index/HEAD)
unstaged - list of unstaged paths (diff index/working-tree)
untracked - list of untracked, un-ignored & non-.git paths
"""
with open_repo_closing(repo) as r:
# 1. Get status of staged
tracked_changes = get_tree_changes(r)
# 2. Get status of unstaged
index = r.open_index()
normalizer = r.get_blob_normalizer()
filter_callback = normalizer.checkin_normalize
unstaged_changes = list(
get_unstaged_changes(index, r.path, filter_callback)
)
ignore_manager = IgnoreFilterManager.from_repo(r)
untracked_paths = get_untracked_paths(r.path, r.path, index)
if ignored:
untracked_changes = list(untracked_paths)
else:
untracked_changes = [
p for p in untracked_paths
if not ignore_manager.is_ignored(p)]
return GitStatus(tracked_changes, unstaged_changes, untracked_changes)
def _walk_working_dir_paths(frompath, basepath):
"""Get path, is_dir for files in working dir from frompath
Args:
frompath: Path to begin walk
basepath: Path to compare to
"""
for dirpath, dirnames, filenames in os.walk(frompath):
# Skip .git and below.
if '.git' in dirnames:
dirnames.remove('.git')
if dirpath != basepath:
continue
if '.git' in filenames:
filenames.remove('.git')
if dirpath != basepath:
continue
if dirpath != frompath:
yield dirpath, True
for filename in filenames:
filepath = os.path.join(dirpath, filename)
yield filepath, False
def get_untracked_paths(frompath, basepath, index):
"""Get untracked paths.
Args:
;param frompath: Path to walk
basepath: Path to compare to
index: Index to check against
"""
for ap, is_dir in _walk_working_dir_paths(frompath, basepath):
if not is_dir:
ip = path_to_tree_path(basepath, ap)
if ip not in index:
yield os.path.relpath(ap, frompath)
def get_tree_changes(repo):
"""Return add/delete/modify changes to tree by comparing index to HEAD.
Args:
repo: repo path or object
Returns: dict with lists for each type of change
"""
with open_repo_closing(repo) as r:
index = r.open_index()
# Compares the Index to the HEAD & determines changes
# Iterate through the changes and report add/delete/modify
# TODO: call out to dulwich.diff_tree somehow.
tracked_changes = {
'add': [],
'delete': [],
'modify': [],
}
try:
tree_id = r[b'HEAD'].tree
except KeyError:
tree_id = None
for change in index.changes_from_tree(r.object_store, tree_id):
if not change[0][0]:
tracked_changes['add'].append(change[0][1])
elif not change[0][1]:
tracked_changes['delete'].append(change[0][0])
elif change[0][0] == change[0][1]:
tracked_changes['modify'].append(change[0][0])
else:
raise AssertionError('git mv ops not yet supported')
return tracked_changes
def daemon(path=".", address=None, port=None):
"""Run a daemon serving Git requests over TCP/IP.
Args:
path: Path to the directory to serve.
address: Optional address to listen on (defaults to ::)
port: Optional port to listen on (defaults to TCP_GIT_PORT)
"""
# TODO(jelmer): Support git-daemon-export-ok and --export-all.
backend = FileSystemBackend(path)
server = TCPGitServer(backend, address, port)
server.serve_forever()
def web_daemon(path=".", address=None, port=None):
"""Run a daemon serving Git requests over HTTP.
Args:
path: Path to the directory to serve
address: Optional address to listen on (defaults to ::)
port: Optional port to listen on (defaults to 80)
"""
from dulwich.web import (
make_wsgi_chain,
make_server,
WSGIRequestHandlerLogger,
WSGIServerLogger)
backend = FileSystemBackend(path)
app = make_wsgi_chain(backend)
server = make_server(address, port, app,
handler_class=WSGIRequestHandlerLogger,
server_class=WSGIServerLogger)
server.serve_forever()
def upload_pack(path=".", inf=None, outf=None):
"""Upload a pack file after negotiating its contents using smart protocol.
Args:
path: Path to the repository
inf: Input stream to communicate with client
outf: Output stream to communicate with client
"""
if outf is None:
outf = getattr(sys.stdout, 'buffer', sys.stdout)
if inf is None:
inf = getattr(sys.stdin, 'buffer', sys.stdin)
path = os.path.expanduser(path)
backend = FileSystemBackend(path)
def send_fn(data):
outf.write(data)
outf.flush()
proto = Protocol(inf.read, send_fn)
handler = UploadPackHandler(backend, [path], proto)
# FIXME: Catch exceptions and write a single-line summary to outf.
handler.handle()
return 0
def receive_pack(path=".", inf=None, outf=None):
"""Receive a pack file after negotiating its contents using smart protocol.
Args:
path: Path to the repository
inf: Input stream to communicate with client
outf: Output stream to communicate with client
"""
if outf is None:
outf = getattr(sys.stdout, 'buffer', sys.stdout)
if inf is None:
inf = getattr(sys.stdin, 'buffer', sys.stdin)
path = os.path.expanduser(path)
backend = FileSystemBackend(path)
def send_fn(data):
outf.write(data)
outf.flush()
proto = Protocol(inf.read, send_fn)
handler = ReceivePackHandler(backend, [path], proto)
# FIXME: Catch exceptions and write a single-line summary to outf.
handler.handle()
return 0
def _make_branch_ref(name):
if getattr(name, 'encode', None):
name = name.encode(DEFAULT_ENCODING)
return LOCAL_BRANCH_PREFIX + name
def _make_tag_ref(name):
if getattr(name, 'encode', None):
name = name.encode(DEFAULT_ENCODING)
return b"refs/tags/" + name
def branch_delete(repo, name):
"""Delete a branch.
Args:
repo: Path to the repository
name: Name of the branch
"""
with open_repo_closing(repo) as r:
if isinstance(name, list):
names = name
else:
names = [name]
for name in names:
del r.refs[_make_branch_ref(name)]
def branch_create(repo, name, objectish=None, force=False):
"""Create a branch.
Args:
repo: Path to the repository
name: Name of the new branch
objectish: Target object to point new branch at (defaults to HEAD)
force: Force creation of branch, even if it already exists
"""
with open_repo_closing(repo) as r:
if objectish is None:
objectish = "HEAD"
object = parse_object(r, objectish)
refname = _make_branch_ref(name)
ref_message = b"branch: Created from " + objectish.encode('utf-8')
if force:
r.refs.set_if_equals(refname, None, object.id, message=ref_message)
else:
if not r.refs.add_if_new(refname, object.id, message=ref_message):
raise KeyError("Branch with name %s already exists." % name)
def branch_list(repo):
"""List all branches.
Args:
repo: Path to the repository
"""
with open_repo_closing(repo) as r:
return r.refs.keys(base=LOCAL_BRANCH_PREFIX)
def active_branch(repo):
"""Return the active branch in the repository, if any.
Args:
repo: Repository to open
Returns:
branch name
Raises:
KeyError: if the repository does not have a working tree
IndexError: if HEAD is floating
"""
with open_repo_closing(repo) as r:
active_ref = r.refs.follow(b'HEAD')[0][1]
if not active_ref.startswith(LOCAL_BRANCH_PREFIX):
raise ValueError(active_ref)
return active_ref[len(LOCAL_BRANCH_PREFIX):]
def get_branch_remote(repo):
"""Return the active branch's remote name, if any.
Args:
repo: Repository to open
Returns:
remote name
Raises:
KeyError: if the repository does not have a working tree
"""
with open_repo_closing(repo) as r:
branch_name = active_branch(r.path)
config = r.get_config()
try:
remote_name = config.get((b'branch', branch_name), 'remote')
except KeyError:
remote_name = b'origin'
return remote_name
def fetch(repo, remote_location, remote_name=b'origin', outstream=sys.stdout,
errstream=default_bytes_err_stream, message=None, depth=None,
prune=False, prune_tags=False, **kwargs):
"""Fetch objects from a remote server.
Args:
repo: Path to the repository
remote_location: String identifying a remote server
remote_name: Name for remote server
outstream: Output stream (defaults to stdout)
errstream: Error stream (defaults to stderr)
message: Reflog message (defaults to b"fetch: from ")
depth: Depth to fetch at
prune: Prune remote removed refs
prune_tags: Prune reomte removed tags
Returns:
Dictionary with refs on the remote
"""
if message is None:
message = b'fetch: from ' + remote_location.encode("utf-8")
with open_repo_closing(repo) as r:
client, path = get_transport_and_path(
remote_location, config=r.get_config_stack(), **kwargs)
fetch_result = client.fetch(path, r, progress=errstream.write,
depth=depth)
stripped_refs = strip_peeled_refs(fetch_result.refs)
branches = {
n[len(LOCAL_BRANCH_PREFIX):]: v for (n, v) in stripped_refs.items()
if n.startswith(LOCAL_BRANCH_PREFIX)}
r.refs.import_refs(
b'refs/remotes/' + remote_name, branches, message=message,
prune=prune)
tags = {
n[len(b'refs/tags/'):]: v for (n, v) in stripped_refs.items()
if n.startswith(b'refs/tags/') and
not n.endswith(ANNOTATED_TAG_SUFFIX)}
r.refs.import_refs(
b'refs/tags', tags, message=message,
prune=prune_tags)
return fetch_result.refs
def ls_remote(remote, config=None, **kwargs):
"""List the refs in a remote.
Args:
remote: Remote repository location
config: Configuration to use
Returns:
Dictionary with remote refs
"""
if config is None:
config = StackedConfig.default()
client, host_path = get_transport_and_path(remote, config=config, **kwargs)
return client.get_refs(host_path)
def repack(repo):
"""Repack loose files in a repository.
Currently this only packs loose objects.
Args:
repo: Path to the repository
"""
with open_repo_closing(repo) as r:
r.object_store.pack_loose_objects()
def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None):
"""Pack objects into a file.
Args:
repo: Path to the repository
object_ids: List of object ids to write
packf: File-like object to write to
idxf: File-like object to write to (can be None)
"""
with open_repo_closing(repo) as r:
entries, data_sum = write_pack_objects(
packf,
r.object_store.iter_shas((oid, None) for oid in object_ids),
delta_window_size=delta_window_size)
if idxf is not None:
entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])
write_pack_index(idxf, entries, data_sum)
def ls_tree(repo, treeish=b"HEAD", outstream=sys.stdout, recursive=False,
name_only=False):
"""List contents of a tree.
Args:
repo: Path to the repository
tree_ish: Tree id to list
outstream: Output stream (defaults to stdout)
recursive: Whether to recursively list files
name_only: Only print item name
"""
def list_tree(store, treeid, base):
for (name, mode, sha) in store[treeid].iteritems():
if base:
name = posixpath.join(base, name)
if name_only:
outstream.write(name + b"\n")
else:
outstream.write(pretty_format_tree_entry(name, mode, sha))
if stat.S_ISDIR(mode) and recursive:
list_tree(store, sha, name)
with open_repo_closing(repo) as r:
tree = parse_tree(r, treeish)
list_tree(r.object_store, tree.id, "")
def remote_add(repo, name, url):
"""Add a remote.
Args:
repo: Path to the repository
name: Remote name
url: Remote URL
"""
if not isinstance(name, bytes):
name = name.encode(DEFAULT_ENCODING)
if not isinstance(url, bytes):
url = url.encode(DEFAULT_ENCODING)
with open_repo_closing(repo) as r:
c = r.get_config()
section = (b'remote', name)
if c.has_section(section):
raise RemoteExists(section)
c.set(section, b"url", url)
c.write_to_path()
def check_ignore(repo, paths, no_index=False):
"""Debug gitignore files.
Args:
repo: Path to the repository
paths: List of paths to check for
no_index: Don't check index
Returns: List of ignored files
"""
with open_repo_closing(repo) as r:
index = r.open_index()
ignore_manager = IgnoreFilterManager.from_repo(r)
for path in paths:
if not no_index and path_to_tree_path(r.path, path) in index:
continue
if os.path.isabs(path):
path = os.path.relpath(path, r.path)
if ignore_manager.is_ignored(path):
yield path
def update_head(repo, target, detached=False, new_branch=None):
"""Update HEAD to point at a new branch/commit.
Note that this does not actually update the working tree.
Args:
repo: Path to the repository
detach: Create a detached head
target: Branch or committish to switch to
new_branch: New branch to create
"""
with open_repo_closing(repo) as r:
if new_branch is not None:
to_set = _make_branch_ref(new_branch)
else:
to_set = b"HEAD"
if detached:
# TODO(jelmer): Provide some way so that the actual ref gets
# updated rather than what it points to, so the delete isn't
# necessary.
del r.refs[to_set]
r.refs[to_set] = parse_commit(r, target).id
else:
r.refs.set_symbolic_ref(to_set, parse_ref(r, target))
if new_branch is not None:
r.refs.set_symbolic_ref(b"HEAD", to_set)
def check_mailmap(repo, contact):
"""Check canonical name and email of contact.
Args:
repo: Path to the repository
contact: Contact name and/or email
Returns: Canonical contact data
"""
with open_repo_closing(repo) as r:
from dulwich.mailmap import Mailmap
import errno
try:
mailmap = Mailmap.from_path(os.path.join(r.path, '.mailmap'))
except IOError as e:
if e.errno != errno.ENOENT:
raise
mailmap = Mailmap()
return mailmap.lookup(contact)
def fsck(repo):
"""Check a repository.
Args:
repo: A path to the repository
Returns: Iterator over errors/warnings
"""
with open_repo_closing(repo) as r:
# TODO(jelmer): check pack files
# TODO(jelmer): check graph
# TODO(jelmer): check refs
for sha in r.object_store:
o = r.object_store[sha]
try:
o.check()
except Exception as e:
yield (sha, e)
def stash_list(repo):
"""List all stashes in a repository."""
with open_repo_closing(repo) as r:
from dulwich.stash import Stash
stash = Stash.from_repo(r)
return enumerate(list(stash.stashes()))
def stash_push(repo):
"""Push a new stash onto the stack."""
with open_repo_closing(repo) as r:
from dulwich.stash import Stash
stash = Stash.from_repo(r)
stash.push()
def stash_pop(repo):
"""Pop a new stash from the stack."""
with open_repo_closing(repo) as r:
from dulwich.stash import Stash
stash = Stash.from_repo(r)
stash.pop()
def ls_files(repo):
"""List all files in an index."""
with open_repo_closing(repo) as r:
return sorted(r.open_index())
def describe(repo):
"""Describe the repository version.
Args:
projdir: git repository root
Returns: a string description of the current git revision
Examples: "gabcdefh", "v0.1" or "v0.1-5-gabcdefh".
"""
# Get the repository
with open_repo_closing(repo) as r:
# Get a list of all tags
refs = r.get_refs()
tags = {}
for key, value in refs.items():
key = key.decode()
obj = r.get_object(value)
if u'tags' not in key:
continue
_, tag = key.rsplit(u'/', 1)
try:
commit = obj.object
except AttributeError:
continue
else:
commit = r.get_object(commit[1])
tags[tag] = [
datetime.datetime(*time.gmtime(commit.commit_time)[:6]),
commit.id.decode('ascii'),
]
sorted_tags = sorted(tags.items(),
key=lambda tag: tag[1][0],
reverse=True)
# If there are no tags, return the current commit
if len(sorted_tags) == 0:
return 'g{}'.format(r[r.head()].id.decode('ascii')[:7])
# We're now 0 commits from the top
commit_count = 0
# Get the latest commit
latest_commit = r[r.head()]
# Walk through all commits
walker = r.get_walker()
for entry in walker:
# Check if tag
commit_id = entry.commit.id.decode('ascii')
for tag in sorted_tags:
tag_name = tag[0]
tag_commit = tag[1][1]
if commit_id == tag_commit:
if commit_count == 0:
return tag_name
else:
return '{}-{}-g{}'.format(
tag_name,
commit_count,
latest_commit.id.decode('ascii')[:7])
commit_count += 1
# Return plain commit if no parent tag can be found
return 'g{}'.format(latest_commit.id.decode('ascii')[:7])
def get_object_by_path(repo, path, committish=None):
"""Get an object by path.
Args:
repo: A path to the repository
path: Path to look up
committish: Commit to look up path in
Returns: A `ShaFile` object
"""
if committish is None:
committish = "HEAD"
# Get the repository
with open_repo_closing(repo) as r:
commit = parse_commit(r, committish)
base_tree = commit.tree
if not isinstance(path, bytes):
path = commit_encode(commit, path)
(mode, sha) = tree_lookup_path(
r.object_store.__getitem__,
base_tree, path)
return r[sha]
def write_tree(repo):
"""Write a tree object from the index.
Args:
repo: Repository for which to write tree
Returns: tree id for the tree that was written
"""
with open_repo_closing(repo) as r:
return r.open_index().commit(r.object_store)
diff --git a/dulwich/reflog.py b/dulwich/reflog.py
index 37a2ff8c..64cb6c5d 100644
--- a/dulwich/reflog.py
+++ b/dulwich/reflog.py
@@ -1,79 +1,79 @@
# reflog.py -- Parsing and writing reflog files
# Copyright (C) 2015 Jelmer Vernooij and others.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Utilities for reading and generating reflogs.
"""
import collections
from dulwich.objects import (
format_timezone,
parse_timezone,
ZERO_SHA,
)
Entry = collections.namedtuple(
'Entry', ['old_sha', 'new_sha', 'committer', 'timestamp', 'timezone',
'message'])
def format_reflog_line(old_sha, new_sha, committer, timestamp, timezone,
message):
"""Generate a single reflog line.
Args:
old_sha: Old Commit SHA
new_sha: New Commit SHA
committer: Committer name and e-mail
timestamp: Timestamp
timezone: Timezone
message: Message
"""
if old_sha is None:
old_sha = ZERO_SHA
return (old_sha + b' ' + new_sha + b' ' + committer + b' ' +
str(int(timestamp)).encode('ascii') + b' ' +
format_timezone(timezone) + b'\t' + message)
def parse_reflog_line(line):
"""Parse a reflog line.
Args:
line: Line to parse
Returns: Tuple of (old_sha, new_sha, committer, timestamp, timezone,
message)
"""
(begin, message) = line.split(b'\t', 1)
(old_sha, new_sha, rest) = begin.split(b' ', 2)
(committer, timestamp_str, timezone_str) = rest.rsplit(b' ', 2)
return Entry(old_sha, new_sha, committer, int(timestamp_str),
parse_timezone(timezone_str)[0], message)
def read_reflog(f):
"""Read reflog.
Args:
f: File-like object
Returns: Iterator over Entry objects
"""
- for l in f:
- yield parse_reflog_line(l)
+ for line in f:
+ yield parse_reflog_line(line)
diff --git a/dulwich/refs.py b/dulwich/refs.py
index 0b5a3c68..54ce22a3 100644
--- a/dulwich/refs.py
+++ b/dulwich/refs.py
@@ -1,968 +1,968 @@
# refs.py -- For dealing with git refs
# Copyright (C) 2008-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Ref handling.
"""
import errno
import os
import sys
from dulwich.errors import (
PackedRefsException,
RefFormatError,
)
from dulwich.objects import (
git_line,
valid_hexsha,
ZERO_SHA,
)
from dulwich.file import (
GitFile,
ensure_dir_exists,
)
SYMREF = b'ref: '
LOCAL_BRANCH_PREFIX = b'refs/heads/'
LOCAL_TAG_PREFIX = b'refs/tags/'
BAD_REF_CHARS = set(b'\177 ~^:?*[')
ANNOTATED_TAG_SUFFIX = b'^{}'
def parse_symref_value(contents):
"""Parse a symref value.
Args:
contents: Contents to parse
Returns: Destination
"""
if contents.startswith(SYMREF):
return contents[len(SYMREF):].rstrip(b'\r\n')
raise ValueError(contents)
def check_ref_format(refname):
"""Check if a refname is correctly formatted.
Implements all the same rules as git-check-ref-format[1].
[1]
http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
Args:
refname: The refname to check
Returns: True if refname is valid, False otherwise
"""
# These could be combined into one big expression, but are listed
# separately to parallel [1].
if b'/.' in refname or refname.startswith(b'.'):
return False
if b'/' not in refname:
return False
if b'..' in refname:
return False
for i, c in enumerate(refname):
if ord(refname[i:i+1]) < 0o40 or c in BAD_REF_CHARS:
return False
if refname[-1] in b'/.':
return False
if refname.endswith(b'.lock'):
return False
if b'@{' in refname:
return False
if b'\\' in refname:
return False
return True
class RefsContainer(object):
"""A container for refs."""
def __init__(self, logger=None):
self._logger = logger
def _log(self, ref, old_sha, new_sha, committer=None, timestamp=None,
timezone=None, message=None):
if self._logger is None:
return
if message is None:
return
self._logger(ref, old_sha, new_sha, committer, timestamp,
timezone, message)
def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
timezone=None, message=None):
"""Make a ref point at another ref.
Args:
name: Name of the ref to set
other: Name of the ref to point at
message: Optional message
"""
raise NotImplementedError(self.set_symbolic_ref)
def get_packed_refs(self):
"""Get contents of the packed-refs file.
Returns: Dictionary mapping ref names to SHA1s
Note: Will return an empty dictionary when no packed-refs file is
present.
"""
raise NotImplementedError(self.get_packed_refs)
def get_peeled(self, name):
"""Return the cached peeled value of a ref, if available.
Args:
name: Name of the ref to peel
Returns: The peeled value of the ref. If the ref is known not point to
a tag, this will be the SHA the ref refers to. If the ref may point
to a tag, but no cached information is available, None is returned.
"""
return None
def import_refs(self, base, other, committer=None, timestamp=None,
timezone=None, message=None, prune=False):
if prune:
to_delete = set(self.subkeys(base))
else:
to_delete = set()
for name, value in other.items():
self.set_if_equals(b'/'.join((base, name)), None, value,
message=message)
if to_delete:
try:
to_delete.remove(name)
except KeyError:
pass
for ref in to_delete:
self.remove_if_equals(b'/'.join((base, ref)), None)
def allkeys(self):
"""All refs present in this container."""
raise NotImplementedError(self.allkeys)
def __iter__(self):
return iter(self.allkeys())
def keys(self, base=None):
"""Refs present in this container.
Args:
base: An optional base to return refs under.
Returns: An unsorted set of valid refs in this container, including
packed refs.
"""
if base is not None:
return self.subkeys(base)
else:
return self.allkeys()
def subkeys(self, base):
"""Refs present in this container under a base.
Args:
base: The base to return refs under.
Returns: A set of valid refs in this container under the base; the base
prefix is stripped from the ref names returned.
"""
keys = set()
base_len = len(base) + 1
for refname in self.allkeys():
if refname.startswith(base):
keys.add(refname[base_len:])
return keys
def as_dict(self, base=None):
"""Return the contents of this container as a dictionary.
"""
ret = {}
keys = self.keys(base)
if base is None:
base = b''
else:
base = base.rstrip(b'/')
for key in keys:
try:
ret[key] = self[(base + b'/' + key).strip(b'/')]
except KeyError:
continue # Unable to resolve
return ret
def _check_refname(self, name):
"""Ensure a refname is valid and lives in refs or is HEAD.
HEAD is not a valid refname according to git-check-ref-format, but this
class needs to be able to touch HEAD. Also, check_ref_format expects
refnames without the leading 'refs/', but this class requires that
so it cannot touch anything outside the refs dir (or HEAD).
Args:
name: The name of the reference.
Raises:
KeyError: if a refname is not HEAD or is otherwise not valid.
"""
if name in (b'HEAD', b'refs/stash'):
return
if not name.startswith(b'refs/') or not check_ref_format(name[5:]):
raise RefFormatError(name)
def read_ref(self, refname):
"""Read a reference without following any references.
Args:
refname: The name of the reference
Returns: The contents of the ref file, or None if it does
not exist.
"""
contents = self.read_loose_ref(refname)
if not contents:
contents = self.get_packed_refs().get(refname, None)
return contents
def read_loose_ref(self, name):
"""Read a loose reference and return its contents.
Args:
name: the refname to read
Returns: The contents of the ref file, or None if it does
not exist.
"""
raise NotImplementedError(self.read_loose_ref)
def follow(self, name):
"""Follow a reference name.
Returns: a tuple of (refnames, sha), wheres refnames are the names of
references in the chain
"""
contents = SYMREF + name
depth = 0
refnames = []
while contents.startswith(SYMREF):
refname = contents[len(SYMREF):]
refnames.append(refname)
contents = self.read_ref(refname)
if not contents:
break
depth += 1
if depth > 5:
raise KeyError(name)
return refnames, contents
def _follow(self, name):
import warnings
warnings.warn(
"RefsContainer._follow is deprecated. Use RefsContainer.follow "
"instead.", DeprecationWarning)
refnames, contents = self.follow(name)
if not refnames:
return (None, contents)
return (refnames[-1], contents)
def __contains__(self, refname):
if self.read_ref(refname):
return True
return False
def __getitem__(self, name):
"""Get the SHA1 for a reference name.
This method follows all symbolic references.
"""
_, sha = self.follow(name)
if sha is None:
raise KeyError(name)
return sha
def set_if_equals(self, name, old_ref, new_ref, committer=None,
timestamp=None, timezone=None, message=None):
"""Set a refname to new_ref only if it currently equals old_ref.
This method follows all symbolic references if applicable for the
subclass, and can be used to perform an atomic compare-and-swap
operation.
Args:
name: The refname to set.
old_ref: The old sha the refname must refer to, or None to set
unconditionally.
new_ref: The new sha the refname will refer to.
message: Message for reflog
Returns: True if the set was successful, False otherwise.
"""
raise NotImplementedError(self.set_if_equals)
def add_if_new(self, name, ref):
"""Add a new reference only if it does not already exist.
Args:
name: Ref name
ref: Ref value
message: Message for reflog
"""
raise NotImplementedError(self.add_if_new)
def __setitem__(self, name, ref):
"""Set a reference name to point to the given SHA1.
This method follows all symbolic references if applicable for the
subclass.
Note: This method unconditionally overwrites the contents of a
reference. To update atomically only if the reference has not
changed, use set_if_equals().
Args:
name: The refname to set.
ref: The new sha the refname will refer to.
"""
self.set_if_equals(name, None, ref)
def remove_if_equals(self, name, old_ref, committer=None,
timestamp=None, timezone=None, message=None):
"""Remove a refname only if it currently equals old_ref.
This method does not follow symbolic references, even if applicable for
the subclass. It can be used to perform an atomic compare-and-delete
operation.
Args:
name: The refname to delete.
old_ref: The old sha the refname must refer to, or None to
delete unconditionally.
message: Message for reflog
Returns: True if the delete was successful, False otherwise.
"""
raise NotImplementedError(self.remove_if_equals)
def __delitem__(self, name):
"""Remove a refname.
This method does not follow symbolic references, even if applicable for
the subclass.
Note: This method unconditionally deletes the contents of a reference.
To delete atomically only if the reference has not changed, use
remove_if_equals().
Args:
name: The refname to delete.
"""
self.remove_if_equals(name, None)
def get_symrefs(self):
"""Get a dict with all symrefs in this container.
Returns: Dictionary mapping source ref to target ref
"""
ret = {}
for src in self.allkeys():
try:
dst = parse_symref_value(self.read_ref(src))
except ValueError:
pass
else:
ret[src] = dst
return ret
class DictRefsContainer(RefsContainer):
"""RefsContainer backed by a simple dict.
This container does not support symbolic or packed references and is not
threadsafe.
"""
def __init__(self, refs, logger=None):
super(DictRefsContainer, self).__init__(logger=logger)
self._refs = refs
self._peeled = {}
def allkeys(self):
return self._refs.keys()
def read_loose_ref(self, name):
return self._refs.get(name, None)
def get_packed_refs(self):
return {}
def set_symbolic_ref(self, name, other, committer=None,
timestamp=None, timezone=None, message=None):
old = self.follow(name)[-1]
self._refs[name] = SYMREF + other
self._log(name, old, old, committer=committer, timestamp=timestamp,
timezone=timezone, message=message)
def set_if_equals(self, name, old_ref, new_ref, committer=None,
timestamp=None, timezone=None, message=None):
if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
return False
realnames, _ = self.follow(name)
for realname in realnames:
self._check_refname(realname)
old = self._refs.get(realname)
self._refs[realname] = new_ref
self._log(realname, old, new_ref, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
return True
def add_if_new(self, name, ref, committer=None, timestamp=None,
timezone=None, message=None):
if name in self._refs:
return False
self._refs[name] = ref
self._log(name, None, ref, committer=committer, timestamp=timestamp,
timezone=timezone, message=message)
return True
def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
timezone=None, message=None):
if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
return False
try:
old = self._refs.pop(name)
except KeyError:
pass
else:
self._log(name, old, None, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
return True
def get_peeled(self, name):
return self._peeled.get(name)
def _update(self, refs):
"""Update multiple refs; intended only for testing."""
# TODO(dborowitz): replace this with a public function that uses
# set_if_equal.
self._refs.update(refs)
def _update_peeled(self, peeled):
"""Update cached peeled refs; intended only for testing."""
self._peeled.update(peeled)
class InfoRefsContainer(RefsContainer):
"""Refs container that reads refs from a info/refs file."""
def __init__(self, f):
self._refs = {}
self._peeled = {}
- for l in f.readlines():
- sha, name = l.rstrip(b'\n').split(b'\t')
+ for line in f.readlines():
+ sha, name = line.rstrip(b'\n').split(b'\t')
if name.endswith(ANNOTATED_TAG_SUFFIX):
name = name[:-3]
if not check_ref_format(name):
raise ValueError("invalid ref name %r" % name)
self._peeled[name] = sha
else:
if not check_ref_format(name):
raise ValueError("invalid ref name %r" % name)
self._refs[name] = sha
def allkeys(self):
return self._refs.keys()
def read_loose_ref(self, name):
return self._refs.get(name, None)
def get_packed_refs(self):
return {}
def get_peeled(self, name):
try:
return self._peeled[name]
except KeyError:
return self._refs[name]
class DiskRefsContainer(RefsContainer):
"""Refs container that reads refs from disk."""
def __init__(self, path, worktree_path=None, logger=None):
super(DiskRefsContainer, self).__init__(logger=logger)
if getattr(path, 'encode', None) is not None:
path = os.fsencode(path)
self.path = path
if worktree_path is None:
worktree_path = path
if getattr(worktree_path, 'encode', None) is not None:
worktree_path = os.fsencode(worktree_path)
self.worktree_path = worktree_path
self._packed_refs = None
self._peeled_refs = None
def __repr__(self):
return "%s(%r)" % (self.__class__.__name__, self.path)
def subkeys(self, base):
subkeys = set()
path = self.refpath(base)
for root, unused_dirs, files in os.walk(path):
dir = root[len(path):]
if os.path.sep != '/':
dir = dir.replace(os.fsencode(os.path.sep), b"/")
dir = dir.strip(b'/')
for filename in files:
refname = b"/".join(([dir] if dir else []) + [filename])
# check_ref_format requires at least one /, so we prepend the
# base before calling it.
if check_ref_format(base + b'/' + refname):
subkeys.add(refname)
for key in self.get_packed_refs():
if key.startswith(base):
subkeys.add(key[len(base):].strip(b'/'))
return subkeys
def allkeys(self):
allkeys = set()
if os.path.exists(self.refpath(b'HEAD')):
allkeys.add(b'HEAD')
path = self.refpath(b'')
refspath = self.refpath(b'refs')
for root, unused_dirs, files in os.walk(refspath):
dir = root[len(path):]
if os.path.sep != '/':
dir = dir.replace(os.fsencode(os.path.sep), b"/")
for filename in files:
refname = b"/".join([dir, filename])
if check_ref_format(refname):
allkeys.add(refname)
allkeys.update(self.get_packed_refs())
return allkeys
def refpath(self, name):
"""Return the disk path of a ref.
"""
if os.path.sep != "/":
name = name.replace(b"/", os.fsencode(os.path.sep))
# TODO: as the 'HEAD' reference is working tree specific, it
# should actually not be a part of RefsContainer
if name == b'HEAD':
return os.path.join(self.worktree_path, name)
else:
return os.path.join(self.path, name)
def get_packed_refs(self):
"""Get contents of the packed-refs file.
Returns: Dictionary mapping ref names to SHA1s
Note: Will return an empty dictionary when no packed-refs file is
present.
"""
# TODO: invalidate the cache on repacking
if self._packed_refs is None:
# set both to empty because we want _peeled_refs to be
# None if and only if _packed_refs is also None.
self._packed_refs = {}
self._peeled_refs = {}
path = os.path.join(self.path, b'packed-refs')
try:
f = GitFile(path, 'rb')
except IOError as e:
if e.errno == errno.ENOENT:
return {}
raise
with f:
first_line = next(iter(f)).rstrip()
if (first_line.startswith(b'# pack-refs') and b' peeled' in
first_line):
for sha, name, peeled in read_packed_refs_with_peeled(f):
self._packed_refs[name] = sha
if peeled:
self._peeled_refs[name] = peeled
else:
f.seek(0)
for sha, name in read_packed_refs(f):
self._packed_refs[name] = sha
return self._packed_refs
def get_peeled(self, name):
"""Return the cached peeled value of a ref, if available.
Args:
name: Name of the ref to peel
Returns: The peeled value of the ref. If the ref is known not point to
a tag, this will be the SHA the ref refers to. If the ref may point
to a tag, but no cached information is available, None is returned.
"""
self.get_packed_refs()
if self._peeled_refs is None or name not in self._packed_refs:
# No cache: no peeled refs were read, or this ref is loose
return None
if name in self._peeled_refs:
return self._peeled_refs[name]
else:
# Known not peelable
return self[name]
def read_loose_ref(self, name):
"""Read a reference file and return its contents.
If the reference file a symbolic reference, only read the first line of
the file. Otherwise, only read the first 40 bytes.
Args:
name: the refname to read, relative to refpath
Returns: The contents of the ref file, or None if the file does not
exist.
Raises:
IOError: if any other error occurs
"""
filename = self.refpath(name)
try:
with GitFile(filename, 'rb') as f:
header = f.read(len(SYMREF))
if header == SYMREF:
# Read only the first line
return header + next(iter(f)).rstrip(b'\r\n')
else:
# Read only the first 40 bytes
return header + f.read(40 - len(SYMREF))
except IOError as e:
if e.errno in (errno.ENOENT, errno.EISDIR, errno.ENOTDIR):
return None
raise
def _remove_packed_ref(self, name):
if self._packed_refs is None:
return
filename = os.path.join(self.path, b'packed-refs')
# reread cached refs from disk, while holding the lock
f = GitFile(filename, 'wb')
try:
self._packed_refs = None
self.get_packed_refs()
if name not in self._packed_refs:
return
del self._packed_refs[name]
if name in self._peeled_refs:
del self._peeled_refs[name]
write_packed_refs(f, self._packed_refs, self._peeled_refs)
f.close()
finally:
f.abort()
def set_symbolic_ref(self, name, other, committer=None, timestamp=None,
timezone=None, message=None):
"""Make a ref point at another ref.
Args:
name: Name of the ref to set
other: Name of the ref to point at
message: Optional message to describe the change
"""
self._check_refname(name)
self._check_refname(other)
filename = self.refpath(name)
f = GitFile(filename, 'wb')
try:
f.write(SYMREF + other + b'\n')
sha = self.follow(name)[-1]
self._log(name, sha, sha, committer=committer,
timestamp=timestamp, timezone=timezone,
message=message)
except BaseException:
f.abort()
raise
else:
f.close()
def set_if_equals(self, name, old_ref, new_ref, committer=None,
timestamp=None, timezone=None, message=None):
"""Set a refname to new_ref only if it currently equals old_ref.
This method follows all symbolic references, and can be used to perform
an atomic compare-and-swap operation.
Args:
name: The refname to set.
old_ref: The old sha the refname must refer to, or None to set
unconditionally.
new_ref: The new sha the refname will refer to.
message: Set message for reflog
Returns: True if the set was successful, False otherwise.
"""
self._check_refname(name)
try:
realnames, _ = self.follow(name)
realname = realnames[-1]
except (KeyError, IndexError):
realname = name
filename = self.refpath(realname)
# make sure none of the ancestor folders is in packed refs
probe_ref = os.path.dirname(realname)
packed_refs = self.get_packed_refs()
while probe_ref:
if packed_refs.get(probe_ref, None) is not None:
raise OSError(errno.ENOTDIR,
'Not a directory: {}'.format(filename))
probe_ref = os.path.dirname(probe_ref)
ensure_dir_exists(os.path.dirname(filename))
with GitFile(filename, 'wb') as f:
if old_ref is not None:
try:
# read again while holding the lock
orig_ref = self.read_loose_ref(realname)
if orig_ref is None:
orig_ref = self.get_packed_refs().get(
realname, ZERO_SHA)
if orig_ref != old_ref:
f.abort()
return False
except (OSError, IOError):
f.abort()
raise
try:
f.write(new_ref + b'\n')
except (OSError, IOError):
f.abort()
raise
self._log(realname, old_ref, new_ref, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
return True
def add_if_new(self, name, ref, committer=None, timestamp=None,
timezone=None, message=None):
"""Add a new reference only if it does not already exist.
This method follows symrefs, and only ensures that the last ref in the
chain does not exist.
Args:
name: The refname to set.
ref: The new sha the refname will refer to.
message: Optional message for reflog
Returns: True if the add was successful, False otherwise.
"""
try:
realnames, contents = self.follow(name)
if contents is not None:
return False
realname = realnames[-1]
except (KeyError, IndexError):
realname = name
self._check_refname(realname)
filename = self.refpath(realname)
ensure_dir_exists(os.path.dirname(filename))
with GitFile(filename, 'wb') as f:
if os.path.exists(filename) or name in self.get_packed_refs():
f.abort()
return False
try:
f.write(ref + b'\n')
except (OSError, IOError):
f.abort()
raise
else:
self._log(name, None, ref, committer=committer,
timestamp=timestamp, timezone=timezone,
message=message)
return True
def remove_if_equals(self, name, old_ref, committer=None, timestamp=None,
timezone=None, message=None):
"""Remove a refname only if it currently equals old_ref.
This method does not follow symbolic references. It can be used to
perform an atomic compare-and-delete operation.
Args:
name: The refname to delete.
old_ref: The old sha the refname must refer to, or None to
delete unconditionally.
message: Optional message
Returns: True if the delete was successful, False otherwise.
"""
self._check_refname(name)
filename = self.refpath(name)
ensure_dir_exists(os.path.dirname(filename))
f = GitFile(filename, 'wb')
try:
if old_ref is not None:
orig_ref = self.read_loose_ref(name)
if orig_ref is None:
orig_ref = self.get_packed_refs().get(name, ZERO_SHA)
if orig_ref != old_ref:
return False
# remove the reference file itself
try:
os.remove(filename)
except OSError as e:
if e.errno != errno.ENOENT: # may only be packed
raise
self._remove_packed_ref(name)
self._log(name, old_ref, None, committer=committer,
timestamp=timestamp, timezone=timezone, message=message)
finally:
# never write, we just wanted the lock
f.abort()
# outside of the lock, clean-up any parent directory that might now
# be empty. this ensures that re-creating a reference of the same
# name of what was previously a directory works as expected
parent = name
while True:
try:
parent, _ = parent.rsplit(b'/', 1)
except ValueError:
break
parent_filename = self.refpath(parent)
try:
os.rmdir(parent_filename)
except OSError:
# this can be caused by the parent directory being
# removed by another process, being not empty, etc.
# in any case, this is non fatal because we already
# removed the reference, just ignore it
break
return True
def _split_ref_line(line):
"""Split a single ref line into a tuple of SHA1 and name."""
fields = line.rstrip(b'\n\r').split(b' ')
if len(fields) != 2:
raise PackedRefsException("invalid ref line %r" % line)
sha, name = fields
if not valid_hexsha(sha):
raise PackedRefsException("Invalid hex sha %r" % sha)
if not check_ref_format(name):
raise PackedRefsException("invalid ref name %r" % name)
return (sha, name)
def read_packed_refs(f):
"""Read a packed refs file.
Args:
f: file-like object to read from
Returns: Iterator over tuples with SHA1s and ref names.
"""
- for l in f:
- if l.startswith(b'#'):
+ for line in f:
+ if line.startswith(b'#'):
# Comment
continue
- if l.startswith(b'^'):
+ if line.startswith(b'^'):
raise PackedRefsException(
"found peeled ref in packed-refs without peeled")
- yield _split_ref_line(l)
+ yield _split_ref_line(line)
def read_packed_refs_with_peeled(f):
"""Read a packed refs file including peeled refs.
Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
with ref names, SHA1s, and peeled SHA1s (or None).
Args:
f: file-like object to read from, seek'ed to the second line
"""
last = None
for line in f:
if line[0] == b'#':
continue
line = line.rstrip(b'\r\n')
if line.startswith(b'^'):
if not last:
raise PackedRefsException("unexpected peeled ref line")
if not valid_hexsha(line[1:]):
raise PackedRefsException("Invalid hex sha %r" % line[1:])
sha, name = _split_ref_line(last)
last = None
yield (sha, name, line[1:])
else:
if last:
sha, name = _split_ref_line(last)
yield (sha, name, None)
last = line
if last:
sha, name = _split_ref_line(last)
yield (sha, name, None)
def write_packed_refs(f, packed_refs, peeled_refs=None):
"""Write a packed refs file.
Args:
f: empty file-like object to write to
packed_refs: dict of refname to sha of packed refs to write
peeled_refs: dict of refname to peeled value of sha
"""
if peeled_refs is None:
peeled_refs = {}
else:
f.write(b'# pack-refs with: peeled\n')
for refname in sorted(packed_refs.keys()):
f.write(git_line(packed_refs[refname], refname))
if refname in peeled_refs:
f.write(b'^' + peeled_refs[refname] + b'\n')
def read_info_refs(f):
ret = {}
- for l in f.readlines():
- (sha, name) = l.rstrip(b"\r\n").split(b"\t", 1)
+ for line in f.readlines():
+ (sha, name) = line.rstrip(b"\r\n").split(b"\t", 1)
ret[name] = sha
return ret
def write_info_refs(refs, store):
"""Generate info refs."""
for name, sha in sorted(refs.items()):
# get_refs() includes HEAD as a special case, but we don't want to
# advertise it
if name == b'HEAD':
continue
try:
o = store[sha]
except KeyError:
continue
peeled = store.peel_sha(sha)
yield o.id + b'\t' + name + b'\n'
if o.id != peeled.id:
yield peeled.id + b'\t' + name + ANNOTATED_TAG_SUFFIX + b'\n'
def is_local_branch(x):
return x.startswith(LOCAL_BRANCH_PREFIX)
def strip_peeled_refs(refs):
"""Remove all peeled refs"""
return {ref: sha for (ref, sha) in refs.items()
if not ref.endswith(ANNOTATED_TAG_SUFFIX)}
diff --git a/dulwich/repo.py b/dulwich/repo.py
index 307b4f61..e2ef83af 100644
--- a/dulwich/repo.py
+++ b/dulwich/repo.py
@@ -1,1493 +1,1493 @@
# repo.py -- For dealing with git repositories.
# Copyright (C) 2007 James Westby
# Copyright (C) 2008-2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Repository access.
This module contains the base class for git repositories
(BaseRepo) and an implementation which uses a repository on
local disk (Repo).
"""
from io import BytesIO
import errno
import os
import sys
import stat
import time
from dulwich.errors import (
NoIndexPresent,
NotBlobError,
NotCommitError,
NotGitRepository,
NotTreeError,
NotTagError,
CommitError,
RefFormatError,
HookError,
)
from dulwich.file import (
GitFile,
)
from dulwich.object_store import (
DiskObjectStore,
MemoryObjectStore,
ObjectStoreGraphWalker,
)
from dulwich.objects import (
check_hexsha,
Blob,
Commit,
ShaFile,
Tag,
Tree,
)
from dulwich.pack import (
pack_objects_to_data,
)
from dulwich.hooks import (
PreCommitShellHook,
PostCommitShellHook,
CommitMsgShellHook,
PostReceiveShellHook,
)
from dulwich.line_ending import BlobNormalizer
from dulwich.refs import ( # noqa: F401
ANNOTATED_TAG_SUFFIX,
check_ref_format,
RefsContainer,
DictRefsContainer,
InfoRefsContainer,
DiskRefsContainer,
read_packed_refs,
read_packed_refs_with_peeled,
write_packed_refs,
SYMREF,
)
import warnings
CONTROLDIR = '.git'
OBJECTDIR = 'objects'
REFSDIR = 'refs'
REFSDIR_TAGS = 'tags'
REFSDIR_HEADS = 'heads'
INDEX_FILENAME = "index"
COMMONDIR = 'commondir'
GITDIR = 'gitdir'
WORKTREES = 'worktrees'
BASE_DIRECTORIES = [
["branches"],
[REFSDIR],
[REFSDIR, REFSDIR_TAGS],
[REFSDIR, REFSDIR_HEADS],
["hooks"],
["info"]
]
DEFAULT_REF = b'refs/heads/master'
class InvalidUserIdentity(Exception):
"""User identity is not of the format 'user '"""
def __init__(self, identity):
self.identity = identity
def _get_default_identity():
import getpass
import socket
username = getpass.getuser()
try:
import pwd
except ImportError:
fullname = None
else:
try:
gecos = pwd.getpwnam(username).pw_gecos
except KeyError:
fullname = None
else:
fullname = gecos.split(',')[0]
if not fullname:
fullname = username
email = os.environ.get('EMAIL')
if email is None:
email = "{}@{}".format(username, socket.gethostname())
return (fullname, email)
def get_user_identity(config, kind=None):
"""Determine the identity to use for new commits.
"""
if kind:
user = os.environ.get("GIT_" + kind + "_NAME")
if user is not None:
user = user.encode('utf-8')
email = os.environ.get("GIT_" + kind + "_EMAIL")
if email is not None:
email = email.encode('utf-8')
else:
user = None
email = None
if user is None:
try:
user = config.get(("user", ), "name")
except KeyError:
user = None
if email is None:
try:
email = config.get(("user", ), "email")
except KeyError:
email = None
default_user, default_email = _get_default_identity()
if user is None:
user = default_user
if not isinstance(user, bytes):
user = user.encode('utf-8')
if email is None:
email = default_email
if not isinstance(email, bytes):
email = email.encode('utf-8')
if email.startswith(b'<') and email.endswith(b'>'):
email = email[1:-1]
return (user + b" <" + email + b">")
def check_user_identity(identity):
"""Verify that a user identity is formatted correctly.
Args:
identity: User identity bytestring
Raises:
InvalidUserIdentity: Raised when identity is invalid
"""
try:
fst, snd = identity.split(b' <', 1)
except ValueError:
raise InvalidUserIdentity(identity)
if b'>' not in snd:
raise InvalidUserIdentity(identity)
def parse_graftpoints(graftpoints):
"""Convert a list of graftpoints into a dict
Args:
graftpoints: Iterator of graftpoint lines
Each line is formatted as:
[]*
Resulting dictionary is:
: [*]
https://git.wiki.kernel.org/index.php/GraftPoint
"""
grafts = {}
- for l in graftpoints:
- raw_graft = l.split(None, 1)
+ for line in graftpoints:
+ raw_graft = line.split(None, 1)
commit = raw_graft[0]
if len(raw_graft) == 2:
parents = raw_graft[1].split()
else:
parents = []
for sha in [commit] + parents:
check_hexsha(sha, 'Invalid graftpoint')
grafts[commit] = parents
return grafts
def serialize_graftpoints(graftpoints):
"""Convert a dictionary of grafts into string
The graft dictionary is:
: [*]
Each line is formatted as:
[]*
https://git.wiki.kernel.org/index.php/GraftPoint
"""
graft_lines = []
for commit, parents in graftpoints.items():
if parents:
graft_lines.append(commit + b' ' + b' '.join(parents))
else:
graft_lines.append(commit)
return b'\n'.join(graft_lines)
def _set_filesystem_hidden(path):
"""Mark path as to be hidden if supported by platform and filesystem.
On win32 uses SetFileAttributesW api:
"""
if sys.platform == 'win32':
import ctypes
from ctypes.wintypes import BOOL, DWORD, LPCWSTR
FILE_ATTRIBUTE_HIDDEN = 2
SetFileAttributesW = ctypes.WINFUNCTYPE(BOOL, LPCWSTR, DWORD)(
("SetFileAttributesW", ctypes.windll.kernel32))
if isinstance(path, bytes):
path = os.fsdecode(path)
if not SetFileAttributesW(path, FILE_ATTRIBUTE_HIDDEN):
pass # Could raise or log `ctypes.WinError()` here
# Could implement other platform specific filesytem hiding here
class BaseRepo(object):
"""Base class for a git repository.
:ivar object_store: Dictionary-like object for accessing
the objects
:ivar refs: Dictionary-like object with the refs in this
repository
"""
def __init__(self, object_store, refs):
"""Open a repository.
This shouldn't be called directly, but rather through one of the
base classes, such as MemoryRepo or Repo.
Args:
object_store: Object store to use
refs: Refs container to use
"""
self.object_store = object_store
self.refs = refs
self._graftpoints = {}
self.hooks = {}
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
Returns: True if permissions can be trusted, False otherwise.
"""
raise NotImplementedError(self._determine_file_mode)
def _init_files(self, bare):
"""Initialize a default set of named files."""
from dulwich.config import ConfigFile
self._put_named_file('description', b"Unnamed repository")
f = BytesIO()
cf = ConfigFile()
cf.set("core", "repositoryformatversion", "0")
if self._determine_file_mode():
cf.set("core", "filemode", True)
else:
cf.set("core", "filemode", False)
cf.set("core", "bare", bare)
cf.set("core", "logallrefupdates", True)
cf.write_to_file(f)
self._put_named_file('config', f.getvalue())
self._put_named_file(os.path.join('info', 'exclude'), b'')
def get_named_file(self, path):
"""Get a file from the control dir with a specific name.
Although the filename should be interpreted as a filename relative to
the control dir in a disk-based Repo, the object returned need not be
pointing to a file in that location.
Args:
path: The path to the file, relative to the control dir.
Returns: An open file object, or None if the file does not exist.
"""
raise NotImplementedError(self.get_named_file)
def _put_named_file(self, path, contents):
"""Write a file to the control dir with the given name and contents.
Args:
path: The path to the file, relative to the control dir.
contents: A string to write to the file.
"""
raise NotImplementedError(self._put_named_file)
def _del_named_file(self, path):
"""Delete a file in the contrl directory with the given name."""
raise NotImplementedError(self._del_named_file)
def open_index(self):
"""Open the index for this repository.
Raises:
NoIndexPresent: If no index is present
Returns: The matching `Index`
"""
raise NotImplementedError(self.open_index)
def fetch(self, target, determine_wants=None, progress=None, depth=None):
"""Fetch objects into another repository.
Args:
target: The target repository
determine_wants: Optional function to determine what refs to
fetch.
progress: Optional progress function
depth: Optional shallow fetch depth
Returns: The local refs
"""
if determine_wants is None:
determine_wants = target.object_store.determine_wants_all
count, pack_data = self.fetch_pack_data(
determine_wants, target.get_graph_walker(), progress=progress,
depth=depth)
target.object_store.add_pack_data(count, pack_data, progress)
return self.get_refs()
def fetch_pack_data(self, determine_wants, graph_walker, progress,
get_tagged=None, depth=None):
"""Fetch the pack data required for a set of revisions.
Args:
determine_wants: Function that takes a dictionary with heads
and returns the list of heads to fetch.
graph_walker: Object that can iterate over the list of revisions
to fetch and has an "ack" method that will be called to acknowledge
that a revision is present.
progress: Simple progress function that will be called with
updated progress strings.
get_tagged: Function that returns a dict of pointed-to sha ->
tag sha for including tags.
depth: Shallow fetch depth
Returns: count and iterator over pack data
"""
# TODO(jelmer): Fetch pack data directly, don't create objects first.
objects = self.fetch_objects(determine_wants, graph_walker, progress,
get_tagged, depth=depth)
return pack_objects_to_data(objects)
def fetch_objects(self, determine_wants, graph_walker, progress,
get_tagged=None, depth=None):
"""Fetch the missing objects required for a set of revisions.
Args:
determine_wants: Function that takes a dictionary with heads
and returns the list of heads to fetch.
graph_walker: Object that can iterate over the list of revisions
to fetch and has an "ack" method that will be called to acknowledge
that a revision is present.
progress: Simple progress function that will be called with
updated progress strings.
get_tagged: Function that returns a dict of pointed-to sha ->
tag sha for including tags.
depth: Shallow fetch depth
Returns: iterator over objects, with __len__ implemented
"""
if depth not in (None, 0):
raise NotImplementedError("depth not supported yet")
refs = {}
for ref, sha in self.get_refs().items():
try:
obj = self.object_store[sha]
except KeyError:
warnings.warn(
'ref %s points at non-present sha %s' % (
ref.decode('utf-8', 'replace'), sha.decode('ascii')),
UserWarning)
continue
else:
if isinstance(obj, Tag):
refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1]
refs[ref] = sha
wants = determine_wants(refs)
if not isinstance(wants, list):
raise TypeError("determine_wants() did not return a list")
shallows = getattr(graph_walker, 'shallow', frozenset())
unshallows = getattr(graph_walker, 'unshallow', frozenset())
if wants == []:
# TODO(dborowitz): find a way to short-circuit that doesn't change
# this interface.
if shallows or unshallows:
# Do not send a pack in shallow short-circuit path
return None
return []
# If the graph walker is set up with an implementation that can
# ACK/NAK to the wire, it will write data to the client through
# this call as a side-effect.
haves = self.object_store.find_common_revisions(graph_walker)
# Deal with shallow requests separately because the haves do
# not reflect what objects are missing
if shallows or unshallows:
# TODO: filter the haves commits from iter_shas. the specific
# commits aren't missing.
haves = []
def get_parents(commit):
if commit.id in shallows:
return []
return self.get_parents(commit.id, commit)
return self.object_store.iter_shas(
self.object_store.find_missing_objects(
haves, wants, self.get_shallow(),
progress, get_tagged,
get_parents=get_parents))
def generate_pack_data(self, have, want, progress=None, ofs_delta=None):
"""Generate pack data objects for a set of wants/haves.
Args:
have: List of SHA1s of objects that should not be sent
want: List of SHA1s of objects that should be sent
ofs_delta: Whether OFS deltas can be included
progress: Optional progress reporting method
"""
return self.object_store.generate_pack_data(
have, want, shallow=self.get_shallow(),
progress=progress, ofs_delta=ofs_delta)
def get_graph_walker(self, heads=None):
"""Retrieve a graph walker.
A graph walker is used by a remote repository (or proxy)
to find out which objects are present in this repository.
Args:
heads: Repository heads to use (optional)
Returns: A graph walker object
"""
if heads is None:
heads = [
sha for sha in self.refs.as_dict(b'refs/heads').values()
if sha in self.object_store]
return ObjectStoreGraphWalker(
heads, self.get_parents, shallow=self.get_shallow())
def get_refs(self):
"""Get dictionary with all refs.
Returns: A ``dict`` mapping ref names to SHA1s
"""
return self.refs.as_dict()
def head(self):
"""Return the SHA1 pointed at by HEAD."""
return self.refs[b'HEAD']
def _get_object(self, sha, cls):
assert len(sha) in (20, 40)
ret = self.get_object(sha)
if not isinstance(ret, cls):
if cls is Commit:
raise NotCommitError(ret)
elif cls is Blob:
raise NotBlobError(ret)
elif cls is Tree:
raise NotTreeError(ret)
elif cls is Tag:
raise NotTagError(ret)
else:
raise Exception("Type invalid: %r != %r" % (
ret.type_name, cls.type_name))
return ret
def get_object(self, sha):
"""Retrieve the object with the specified SHA.
Args:
sha: SHA to retrieve
Returns: A ShaFile object
Raises:
KeyError: when the object can not be found
"""
return self.object_store[sha]
def get_parents(self, sha, commit=None):
"""Retrieve the parents of a specific commit.
If the specific commit is a graftpoint, the graft parents
will be returned instead.
Args:
sha: SHA of the commit for which to retrieve the parents
commit: Optional commit matching the sha
Returns: List of parents
"""
try:
return self._graftpoints[sha]
except KeyError:
if commit is None:
commit = self[sha]
return commit.parents
def get_config(self):
"""Retrieve the config object.
Returns: `ConfigFile` object for the ``.git/config`` file.
"""
raise NotImplementedError(self.get_config)
def get_description(self):
"""Retrieve the description for this repository.
Returns: String with the description of the repository
as set by the user.
"""
raise NotImplementedError(self.get_description)
def set_description(self, description):
"""Set the description for this repository.
Args:
description: Text to set as description for this repository.
"""
raise NotImplementedError(self.set_description)
def get_config_stack(self):
"""Return a config stack for this repository.
This stack accesses the configuration for both this repository
itself (.git/config) and the global configuration, which usually
lives in ~/.gitconfig.
Returns: `Config` instance for this repository
"""
from dulwich.config import StackedConfig
backends = [self.get_config()] + StackedConfig.default_backends()
return StackedConfig(backends, writable=backends[0])
def get_shallow(self):
"""Get the set of shallow commits.
Returns: Set of shallow commits.
"""
f = self.get_named_file('shallow')
if f is None:
return set()
with f:
- return set(l.strip() for l in f)
+ return set(line.strip() for line in f)
def update_shallow(self, new_shallow, new_unshallow):
"""Update the list of shallow objects.
Args:
new_shallow: Newly shallow objects
new_unshallow: Newly no longer shallow objects
"""
shallow = self.get_shallow()
if new_shallow:
shallow.update(new_shallow)
if new_unshallow:
shallow.difference_update(new_unshallow)
self._put_named_file(
'shallow',
b''.join([sha + b'\n' for sha in shallow]))
def get_peeled(self, ref):
"""Get the peeled value of a ref.
Args:
ref: The refname to peel.
Returns: The fully-peeled SHA1 of a tag object, after peeling all
intermediate tags; if the original ref does not point to a tag,
this will equal the original SHA1.
"""
cached = self.refs.get_peeled(ref)
if cached is not None:
return cached
return self.object_store.peel_sha(self.refs[ref]).id
def get_walker(self, include=None, *args, **kwargs):
"""Obtain a walker for this repository.
Args:
include: Iterable of SHAs of commits to include along with their
ancestors. Defaults to [HEAD]
exclude: Iterable of SHAs of commits to exclude along with their
ancestors, overriding includes.
order: ORDER_* constant specifying the order of results.
Anything other than ORDER_DATE may result in O(n) memory usage.
reverse: If True, reverse the order of output, requiring O(n)
memory.
max_entries: The maximum number of entries to yield, or None for
no limit.
paths: Iterable of file or subtree paths to show entries for.
rename_detector: diff.RenameDetector object for detecting
renames.
follow: If True, follow path across renames/copies. Forces a
default rename_detector.
since: Timestamp to list commits after.
until: Timestamp to list commits before.
queue_cls: A class to use for a queue of commits, supporting the
iterator protocol. The constructor takes a single argument, the
Walker.
Returns: A `Walker` object
"""
from dulwich.walk import Walker
if include is None:
include = [self.head()]
if isinstance(include, str):
include = [include]
kwargs['get_parents'] = lambda commit: self.get_parents(
commit.id, commit)
return Walker(self.object_store, include, *args, **kwargs)
def __getitem__(self, name):
"""Retrieve a Git object by SHA1 or ref.
Args:
name: A Git object SHA1 or a ref name
Returns: A `ShaFile` object, such as a Commit or Blob
Raises:
KeyError: when the specified ref or object does not exist
"""
if not isinstance(name, bytes):
raise TypeError("'name' must be bytestring, not %.80s" %
type(name).__name__)
if len(name) in (20, 40):
try:
return self.object_store[name]
except (KeyError, ValueError):
pass
try:
return self.object_store[self.refs[name]]
except RefFormatError:
raise KeyError(name)
def __contains__(self, name):
"""Check if a specific Git object or ref is present.
Args:
name: Git object SHA1 or ref name
"""
if len(name) in (20, 40):
return name in self.object_store or name in self.refs
else:
return name in self.refs
def __setitem__(self, name, value):
"""Set a ref.
Args:
name: ref name
value: Ref value - either a ShaFile object, or a hex sha
"""
if name.startswith(b"refs/") or name == b'HEAD':
if isinstance(value, ShaFile):
self.refs[name] = value.id
elif isinstance(value, bytes):
self.refs[name] = value
else:
raise TypeError(value)
else:
raise ValueError(name)
def __delitem__(self, name):
"""Remove a ref.
Args:
name: Name of the ref to remove
"""
if name.startswith(b"refs/") or name == b"HEAD":
del self.refs[name]
else:
raise ValueError(name)
def _get_user_identity(self, config, kind=None):
"""Determine the identity to use for new commits.
"""
# TODO(jelmer): Deprecate this function in favor of get_user_identity
return get_user_identity(config)
def _add_graftpoints(self, updated_graftpoints):
"""Add or modify graftpoints
Args:
updated_graftpoints: Dict of commit shas to list of parent shas
"""
# Simple validation
for commit, parents in updated_graftpoints.items():
for sha in [commit] + parents:
check_hexsha(sha, 'Invalid graftpoint')
self._graftpoints.update(updated_graftpoints)
def _remove_graftpoints(self, to_remove=[]):
"""Remove graftpoints
Args:
to_remove: List of commit shas
"""
for sha in to_remove:
del self._graftpoints[sha]
def _read_heads(self, name):
f = self.get_named_file(name)
if f is None:
return []
with f:
- return [l.strip() for l in f.readlines() if l.strip()]
+ return [line.strip() for line in f.readlines() if line.strip()]
def do_commit(self, message=None, committer=None,
author=None, commit_timestamp=None,
commit_timezone=None, author_timestamp=None,
author_timezone=None, tree=None, encoding=None,
ref=b'HEAD', merge_heads=None):
"""Create a new commit.
Args:
message: Commit message
committer: Committer fullname
author: Author fullname (defaults to committer)
commit_timestamp: Commit timestamp (defaults to now)
commit_timezone: Commit timestamp timezone (defaults to GMT)
author_timestamp: Author timestamp (defaults to commit
timestamp)
author_timezone: Author timestamp timezone
(defaults to commit timestamp timezone)
tree: SHA1 of the tree root to use (if not specified the
current index will be committed).
encoding: Encoding
ref: Optional ref to commit to (defaults to current branch)
merge_heads: Merge heads (defaults to .git/MERGE_HEADS)
Returns: New commit SHA1
"""
import time
c = Commit()
if tree is None:
index = self.open_index()
c.tree = index.commit(self.object_store)
else:
if len(tree) != 40:
raise ValueError("tree must be a 40-byte hex sha string")
c.tree = tree
try:
self.hooks['pre-commit'].execute()
except HookError as e:
raise CommitError(e)
except KeyError: # no hook defined, silent fallthrough
pass
config = self.get_config_stack()
if merge_heads is None:
merge_heads = self._read_heads('MERGE_HEADS')
if committer is None:
committer = get_user_identity(config, kind='COMMITTER')
check_user_identity(committer)
c.committer = committer
if commit_timestamp is None:
# FIXME: Support GIT_COMMITTER_DATE environment variable
commit_timestamp = time.time()
c.commit_time = int(commit_timestamp)
if commit_timezone is None:
# FIXME: Use current user timezone rather than UTC
commit_timezone = 0
c.commit_timezone = commit_timezone
if author is None:
author = get_user_identity(config, kind='AUTHOR')
c.author = author
check_user_identity(author)
if author_timestamp is None:
# FIXME: Support GIT_AUTHOR_DATE environment variable
author_timestamp = commit_timestamp
c.author_time = int(author_timestamp)
if author_timezone is None:
author_timezone = commit_timezone
c.author_timezone = author_timezone
if encoding is None:
try:
encoding = config.get(('i18n', ), 'commitEncoding')
except KeyError:
pass # No dice
if encoding is not None:
c.encoding = encoding
if message is None:
# FIXME: Try to read commit message from .git/MERGE_MSG
raise ValueError("No commit message specified")
try:
c.message = self.hooks['commit-msg'].execute(message)
if c.message is None:
c.message = message
except HookError as e:
raise CommitError(e)
except KeyError: # no hook defined, message not modified
c.message = message
if ref is None:
# Create a dangling commit
c.parents = merge_heads
self.object_store.add_object(c)
else:
try:
old_head = self.refs[ref]
c.parents = [old_head] + merge_heads
self.object_store.add_object(c)
ok = self.refs.set_if_equals(
ref, old_head, c.id, message=b"commit: " + message,
committer=committer, timestamp=commit_timestamp,
timezone=commit_timezone)
except KeyError:
c.parents = merge_heads
self.object_store.add_object(c)
ok = self.refs.add_if_new(
ref, c.id, message=b"commit: " + message,
committer=committer, timestamp=commit_timestamp,
timezone=commit_timezone)
if not ok:
# Fail if the atomic compare-and-swap failed, leaving the
# commit and all its objects as garbage.
raise CommitError("%s changed during commit" % (ref,))
self._del_named_file('MERGE_HEADS')
try:
self.hooks['post-commit'].execute()
except HookError as e: # silent failure
warnings.warn("post-commit hook failed: %s" % e, UserWarning)
except KeyError: # no hook defined, silent fallthrough
pass
return c.id
def read_gitfile(f):
"""Read a ``.git`` file.
The first line of the file should start with "gitdir: "
Args:
f: File-like object to read from
Returns: A path
"""
cs = f.read()
if not cs.startswith("gitdir: "):
raise ValueError("Expected file to start with 'gitdir: '")
return cs[len("gitdir: "):].rstrip("\n")
class Repo(BaseRepo):
"""A git repository backed by local disk.
To open an existing repository, call the contructor with
the path of the repository.
To create a new repository, use the Repo.init class method.
"""
def __init__(self, root):
hidden_path = os.path.join(root, CONTROLDIR)
if os.path.isdir(os.path.join(hidden_path, OBJECTDIR)):
self.bare = False
self._controldir = hidden_path
elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and
os.path.isdir(os.path.join(root, REFSDIR))):
self.bare = True
self._controldir = root
elif os.path.isfile(hidden_path):
self.bare = False
with open(hidden_path, 'r') as f:
path = read_gitfile(f)
self.bare = False
self._controldir = os.path.join(root, path)
else:
raise NotGitRepository(
"No git repository was found at %(path)s" % dict(path=root)
)
commondir = self.get_named_file(COMMONDIR)
if commondir is not None:
with commondir:
self._commondir = os.path.join(
self.controldir(),
os.fsdecode(commondir.read().rstrip(b"\r\n")))
else:
self._commondir = self._controldir
self.path = root
config = self.get_config()
object_store = DiskObjectStore.from_config(
os.path.join(self.commondir(), OBJECTDIR),
config)
refs = DiskRefsContainer(self.commondir(), self._controldir,
logger=self._write_reflog)
BaseRepo.__init__(self, object_store, refs)
self._graftpoints = {}
graft_file = self.get_named_file(os.path.join("info", "grafts"),
basedir=self.commondir())
if graft_file:
with graft_file:
self._graftpoints.update(parse_graftpoints(graft_file))
graft_file = self.get_named_file("shallow",
basedir=self.commondir())
if graft_file:
with graft_file:
self._graftpoints.update(parse_graftpoints(graft_file))
self.hooks['pre-commit'] = PreCommitShellHook(self.controldir())
self.hooks['commit-msg'] = CommitMsgShellHook(self.controldir())
self.hooks['post-commit'] = PostCommitShellHook(self.controldir())
self.hooks['post-receive'] = PostReceiveShellHook(self.controldir())
def _write_reflog(self, ref, old_sha, new_sha, committer, timestamp,
timezone, message):
from .reflog import format_reflog_line
path = os.path.join(self.controldir(), 'logs', os.fsdecode(ref))
try:
os.makedirs(os.path.dirname(path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
if committer is None:
config = self.get_config_stack()
committer = self._get_user_identity(config)
check_user_identity(committer)
if timestamp is None:
timestamp = int(time.time())
if timezone is None:
timezone = 0 # FIXME
with open(path, 'ab') as f:
f.write(format_reflog_line(old_sha, new_sha, committer,
timestamp, timezone, message) + b'\n')
@classmethod
def discover(cls, start='.'):
"""Iterate parent directories to discover a repository
Return a Repo object for the first parent directory that looks like a
Git repository.
Args:
start: The directory to start discovery from (defaults to '.')
"""
remaining = True
path = os.path.abspath(start)
while remaining:
try:
return cls(path)
except NotGitRepository:
path, remaining = os.path.split(path)
raise NotGitRepository(
"No git repository was found at %(path)s" % dict(path=start)
)
def controldir(self):
"""Return the path of the control directory."""
return self._controldir
def commondir(self):
"""Return the path of the common directory.
For a main working tree, it is identical to controldir().
For a linked working tree, it is the control directory of the
main working tree."""
return self._commondir
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
Returns: True if permissions can be trusted, False otherwise.
"""
fname = os.path.join(self.path, '.probe-permissions')
with open(fname, 'w') as f:
f.write('')
st1 = os.lstat(fname)
try:
os.chmod(fname, st1.st_mode ^ stat.S_IXUSR)
except EnvironmentError as e:
if e.errno == errno.EPERM:
return False
raise
st2 = os.lstat(fname)
os.unlink(fname)
mode_differs = st1.st_mode != st2.st_mode
st2_has_exec = (st2.st_mode & stat.S_IXUSR) != 0
return mode_differs and st2_has_exec
def _put_named_file(self, path, contents):
"""Write a file to the control dir with the given name and contents.
Args:
path: The path to the file, relative to the control dir.
contents: A string to write to the file.
"""
path = path.lstrip(os.path.sep)
with GitFile(os.path.join(self.controldir(), path), 'wb') as f:
f.write(contents)
def _del_named_file(self, path):
try:
os.unlink(os.path.join(self.controldir(), path))
except (IOError, OSError) as e:
if e.errno == errno.ENOENT:
return
raise
def get_named_file(self, path, basedir=None):
"""Get a file from the control dir with a specific name.
Although the filename should be interpreted as a filename relative to
the control dir in a disk-based Repo, the object returned need not be
pointing to a file in that location.
Args:
path: The path to the file, relative to the control dir.
basedir: Optional argument that specifies an alternative to the
control dir.
Returns: An open file object, or None if the file does not exist.
"""
# TODO(dborowitz): sanitize filenames, since this is used directly by
# the dumb web serving code.
if basedir is None:
basedir = self.controldir()
path = path.lstrip(os.path.sep)
try:
return open(os.path.join(basedir, path), 'rb')
except (IOError, OSError) as e:
if e.errno == errno.ENOENT:
return None
raise
def index_path(self):
"""Return path to the index file."""
return os.path.join(self.controldir(), INDEX_FILENAME)
def open_index(self):
"""Open the index for this repository.
Raises:
NoIndexPresent: If no index is present
Returns: The matching `Index`
"""
from dulwich.index import Index
if not self.has_index():
raise NoIndexPresent()
return Index(self.index_path())
def has_index(self):
"""Check if an index is present."""
# Bare repos must never have index files; non-bare repos may have a
# missing index file, which is treated as empty.
return not self.bare
def stage(self, fs_paths):
"""Stage a set of paths.
Args:
fs_paths: List of paths, relative to the repository path
"""
root_path_bytes = os.fsencode(self.path)
if not isinstance(fs_paths, list):
fs_paths = [fs_paths]
from dulwich.index import (
blob_from_path_and_stat,
index_entry_from_stat,
_fs_to_tree_path,
)
index = self.open_index()
blob_normalizer = self.get_blob_normalizer()
for fs_path in fs_paths:
if not isinstance(fs_path, bytes):
fs_path = os.fsencode(fs_path)
if os.path.isabs(fs_path):
raise ValueError(
"path %r should be relative to "
"repository root, not absolute" % fs_path)
tree_path = _fs_to_tree_path(fs_path)
full_path = os.path.join(root_path_bytes, fs_path)
try:
st = os.lstat(full_path)
except OSError:
# File no longer exists
try:
del index[tree_path]
except KeyError:
pass # already removed
else:
if (not stat.S_ISREG(st.st_mode) and
not stat.S_ISLNK(st.st_mode)):
try:
del index[tree_path]
except KeyError:
pass
else:
blob = blob_from_path_and_stat(full_path, st)
blob = blob_normalizer.checkin_normalize(blob, fs_path)
self.object_store.add_object(blob)
index[tree_path] = index_entry_from_stat(st, blob.id, 0)
index.write()
def clone(self, target_path, mkdir=True, bare=False,
origin=b"origin", checkout=None):
"""Clone this repository.
Args:
target_path: Target path
mkdir: Create the target directory
bare: Whether to create a bare repository
origin: Base name for refs in target repository
cloned from this repository
Returns: Created repository as `Repo`
"""
if not bare:
target = self.init(target_path, mkdir=mkdir)
else:
if checkout:
raise ValueError("checkout and bare are incompatible")
target = self.init_bare(target_path, mkdir=mkdir)
self.fetch(target)
encoded_path = self.path
if not isinstance(encoded_path, bytes):
encoded_path = os.fsencode(encoded_path)
ref_message = b"clone: from " + encoded_path
target.refs.import_refs(
b'refs/remotes/' + origin, self.refs.as_dict(b'refs/heads'),
message=ref_message)
target.refs.import_refs(
b'refs/tags', self.refs.as_dict(b'refs/tags'),
message=ref_message)
try:
target.refs.add_if_new(
DEFAULT_REF, self.refs[DEFAULT_REF],
message=ref_message)
except KeyError:
pass
target_config = target.get_config()
target_config.set(('remote', 'origin'), 'url', encoded_path)
target_config.set(('remote', 'origin'), 'fetch',
'+refs/heads/*:refs/remotes/origin/*')
target_config.write_to_path()
# Update target head
head_chain, head_sha = self.refs.follow(b'HEAD')
if head_chain and head_sha is not None:
target.refs.set_symbolic_ref(b'HEAD', head_chain[-1],
message=ref_message)
target[b'HEAD'] = head_sha
if checkout is None:
checkout = (not bare)
if checkout:
# Checkout HEAD to target dir
target.reset_index()
return target
def reset_index(self, tree=None):
"""Reset the index back to a specific tree.
Args:
tree: Tree SHA to reset to, None for current HEAD tree.
"""
from dulwich.index import (
build_index_from_tree,
validate_path_element_default,
validate_path_element_ntfs,
)
if tree is None:
tree = self[b'HEAD'].tree
config = self.get_config()
honor_filemode = config.get_boolean(
b'core', b'filemode', os.name != "nt")
if config.get_boolean(b'core', b'core.protectNTFS', os.name == "nt"):
validate_path_element = validate_path_element_ntfs
else:
validate_path_element = validate_path_element_default
return build_index_from_tree(
self.path, self.index_path(), self.object_store, tree,
honor_filemode=honor_filemode,
validate_path_element=validate_path_element)
def get_config(self):
"""Retrieve the config object.
Returns: `ConfigFile` object for the ``.git/config`` file.
"""
from dulwich.config import ConfigFile
path = os.path.join(self._controldir, 'config')
try:
return ConfigFile.from_path(path)
except (IOError, OSError) as e:
if e.errno != errno.ENOENT:
raise
ret = ConfigFile()
ret.path = path
return ret
def get_description(self):
"""Retrieve the description of this repository.
Returns: A string describing the repository or None.
"""
path = os.path.join(self._controldir, 'description')
try:
with GitFile(path, 'rb') as f:
return f.read()
except (IOError, OSError) as e:
if e.errno != errno.ENOENT:
raise
return None
def __repr__(self):
return "" % self.path
def set_description(self, description):
"""Set the description for this repository.
Args:
description: Text to set as description for this repository.
"""
self._put_named_file('description', description)
@classmethod
def _init_maybe_bare(cls, path, bare):
for d in BASE_DIRECTORIES:
os.mkdir(os.path.join(path, *d))
DiskObjectStore.init(os.path.join(path, OBJECTDIR))
ret = cls(path)
ret.refs.set_symbolic_ref(b'HEAD', DEFAULT_REF)
ret._init_files(bare)
return ret
@classmethod
def init(cls, path, mkdir=False):
"""Create a new repository.
Args:
path: Path in which to create the repository
mkdir: Whether to create the directory
Returns: `Repo` instance
"""
if mkdir:
os.mkdir(path)
controldir = os.path.join(path, CONTROLDIR)
os.mkdir(controldir)
_set_filesystem_hidden(controldir)
cls._init_maybe_bare(controldir, False)
return cls(path)
@classmethod
def _init_new_working_directory(cls, path, main_repo, identifier=None,
mkdir=False):
"""Create a new working directory linked to a repository.
Args:
path: Path in which to create the working tree.
main_repo: Main repository to reference
identifier: Worktree identifier
mkdir: Whether to create the directory
Returns: `Repo` instance
"""
if mkdir:
os.mkdir(path)
if identifier is None:
identifier = os.path.basename(path)
main_worktreesdir = os.path.join(main_repo.controldir(), WORKTREES)
worktree_controldir = os.path.join(main_worktreesdir, identifier)
gitdirfile = os.path.join(path, CONTROLDIR)
with open(gitdirfile, 'wb') as f:
f.write(b'gitdir: ' + os.fsencode(worktree_controldir) + b'\n')
try:
os.mkdir(main_worktreesdir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
try:
os.mkdir(worktree_controldir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(os.path.join(worktree_controldir, GITDIR), 'wb') as f:
f.write(os.fsencode(gitdirfile) + b'\n')
with open(os.path.join(worktree_controldir, COMMONDIR), 'wb') as f:
f.write(b'../..\n')
with open(os.path.join(worktree_controldir, 'HEAD'), 'wb') as f:
f.write(main_repo.head() + b'\n')
r = cls(path)
r.reset_index()
return r
@classmethod
def init_bare(cls, path, mkdir=False):
"""Create a new bare repository.
``path`` should already exist and be an empty directory.
Args:
path: Path to create bare repository in
Returns: a `Repo` instance
"""
if mkdir:
os.mkdir(path)
return cls._init_maybe_bare(path, True)
create = init_bare
def close(self):
"""Close any files opened by this repository."""
self.object_store.close()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def get_blob_normalizer(self):
""" Return a BlobNormalizer object
"""
# TODO Parse the git attributes files
git_attributes = {}
return BlobNormalizer(
self.get_config_stack(), git_attributes
)
class MemoryRepo(BaseRepo):
"""Repo that stores refs, objects, and named files in memory.
MemoryRepos are always bare: they have no working tree and no index, since
those have a stronger dependency on the filesystem.
"""
def __init__(self):
from dulwich.config import ConfigFile
self._reflog = []
refs_container = DictRefsContainer({}, logger=self._append_reflog)
BaseRepo.__init__(self, MemoryObjectStore(), refs_container)
self._named_files = {}
self.bare = True
self._config = ConfigFile()
self._description = None
def _append_reflog(self, *args):
self._reflog.append(args)
def set_description(self, description):
self._description = description
def get_description(self):
return self._description
def _determine_file_mode(self):
"""Probe the file-system to determine whether permissions can be trusted.
Returns: True if permissions can be trusted, False otherwise.
"""
return sys.platform != 'win32'
def _put_named_file(self, path, contents):
"""Write a file to the control dir with the given name and contents.
Args:
path: The path to the file, relative to the control dir.
contents: A string to write to the file.
"""
self._named_files[path] = contents
def _del_named_file(self, path):
try:
del self._named_files[path]
except KeyError:
pass
def get_named_file(self, path, basedir=None):
"""Get a file from the control dir with a specific name.
Although the filename should be interpreted as a filename relative to
the control dir in a disk-baked Repo, the object returned need not be
pointing to a file in that location.
Args:
path: The path to the file, relative to the control dir.
Returns: An open file object, or None if the file does not exist.
"""
contents = self._named_files.get(path, None)
if contents is None:
return None
return BytesIO(contents)
def open_index(self):
"""Fail to open index for this repo, since it is bare.
Raises:
NoIndexPresent: Raised when no index is present
"""
raise NoIndexPresent()
def get_config(self):
"""Retrieve the config object.
Returns: `ConfigFile` object.
"""
return self._config
@classmethod
def init_bare(cls, objects, refs):
"""Create a new bare repository in memory.
Args:
objects: Objects for the new repository,
as iterable
refs: Refs as dictionary, mapping names
to object SHA1s
"""
ret = cls()
for obj in objects:
ret.object_store.add_object(obj)
for refname, sha in refs.items():
ret.refs.add_if_new(refname, sha)
ret._init_files(bare=True)
return ret
diff --git a/dulwich/tests/compat/test_client.py b/dulwich/tests/compat/test_client.py
index c1db4ef4..39bcb2e6 100644
--- a/dulwich/tests/compat/test_client.py
+++ b/dulwich/tests/compat/test_client.py
@@ -1,599 +1,598 @@
# test_client.py -- Compatibilty tests for git client.
# Copyright (C) 2010 Google, Inc.
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Compatibilty tests between the Dulwich client and the cgit server."""
import copy
from io import BytesIO
import os
import select
import signal
import stat
import subprocess
import sys
import tarfile
import tempfile
import threading
-import unittest
from urllib.parse import unquote
import http.server
from dulwich import (
client,
errors,
file,
index,
protocol,
objects,
repo,
)
from dulwich.tests import (
SkipTest,
expectedFailure,
)
from dulwich.tests.compat.utils import (
CompatTestCase,
check_for_daemon,
import_repo_to_dir,
rmtree_ro,
run_git_or_fail,
_DEFAULT_GIT,
)
if sys.platform == 'win32':
import ctypes
class DulwichClientTestBase(object):
"""Tests for client/server compatibility."""
def setUp(self):
self.gitroot = os.path.dirname(
import_repo_to_dir('server_new.export').rstrip(os.sep))
self.dest = os.path.join(self.gitroot, 'dest')
file.ensure_dir_exists(self.dest)
run_git_or_fail(['init', '--quiet', '--bare'], cwd=self.dest)
def tearDown(self):
rmtree_ro(self.gitroot)
def assertDestEqualsSrc(self):
repo_dir = os.path.join(self.gitroot, 'server_new.export')
dest_repo_dir = os.path.join(self.gitroot, 'dest')
with repo.Repo(repo_dir) as src:
with repo.Repo(dest_repo_dir) as dest:
self.assertReposEqual(src, dest)
def _client(self):
raise NotImplementedError()
def _build_path(self):
raise NotImplementedError()
def _do_send_pack(self):
c = self._client()
srcpath = os.path.join(self.gitroot, 'server_new.export')
with repo.Repo(srcpath) as src:
sendrefs = dict(src.get_refs())
del sendrefs[b'HEAD']
c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
src.generate_pack_data)
def test_send_pack(self):
self._do_send_pack()
self.assertDestEqualsSrc()
def test_send_pack_nothing_to_send(self):
self._do_send_pack()
self.assertDestEqualsSrc()
# nothing to send, but shouldn't raise either.
self._do_send_pack()
@staticmethod
def _add_file(repo, tree_id, filename, contents):
tree = repo[tree_id]
blob = objects.Blob()
blob.data = contents.encode('utf-8')
repo.object_store.add_object(blob)
tree.add(filename.encode('utf-8'), stat.S_IFREG | 0o644, blob.id)
repo.object_store.add_object(tree)
return tree.id
def test_send_pack_from_shallow_clone(self):
c = self._client()
server_new_path = os.path.join(self.gitroot, 'server_new.export')
run_git_or_fail(['config', 'http.uploadpack', 'true'],
cwd=server_new_path)
run_git_or_fail(['config', 'http.receivepack', 'true'],
cwd=server_new_path)
remote_path = self._build_path('/server_new.export')
with repo.Repo(self.dest) as local:
result = c.fetch(remote_path, local, depth=1)
for r in result.refs.items():
local.refs.set_if_equals(r[0], None, r[1])
tree_id = local[local.head()].tree
for filename, contents in [('bar', 'bar contents'),
- ('zop', 'zop contents')]:
+ ('zop', 'zop contents')]:
tree_id = self._add_file(local, tree_id, filename, contents)
commit_id = local.do_commit(
message=b"add " + filename.encode('utf-8'),
committer=b"Joe Example ",
tree=tree_id)
sendrefs = dict(local.get_refs())
del sendrefs[b'HEAD']
c.send_pack(remote_path, lambda _: sendrefs,
local.generate_pack_data)
with repo.Repo(server_new_path) as remote:
self.assertEqual(remote.head(), commit_id)
def test_send_without_report_status(self):
c = self._client()
c._send_capabilities.remove(b'report-status')
srcpath = os.path.join(self.gitroot, 'server_new.export')
with repo.Repo(srcpath) as src:
sendrefs = dict(src.get_refs())
del sendrefs[b'HEAD']
c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
src.generate_pack_data)
self.assertDestEqualsSrc()
def make_dummy_commit(self, dest):
b = objects.Blob.from_string(b'hi')
dest.object_store.add_object(b)
t = index.commit_tree(dest.object_store, [(b'hi', b.id, 0o100644)])
c = objects.Commit()
c.author = c.committer = b'Foo Bar '
c.author_time = c.commit_time = 0
c.author_timezone = c.commit_timezone = 0
c.message = b'hi'
c.tree = t
dest.object_store.add_object(c)
return c.id
def disable_ff_and_make_dummy_commit(self):
# disable non-fast-forward pushes to the server
dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
cwd=dest.path)
commit_id = self.make_dummy_commit(dest)
return dest, commit_id
def compute_send(self, src):
sendrefs = dict(src.get_refs())
del sendrefs[b'HEAD']
return sendrefs, src.generate_pack_data
def test_send_pack_one_error(self):
dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
dest.refs[b'refs/heads/master'] = dummy_commit
repo_dir = os.path.join(self.gitroot, 'server_new.export')
with repo.Repo(repo_dir) as src:
sendrefs, gen_pack = self.compute_send(src)
c = self._client()
try:
c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
gen_pack)
except errors.UpdateRefsError as e:
self.assertEqual('refs/heads/master failed to update',
e.args[0])
self.assertEqual({b'refs/heads/branch': b'ok',
b'refs/heads/master': b'non-fast-forward'},
e.ref_status)
def test_send_pack_multiple_errors(self):
dest, dummy = self.disable_ff_and_make_dummy_commit()
# set up for two non-ff errors
branch, master = b'refs/heads/branch', b'refs/heads/master'
dest.refs[branch] = dest.refs[master] = dummy
repo_dir = os.path.join(self.gitroot, 'server_new.export')
with repo.Repo(repo_dir) as src:
sendrefs, gen_pack = self.compute_send(src)
c = self._client()
try:
c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
gen_pack)
except errors.UpdateRefsError as e:
self.assertIn(
str(e),
['{0}, {1} failed to update'.format(
branch.decode('ascii'), master.decode('ascii')),
'{1}, {0} failed to update'.format(
branch.decode('ascii'), master.decode('ascii'))])
self.assertEqual({branch: b'non-fast-forward',
master: b'non-fast-forward'},
e.ref_status)
def test_archive(self):
c = self._client()
f = BytesIO()
c.archive(self._build_path('/server_new.export'), b'HEAD', f.write)
f.seek(0)
tf = tarfile.open(fileobj=f)
self.assertEqual(['baz', 'foo'], tf.getnames())
def test_fetch_pack(self):
c = self._client()
with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
result = c.fetch(self._build_path('/server_new.export'), dest)
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
self.assertDestEqualsSrc()
def test_fetch_pack_depth(self):
c = self._client()
with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
result = c.fetch(self._build_path('/server_new.export'), dest,
depth=1)
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
self.assertEqual(
dest.get_shallow(),
set([b'35e0b59e187dd72a0af294aedffc213eaa4d03ff',
b'514dc6d3fbfe77361bcaef320c4d21b72bc10be9']))
def test_repeat(self):
c = self._client()
with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
result = c.fetch(self._build_path('/server_new.export'), dest)
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
self.assertDestEqualsSrc()
result = c.fetch(self._build_path('/server_new.export'), dest)
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
self.assertDestEqualsSrc()
def test_incremental_fetch_pack(self):
self.test_fetch_pack()
dest, dummy = self.disable_ff_and_make_dummy_commit()
dest.refs[b'refs/heads/master'] = dummy
c = self._client()
repo_dir = os.path.join(self.gitroot, 'server_new.export')
with repo.Repo(repo_dir) as dest:
result = c.fetch(self._build_path('/dest'), dest)
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
self.assertDestEqualsSrc()
def test_fetch_pack_no_side_band_64k(self):
c = self._client()
c._fetch_capabilities.remove(b'side-band-64k')
with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
result = c.fetch(self._build_path('/server_new.export'), dest)
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
self.assertDestEqualsSrc()
def test_fetch_pack_zero_sha(self):
# zero sha1s are already present on the client, and should
# be ignored
c = self._client()
with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
result = c.fetch(
self._build_path('/server_new.export'), dest,
lambda refs: [protocol.ZERO_SHA])
for r in result.refs.items():
dest.refs.set_if_equals(r[0], None, r[1])
def test_send_remove_branch(self):
with repo.Repo(os.path.join(self.gitroot, 'dest')) as dest:
dummy_commit = self.make_dummy_commit(dest)
dest.refs[b'refs/heads/master'] = dummy_commit
dest.refs[b'refs/heads/abranch'] = dummy_commit
sendrefs = dict(dest.refs)
sendrefs[b'refs/heads/abranch'] = b"00" * 20
del sendrefs[b'HEAD']
def gen_pack(have, want, ofs_delta=False):
return 0, []
c = self._client()
self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
c.send_pack(
self._build_path('/dest'), lambda _: sendrefs, gen_pack)
self.assertFalse(b"refs/heads/abranch" in dest.refs)
def test_get_refs(self):
c = self._client()
refs = c.get_refs(self._build_path('/server_new.export'))
repo_dir = os.path.join(self.gitroot, 'server_new.export')
with repo.Repo(repo_dir) as dest:
self.assertDictEqual(dest.refs.as_dict(), refs)
class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
def setUp(self):
CompatTestCase.setUp(self)
DulwichClientTestBase.setUp(self)
if check_for_daemon(limit=1):
raise SkipTest('git-daemon was already running on port %s' %
protocol.TCP_GIT_PORT)
fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
suffix=".pid")
os.fdopen(fd).close()
args = [_DEFAULT_GIT, 'daemon', '--verbose', '--export-all',
'--pid-file=%s' % self.pidfile,
'--base-path=%s' % self.gitroot,
'--enable=receive-pack', '--enable=upload-archive',
'--listen=localhost', '--reuseaddr',
self.gitroot]
self.process = subprocess.Popen(
args, cwd=self.gitroot,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if not check_for_daemon():
raise SkipTest('git-daemon failed to start')
def tearDown(self):
with open(self.pidfile) as f:
pid = int(f.read().strip())
if sys.platform == 'win32':
PROCESS_TERMINATE = 1
handle = ctypes.windll.kernel32.OpenProcess(
PROCESS_TERMINATE, False, pid)
ctypes.windll.kernel32.TerminateProcess(handle, -1)
ctypes.windll.kernel32.CloseHandle(handle)
else:
try:
os.kill(pid, signal.SIGKILL)
os.unlink(self.pidfile)
except (OSError, IOError):
pass
self.process.wait()
self.process.stdout.close()
self.process.stderr.close()
DulwichClientTestBase.tearDown(self)
CompatTestCase.tearDown(self)
def _client(self):
return client.TCPGitClient('localhost')
def _build_path(self, path):
return path
if sys.platform == 'win32':
@expectedFailure
def test_fetch_pack_no_side_band_64k(self):
DulwichClientTestBase.test_fetch_pack_no_side_band_64k(self)
class TestSSHVendor(object):
@staticmethod
def run_command(host, command, username=None, port=None,
password=None, key_filename=None):
cmd, path = command.split(' ')
cmd = cmd.split('-', 1)
path = path.replace("'", "")
p = subprocess.Popen(cmd + [path], bufsize=0, stdin=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
return client.SubprocessWrapper(p)
class DulwichMockSSHClientTest(CompatTestCase, DulwichClientTestBase):
def setUp(self):
CompatTestCase.setUp(self)
DulwichClientTestBase.setUp(self)
self.real_vendor = client.get_ssh_vendor
client.get_ssh_vendor = TestSSHVendor
def tearDown(self):
DulwichClientTestBase.tearDown(self)
CompatTestCase.tearDown(self)
client.get_ssh_vendor = self.real_vendor
def _client(self):
return client.SSHGitClient('localhost')
def _build_path(self, path):
return self.gitroot + path
class DulwichSubprocessClientTest(CompatTestCase, DulwichClientTestBase):
def setUp(self):
CompatTestCase.setUp(self)
DulwichClientTestBase.setUp(self)
def tearDown(self):
DulwichClientTestBase.tearDown(self)
CompatTestCase.tearDown(self)
def _client(self):
return client.SubprocessGitClient()
def _build_path(self, path):
return self.gitroot + path
class GitHTTPRequestHandler(http.server.SimpleHTTPRequestHandler):
"""HTTP Request handler that calls out to 'git http-backend'."""
# Make rfile unbuffered -- we need to read one line and then pass
# the rest to a subprocess, so we can't use buffered input.
rbufsize = 0
def do_POST(self):
self.run_backend()
def do_GET(self):
self.run_backend()
def send_head(self):
return self.run_backend()
def log_request(self, code='-', size='-'):
# Let's be quiet, the test suite is noisy enough already
pass
def run_backend(self):
"""Call out to git http-backend."""
# Based on CGIHTTPServer.CGIHTTPRequestHandler.run_cgi:
# Copyright (c) 2001-2010 Python Software Foundation;
# All Rights Reserved
# Licensed under the Python Software Foundation License.
rest = self.path
# find an explicit query string, if present.
i = rest.rfind('?')
if i >= 0:
rest, query = rest[:i], rest[i+1:]
else:
query = ''
env = copy.deepcopy(os.environ)
env['SERVER_SOFTWARE'] = self.version_string()
env['SERVER_NAME'] = self.server.server_name
env['GATEWAY_INTERFACE'] = 'CGI/1.1'
env['SERVER_PROTOCOL'] = self.protocol_version
env['SERVER_PORT'] = str(self.server.server_port)
env['GIT_PROJECT_ROOT'] = self.server.root_path
env["GIT_HTTP_EXPORT_ALL"] = "1"
env['REQUEST_METHOD'] = self.command
uqrest = unquote(rest)
env['PATH_INFO'] = uqrest
env['SCRIPT_NAME'] = "/"
if query:
env['QUERY_STRING'] = query
host = self.address_string()
if host != self.client_address[0]:
env['REMOTE_HOST'] = host
env['REMOTE_ADDR'] = self.client_address[0]
authorization = self.headers.get("authorization")
if authorization:
authorization = authorization.split()
if len(authorization) == 2:
import base64
import binascii
env['AUTH_TYPE'] = authorization[0]
if authorization[0].lower() == "basic":
try:
authorization = base64.decodestring(authorization[1])
except binascii.Error:
pass
else:
authorization = authorization.split(':')
if len(authorization) == 2:
env['REMOTE_USER'] = authorization[0]
# XXX REMOTE_IDENT
content_type = self.headers.get('content-type')
if content_type:
env['CONTENT_TYPE'] = content_type
length = self.headers.get('content-length')
if length:
env['CONTENT_LENGTH'] = length
referer = self.headers.get('referer')
if referer:
env['HTTP_REFERER'] = referer
accept = []
for line in self.headers.getallmatchingheaders('accept'):
if line[:1] in "\t\n\r ":
accept.append(line.strip())
else:
accept = accept + line[7:].split(',')
env['HTTP_ACCEPT'] = ','.join(accept)
ua = self.headers.get('user-agent')
if ua:
env['HTTP_USER_AGENT'] = ua
co = self.headers.get('cookie')
if co:
env['HTTP_COOKIE'] = co
# XXX Other HTTP_* headers
# Since we're setting the env in the parent, provide empty
# values to override previously set values
for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
env.setdefault(k, "")
self.wfile.write(b"HTTP/1.1 200 Script output follows\r\n")
self.wfile.write(
("Server: %s\r\n" % self.server.server_name).encode('ascii'))
self.wfile.write(
("Date: %s\r\n" % self.date_time_string()).encode('ascii'))
decoded_query = query.replace('+', ' ')
try:
nbytes = int(length)
except (TypeError, ValueError):
nbytes = 0
if self.command.lower() == "post" and nbytes > 0:
data = self.rfile.read(nbytes)
else:
data = None
env['CONTENT_LENGTH'] = '0'
# throw away additional data [see bug #427345]
while select.select([self.rfile._sock], [], [], 0)[0]:
if not self.rfile._sock.recv(1):
break
args = ['http-backend']
if '=' not in decoded_query:
args.append(decoded_query)
stdout = run_git_or_fail(
args, input=data, env=env, stderr=subprocess.PIPE)
self.wfile.write(stdout)
class HTTPGitServer(http.server.HTTPServer):
allow_reuse_address = True
def __init__(self, server_address, root_path):
http.server.HTTPServer.__init__(
self, server_address, GitHTTPRequestHandler)
self.root_path = root_path
self.server_name = "localhost"
def get_url(self):
return 'http://%s:%s/' % (self.server_name, self.server_port)
class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
min_git_version = (1, 7, 0, 2)
def setUp(self):
CompatTestCase.setUp(self)
DulwichClientTestBase.setUp(self)
self._httpd = HTTPGitServer(("localhost", 0), self.gitroot)
self.addCleanup(self._httpd.shutdown)
threading.Thread(target=self._httpd.serve_forever).start()
run_git_or_fail(['config', 'http.uploadpack', 'true'],
cwd=self.dest)
run_git_or_fail(['config', 'http.receivepack', 'true'],
cwd=self.dest)
def tearDown(self):
DulwichClientTestBase.tearDown(self)
CompatTestCase.tearDown(self)
self._httpd.shutdown()
self._httpd.socket.close()
def _client(self):
return client.HttpGitClient(self._httpd.get_url())
def _build_path(self, path):
return path
def test_archive(self):
raise SkipTest("exporting archives not supported over http")
diff --git a/dulwich/tests/test_porcelain.py b/dulwich/tests/test_porcelain.py
index cc65ece2..3cfc4d6e 100644
--- a/dulwich/tests/test_porcelain.py
+++ b/dulwich/tests/test_porcelain.py
@@ -1,1837 +1,1837 @@
# test_porcelain.py -- porcelain tests
# Copyright (C) 2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Tests for dulwich.porcelain."""
from io import BytesIO, StringIO
import errno
import os
import shutil
import tarfile
import tempfile
import time
from dulwich import porcelain
from dulwich.diff_tree import tree_changes
from dulwich.objects import (
Blob,
Tag,
Tree,
ZERO_SHA,
)
from dulwich.repo import (
NoIndexPresent,
Repo,
)
from dulwich.tests import (
TestCase,
)
from dulwich.tests.utils import (
build_commit_graph,
make_commit,
make_object,
)
def flat_walk_dir(dir_to_walk):
for dirpath, _, filenames in os.walk(dir_to_walk):
rel_dirpath = os.path.relpath(dirpath, dir_to_walk)
if not dirpath == dir_to_walk:
yield rel_dirpath
for filename in filenames:
if dirpath == dir_to_walk:
yield filename
else:
yield os.path.join(rel_dirpath, filename)
class PorcelainTestCase(TestCase):
def setUp(self):
super(PorcelainTestCase, self).setUp()
self.test_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.test_dir)
self.repo_path = os.path.join(self.test_dir, 'repo')
self.repo = Repo.init(self.repo_path, mkdir=True)
self.addCleanup(self.repo.close)
class ArchiveTests(PorcelainTestCase):
"""Tests for the archive command."""
def test_simple(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"refs/heads/master"] = c3.id
out = BytesIO()
err = BytesIO()
porcelain.archive(self.repo.path, b"refs/heads/master", outstream=out,
errstream=err)
self.assertEqual(b"", err.getvalue())
tf = tarfile.TarFile(fileobj=out)
self.addCleanup(tf.close)
self.assertEqual([], tf.getnames())
class UpdateServerInfoTests(PorcelainTestCase):
def test_simple(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"refs/heads/foo"] = c3.id
porcelain.update_server_info(self.repo.path)
self.assertTrue(os.path.exists(
os.path.join(self.repo.controldir(), 'info', 'refs')))
class CommitTests(PorcelainTestCase):
def test_custom_author(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"refs/heads/foo"] = c3.id
sha = porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ")
self.assertTrue(isinstance(sha, bytes))
self.assertEqual(len(sha), 40)
def test_unicode(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"refs/heads/foo"] = c3.id
sha = porcelain.commit(
self.repo.path, message="Some message",
author="Joe ",
committer="Bob ")
self.assertTrue(isinstance(sha, bytes))
self.assertEqual(len(sha), 40)
class CleanTests(PorcelainTestCase):
def put_files(self, tracked, ignored, untracked, empty_dirs):
"""Put the described files in the wd
"""
all_files = tracked | ignored | untracked
for file_path in all_files:
abs_path = os.path.join(self.repo.path, file_path)
# File may need to be written in a dir that doesn't exist yet, so
# create the parent dir(s) as necessary
parent_dir = os.path.dirname(abs_path)
try:
os.makedirs(parent_dir)
except OSError as err:
if not err.errno == errno.EEXIST:
raise err
with open(abs_path, 'w') as f:
f.write('')
with open(os.path.join(self.repo.path, '.gitignore'), 'w') as f:
f.writelines(ignored)
for dir_path in empty_dirs:
os.mkdir(os.path.join(self.repo.path, 'empty_dir'))
files_to_add = [os.path.join(self.repo.path, t) for t in tracked]
porcelain.add(repo=self.repo.path, paths=files_to_add)
porcelain.commit(repo=self.repo.path, message="init commit")
def assert_wd(self, expected_paths):
"""Assert paths of files and dirs in wd are same as expected_paths
"""
control_dir_rel = os.path.relpath(
self.repo._controldir, self.repo.path)
# normalize paths to simplify comparison across platforms
found_paths = {
os.path.normpath(p)
for p in flat_walk_dir(self.repo.path)
if not p.split(os.sep)[0] == control_dir_rel}
norm_expected_paths = {os.path.normpath(p) for p in expected_paths}
self.assertEqual(found_paths, norm_expected_paths)
def test_from_root(self):
self.put_files(
tracked={
'tracked_file',
'tracked_dir/tracked_file',
'.gitignore'},
ignored={
'ignored_file'},
untracked={
'untracked_file',
'tracked_dir/untracked_dir/untracked_file',
'untracked_dir/untracked_dir/untracked_file'},
empty_dirs={
'empty_dir'})
porcelain.clean(repo=self.repo.path, target_dir=self.repo.path)
self.assert_wd({
'tracked_file',
'tracked_dir/tracked_file',
'.gitignore',
'ignored_file',
'tracked_dir'})
def test_from_subdir(self):
self.put_files(
tracked={
'tracked_file',
'tracked_dir/tracked_file',
'.gitignore'},
ignored={
'ignored_file'},
untracked={
'untracked_file',
'tracked_dir/untracked_dir/untracked_file',
'untracked_dir/untracked_dir/untracked_file'},
empty_dirs={
'empty_dir'})
porcelain.clean(
repo=self.repo,
target_dir=os.path.join(self.repo.path, 'untracked_dir'))
self.assert_wd({
'tracked_file',
'tracked_dir/tracked_file',
'.gitignore',
'ignored_file',
'untracked_file',
'tracked_dir/untracked_dir/untracked_file',
'empty_dir',
'untracked_dir',
'tracked_dir',
'tracked_dir/untracked_dir'})
class CloneTests(PorcelainTestCase):
def test_simple_local(self):
f1_1 = make_object(Blob, data=b'f1')
commit_spec = [[1], [2, 1], [3, 1, 2]]
trees = {1: [(b'f1', f1_1), (b'f2', f1_1)],
2: [(b'f1', f1_1), (b'f2', f1_1)],
3: [(b'f1', f1_1), (b'f2', f1_1)], }
c1, c2, c3 = build_commit_graph(self.repo.object_store,
commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c3.id
self.repo.refs[b"refs/tags/foo"] = c3.id
target_path = tempfile.mkdtemp()
errstream = BytesIO()
self.addCleanup(shutil.rmtree, target_path)
r = porcelain.clone(self.repo.path, target_path,
checkout=False, errstream=errstream)
self.addCleanup(r.close)
self.assertEqual(r.path, target_path)
target_repo = Repo(target_path)
self.assertEqual(0, len(target_repo.open_index()))
self.assertEqual(c3.id, target_repo.refs[b'refs/tags/foo'])
self.assertTrue(b'f1' not in os.listdir(target_path))
self.assertTrue(b'f2' not in os.listdir(target_path))
c = r.get_config()
encoded_path = self.repo.path
if not isinstance(encoded_path, bytes):
encoded_path = encoded_path.encode('utf-8')
self.assertEqual(encoded_path, c.get((b'remote', b'origin'), b'url'))
self.assertEqual(
b'+refs/heads/*:refs/remotes/origin/*',
c.get((b'remote', b'origin'), b'fetch'))
def test_simple_local_with_checkout(self):
f1_1 = make_object(Blob, data=b'f1')
commit_spec = [[1], [2, 1], [3, 1, 2]]
trees = {1: [(b'f1', f1_1), (b'f2', f1_1)],
2: [(b'f1', f1_1), (b'f2', f1_1)],
3: [(b'f1', f1_1), (b'f2', f1_1)], }
c1, c2, c3 = build_commit_graph(self.repo.object_store,
commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c3.id
target_path = tempfile.mkdtemp()
errstream = BytesIO()
self.addCleanup(shutil.rmtree, target_path)
with porcelain.clone(self.repo.path, target_path,
checkout=True,
errstream=errstream) as r:
self.assertEqual(r.path, target_path)
with Repo(target_path) as r:
self.assertEqual(r.head(), c3.id)
self.assertTrue('f1' in os.listdir(target_path))
self.assertTrue('f2' in os.listdir(target_path))
def test_bare_local_with_checkout(self):
f1_1 = make_object(Blob, data=b'f1')
commit_spec = [[1], [2, 1], [3, 1, 2]]
trees = {1: [(b'f1', f1_1), (b'f2', f1_1)],
2: [(b'f1', f1_1), (b'f2', f1_1)],
3: [(b'f1', f1_1), (b'f2', f1_1)], }
c1, c2, c3 = build_commit_graph(self.repo.object_store,
commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c3.id
target_path = tempfile.mkdtemp()
errstream = BytesIO()
self.addCleanup(shutil.rmtree, target_path)
with porcelain.clone(
self.repo.path, target_path, bare=True,
errstream=errstream) as r:
self.assertEqual(r.path, target_path)
with Repo(target_path) as r:
r.head()
self.assertRaises(NoIndexPresent, r.open_index)
self.assertFalse(b'f1' in os.listdir(target_path))
self.assertFalse(b'f2' in os.listdir(target_path))
def test_no_checkout_with_bare(self):
f1_1 = make_object(Blob, data=b'f1')
commit_spec = [[1]]
trees = {1: [(b'f1', f1_1), (b'f2', f1_1)]}
(c1, ) = build_commit_graph(self.repo.object_store, commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c1.id
self.repo.refs[b"HEAD"] = c1.id
target_path = tempfile.mkdtemp()
errstream = BytesIO()
self.addCleanup(shutil.rmtree, target_path)
self.assertRaises(
ValueError, porcelain.clone, self.repo.path,
target_path, checkout=True, bare=True, errstream=errstream)
def test_no_head_no_checkout(self):
f1_1 = make_object(Blob, data=b'f1')
commit_spec = [[1]]
trees = {1: [(b'f1', f1_1), (b'f2', f1_1)]}
(c1, ) = build_commit_graph(self.repo.object_store, commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c1.id
target_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, target_path)
errstream = BytesIO()
r = porcelain.clone(
self.repo.path, target_path, checkout=True, errstream=errstream)
r.close()
def test_no_head_no_checkout_outstream_errstream_autofallback(self):
f1_1 = make_object(Blob, data=b'f1')
commit_spec = [[1]]
trees = {1: [(b'f1', f1_1), (b'f2', f1_1)]}
(c1, ) = build_commit_graph(self.repo.object_store, commit_spec, trees)
self.repo.refs[b"refs/heads/master"] = c1.id
target_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, target_path)
errstream = porcelain.NoneStream()
r = porcelain.clone(
self.repo.path, target_path, checkout=True, errstream=errstream)
r.close()
def test_source_broken(self):
target_path = tempfile.mkdtemp()
self.assertRaises(
Exception, porcelain.clone, '/nonexistant/repo', target_path)
self.assertFalse(os.path.exists(target_path))
class InitTests(TestCase):
def test_non_bare(self):
repo_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, repo_dir)
porcelain.init(repo_dir)
def test_bare(self):
repo_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, repo_dir)
porcelain.init(repo_dir, bare=True)
class AddTests(PorcelainTestCase):
def test_add_default_paths(self):
# create a file for initial commit
fullpath = os.path.join(self.repo.path, 'blah')
with open(fullpath, 'w') as f:
f.write("\n")
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(repo=self.repo.path, message=b'test',
author=b'test ', committer=b'test ')
# Add a second test file and a file in a directory
with open(os.path.join(self.repo.path, 'foo'), 'w') as f:
f.write("\n")
os.mkdir(os.path.join(self.repo.path, 'adir'))
with open(os.path.join(self.repo.path, 'adir', 'afile'), 'w') as f:
f.write("\n")
cwd = os.getcwd()
try:
os.chdir(self.repo.path)
porcelain.add(self.repo.path)
finally:
os.chdir(cwd)
# Check that foo was added and nothing in .git was modified
index = self.repo.open_index()
self.assertEqual(sorted(index), [b'adir/afile', b'blah', b'foo'])
def test_add_default_paths_subdir(self):
os.mkdir(os.path.join(self.repo.path, 'foo'))
with open(os.path.join(self.repo.path, 'blah'), 'w') as f:
f.write("\n")
with open(os.path.join(self.repo.path, 'foo', 'blie'), 'w') as f:
f.write("\n")
cwd = os.getcwd()
try:
os.chdir(os.path.join(self.repo.path, 'foo'))
porcelain.add(repo=self.repo.path)
porcelain.commit(repo=self.repo.path, message=b'test',
author=b'test ',
committer=b'test ')
finally:
os.chdir(cwd)
index = self.repo.open_index()
self.assertEqual(sorted(index), [b'foo/blie'])
def test_add_file(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(self.repo.path, paths=[fullpath])
self.assertIn(b"foo", self.repo.open_index())
def test_add_ignored(self):
with open(os.path.join(self.repo.path, '.gitignore'), 'w') as f:
f.write("foo")
with open(os.path.join(self.repo.path, 'foo'), 'w') as f:
f.write("BAR")
with open(os.path.join(self.repo.path, 'bar'), 'w') as f:
f.write("BAR")
(added, ignored) = porcelain.add(self.repo.path, paths=[
os.path.join(self.repo.path, "foo"),
os.path.join(self.repo.path, "bar")])
self.assertIn(b"bar", self.repo.open_index())
self.assertEqual(set(['bar']), set(added))
self.assertEqual(set(['foo']), ignored)
def test_add_file_absolute_path(self):
# Absolute paths are (not yet) supported
with open(os.path.join(self.repo.path, 'foo'), 'w') as f:
f.write("BAR")
porcelain.add(self.repo, paths=[os.path.join(self.repo.path, "foo")])
self.assertIn(b"foo", self.repo.open_index())
def test_add_not_in_repo(self):
with open(os.path.join(self.test_dir, 'foo'), 'w') as f:
f.write("BAR")
self.assertRaises(
ValueError,
porcelain.add, self.repo,
paths=[os.path.join(self.test_dir, "foo")])
self.assertRaises(
ValueError,
porcelain.add, self.repo,
paths=["../foo"])
self.assertEqual([], list(self.repo.open_index()))
def test_add_file_clrf_conversion(self):
# Set the right configuration to the repo
c = self.repo.get_config()
c.set("core", "autocrlf", "input")
c.write_to_path()
# Add a file with CRLF line-ending
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'wb') as f:
f.write(b"line1\r\nline2")
porcelain.add(self.repo.path, paths=[fullpath])
# The line-endings should have been converted to LF
index = self.repo.open_index()
self.assertIn(b"foo", index)
entry = index[b"foo"]
blob = self.repo[entry.sha]
self.assertEqual(blob.data, b"line1\nline2")
class RemoveTests(PorcelainTestCase):
def test_remove_file(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(self.repo.path, paths=[fullpath])
porcelain.commit(repo=self.repo, message=b'test',
author=b'test ',
committer=b'test ')
self.assertTrue(os.path.exists(os.path.join(self.repo.path, 'foo')))
cwd = os.getcwd()
try:
os.chdir(self.repo.path)
porcelain.remove(self.repo.path, paths=["foo"])
finally:
os.chdir(cwd)
self.assertFalse(os.path.exists(os.path.join(self.repo.path, 'foo')))
def test_remove_file_staged(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
cwd = os.getcwd()
try:
os.chdir(self.repo.path)
porcelain.add(self.repo.path, paths=[fullpath])
self.assertRaises(Exception, porcelain.rm, self.repo.path,
paths=["foo"])
finally:
os.chdir(cwd)
class LogTests(PorcelainTestCase):
def test_simple(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
outstream = StringIO()
porcelain.log(self.repo.path, outstream=outstream)
self.assertEqual(3, outstream.getvalue().count("-" * 50))
def test_max_entries(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
outstream = StringIO()
porcelain.log(self.repo.path, outstream=outstream, max_entries=1)
self.assertEqual(1, outstream.getvalue().count("-" * 50))
class ShowTests(PorcelainTestCase):
def test_nolist(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
outstream = StringIO()
porcelain.show(self.repo.path, objects=c3.id, outstream=outstream)
self.assertTrue(outstream.getvalue().startswith("-" * 50))
def test_simple(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
outstream = StringIO()
porcelain.show(self.repo.path, objects=[c3.id], outstream=outstream)
self.assertTrue(outstream.getvalue().startswith("-" * 50))
def test_blob(self):
b = Blob.from_string(b"The Foo\n")
self.repo.object_store.add_object(b)
outstream = StringIO()
porcelain.show(self.repo.path, objects=[b.id], outstream=outstream)
self.assertEqual(outstream.getvalue(), "The Foo\n")
def test_commit_no_parent(self):
a = Blob.from_string(b"The Foo\n")
ta = Tree()
ta.add(b"somename", 0o100644, a.id)
ca = make_commit(tree=ta.id)
self.repo.object_store.add_objects([(a, None), (ta, None), (ca, None)])
outstream = StringIO()
porcelain.show(self.repo.path, objects=[ca.id], outstream=outstream)
self.assertMultiLineEqual(outstream.getvalue(), """\
--------------------------------------------------
commit: 344da06c1bb85901270b3e8875c988a027ec087d
Author: Test Author
Committer: Test Committer
Date: Fri Jan 01 2010 00:00:00 +0000
Test message.
diff --git a/somename b/somename
new file mode 100644
index 0000000..ea5c7bf
--- /dev/null
+++ b/somename
@@ -0,0 +1 @@
+The Foo
""")
def test_tag(self):
a = Blob.from_string(b"The Foo\n")
ta = Tree()
ta.add(b"somename", 0o100644, a.id)
ca = make_commit(tree=ta.id)
self.repo.object_store.add_objects([(a, None), (ta, None), (ca, None)])
porcelain.tag_create(
self.repo.path, b"tryme", b'foo ', b'bar',
annotated=True, objectish=ca.id, tag_time=1552854211,
tag_timezone=0)
outstream = StringIO()
porcelain.show(self.repo, objects=[b'refs/tags/tryme'],
outstream=outstream)
self.maxDiff = None
self.assertMultiLineEqual(outstream.getvalue(), """\
Tagger: foo
Date: Sun Mar 17 2019 20:23:31 +0000
bar
--------------------------------------------------
commit: 344da06c1bb85901270b3e8875c988a027ec087d
Author: Test Author
Committer: Test Committer
Date: Fri Jan 01 2010 00:00:00 +0000
Test message.
diff --git a/somename b/somename
new file mode 100644
index 0000000..ea5c7bf
--- /dev/null
+++ b/somename
@@ -0,0 +1 @@
+The Foo
""")
def test_commit_with_change(self):
a = Blob.from_string(b"The Foo\n")
ta = Tree()
ta.add(b"somename", 0o100644, a.id)
ca = make_commit(tree=ta.id)
b = Blob.from_string(b"The Bar\n")
tb = Tree()
tb.add(b"somename", 0o100644, b.id)
cb = make_commit(tree=tb.id, parents=[ca.id])
self.repo.object_store.add_objects(
[(a, None), (b, None), (ta, None), (tb, None),
(ca, None), (cb, None)])
outstream = StringIO()
porcelain.show(self.repo.path, objects=[cb.id], outstream=outstream)
self.assertMultiLineEqual(outstream.getvalue(), """\
--------------------------------------------------
commit: 2c6b6c9cb72c130956657e1fdae58e5b103744fa
Author: Test Author
Committer: Test Committer
Date: Fri Jan 01 2010 00:00:00 +0000
Test message.
diff --git a/somename b/somename
index ea5c7bf..fd38bcb 100644
--- a/somename
+++ b/somename
@@ -1 +1 @@
-The Foo
+The Bar
""")
class SymbolicRefTests(PorcelainTestCase):
def test_set_wrong_symbolic_ref(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
self.assertRaises(ValueError, porcelain.symbolic_ref, self.repo.path,
b'foobar')
def test_set_force_wrong_symbolic_ref(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
porcelain.symbolic_ref(self.repo.path, b'force_foobar', force=True)
# test if we actually changed the file
with self.repo.get_named_file('HEAD') as f:
new_ref = f.read()
self.assertEqual(new_ref, b'ref: refs/heads/force_foobar\n')
def test_set_symbolic_ref(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
porcelain.symbolic_ref(self.repo.path, b'master')
def test_set_symbolic_ref_other_than_master(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]],
attrs=dict(refs='develop'))
self.repo.refs[b"HEAD"] = c3.id
self.repo.refs[b"refs/heads/develop"] = c3.id
porcelain.symbolic_ref(self.repo.path, b'develop')
# test if we actually changed the file
with self.repo.get_named_file('HEAD') as f:
new_ref = f.read()
self.assertEqual(new_ref, b'ref: refs/heads/develop\n')
class DiffTreeTests(PorcelainTestCase):
def test_empty(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
outstream = BytesIO()
porcelain.diff_tree(self.repo.path, c2.tree, c3.tree,
outstream=outstream)
self.assertEqual(outstream.getvalue(), b"")
class CommitTreeTests(PorcelainTestCase):
def test_simple(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
b = Blob()
b.data = b"foo the bar"
t = Tree()
t.add(b"somename", 0o100644, b.id)
self.repo.object_store.add_object(t)
self.repo.object_store.add_object(b)
sha = porcelain.commit_tree(
self.repo.path, t.id, message=b"Withcommit.",
author=b"Joe ",
committer=b"Jane ")
self.assertTrue(isinstance(sha, bytes))
self.assertEqual(len(sha), 40)
class RevListTests(PorcelainTestCase):
def test_simple(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
outstream = BytesIO()
porcelain.rev_list(
self.repo.path, [c3.id], outstream=outstream)
self.assertEqual(
c3.id + b"\n" +
c2.id + b"\n" +
c1.id + b"\n",
outstream.getvalue())
class TagCreateTests(PorcelainTestCase):
def test_annotated(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
porcelain.tag_create(self.repo.path, b"tryme", b'foo ',
b'bar', annotated=True)
tags = self.repo.refs.as_dict(b"refs/tags")
self.assertEqual(list(tags.keys()), [b"tryme"])
tag = self.repo[b'refs/tags/tryme']
self.assertTrue(isinstance(tag, Tag))
self.assertEqual(b"foo ", tag.tagger)
self.assertEqual(b"bar", tag.message)
self.assertLess(time.time() - tag.tag_time, 5)
def test_unannotated(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
porcelain.tag_create(self.repo.path, b"tryme", annotated=False)
tags = self.repo.refs.as_dict(b"refs/tags")
self.assertEqual(list(tags.keys()), [b"tryme"])
self.repo[b'refs/tags/tryme']
self.assertEqual(list(tags.values()), [self.repo.head()])
def test_unannotated_unicode(self):
c1, c2, c3 = build_commit_graph(
self.repo.object_store, [[1], [2, 1], [3, 1, 2]])
self.repo.refs[b"HEAD"] = c3.id
porcelain.tag_create(self.repo.path, "tryme", annotated=False)
tags = self.repo.refs.as_dict(b"refs/tags")
self.assertEqual(list(tags.keys()), [b"tryme"])
self.repo[b'refs/tags/tryme']
self.assertEqual(list(tags.values()), [self.repo.head()])
class TagListTests(PorcelainTestCase):
def test_empty(self):
tags = porcelain.tag_list(self.repo.path)
self.assertEqual([], tags)
def test_simple(self):
self.repo.refs[b"refs/tags/foo"] = b"aa" * 20
self.repo.refs[b"refs/tags/bar/bla"] = b"bb" * 20
tags = porcelain.tag_list(self.repo.path)
self.assertEqual([b"bar/bla", b"foo"], tags)
class TagDeleteTests(PorcelainTestCase):
def test_simple(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo[b"HEAD"] = c1.id
porcelain.tag_create(self.repo, b'foo')
self.assertTrue(b"foo" in porcelain.tag_list(self.repo))
porcelain.tag_delete(self.repo, b'foo')
self.assertFalse(b"foo" in porcelain.tag_list(self.repo))
class ResetTests(PorcelainTestCase):
def test_hard_head(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(self.repo.path, paths=[fullpath])
porcelain.commit(self.repo.path, message=b"Some message",
committer=b"Jane ",
author=b"John ")
with open(os.path.join(self.repo.path, 'foo'), 'wb') as f:
f.write(b"OOH")
porcelain.reset(self.repo, "hard", b"HEAD")
index = self.repo.open_index()
changes = list(tree_changes(self.repo,
index.commit(self.repo.object_store),
self.repo[b'HEAD'].tree))
self.assertEqual([], changes)
def test_hard_commit(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(self.repo.path, paths=[fullpath])
sha = porcelain.commit(self.repo.path, message=b"Some message",
committer=b"Jane ",
author=b"John ")
with open(fullpath, 'wb') as f:
f.write(b"BAZ")
porcelain.add(self.repo.path, paths=[fullpath])
porcelain.commit(self.repo.path, message=b"Some other message",
committer=b"Jane ",
author=b"John ")
porcelain.reset(self.repo, "hard", sha)
index = self.repo.open_index()
changes = list(tree_changes(self.repo,
index.commit(self.repo.object_store),
self.repo[sha].tree))
self.assertEqual([], changes)
class PushTests(PorcelainTestCase):
def test_simple(self):
"""
Basic test of porcelain push where self.repo is the remote. First
clone the remote, commit a file to the clone, then push the changes
back to the remote.
"""
outstream = BytesIO()
errstream = BytesIO()
porcelain.commit(repo=self.repo.path, message=b'init',
author=b'author ',
committer=b'committer ')
# Setup target repo cloned from temp test repo
clone_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, clone_path)
target_repo = porcelain.clone(self.repo.path, target=clone_path,
errstream=errstream)
try:
self.assertEqual(target_repo[b'HEAD'], self.repo[b'HEAD'])
finally:
target_repo.close()
# create a second file to be pushed back to origin
handle, fullpath = tempfile.mkstemp(dir=clone_path)
os.close(handle)
porcelain.add(repo=clone_path, paths=[fullpath])
porcelain.commit(repo=clone_path, message=b'push',
author=b'author ',
committer=b'committer ')
# Setup a non-checked out branch in the remote
refs_path = b"refs/heads/foo"
new_id = self.repo[b'HEAD'].id
self.assertNotEqual(new_id, ZERO_SHA)
self.repo.refs[refs_path] = new_id
# Push to the remote
porcelain.push(clone_path, self.repo.path, b"HEAD:" + refs_path,
outstream=outstream, errstream=errstream)
# Check that the target and source
with Repo(clone_path) as r_clone:
self.assertEqual({
b'HEAD': new_id,
b'refs/heads/foo': r_clone[b'HEAD'].id,
b'refs/heads/master': new_id,
}, self.repo.get_refs())
self.assertEqual(r_clone[b'HEAD'].id, self.repo[refs_path].id)
# Get the change in the target repo corresponding to the add
# this will be in the foo branch.
change = list(tree_changes(self.repo, self.repo[b'HEAD'].tree,
self.repo[b'refs/heads/foo'].tree))[0]
self.assertEqual(os.path.basename(fullpath),
change.new.path.decode('ascii'))
def test_delete(self):
"""Basic test of porcelain push, removing a branch.
"""
outstream = BytesIO()
errstream = BytesIO()
porcelain.commit(repo=self.repo.path, message=b'init',
author=b'author ',
committer=b'committer ')
# Setup target repo cloned from temp test repo
clone_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, clone_path)
target_repo = porcelain.clone(self.repo.path, target=clone_path,
errstream=errstream)
target_repo.close()
# Setup a non-checked out branch in the remote
refs_path = b"refs/heads/foo"
new_id = self.repo[b'HEAD'].id
self.assertNotEqual(new_id, ZERO_SHA)
self.repo.refs[refs_path] = new_id
# Push to the remote
porcelain.push(clone_path, self.repo.path, b":" + refs_path,
outstream=outstream, errstream=errstream)
self.assertEqual({
b'HEAD': new_id,
b'refs/heads/master': new_id,
}, self.repo.get_refs())
class PullTests(PorcelainTestCase):
def setUp(self):
super(PullTests, self).setUp()
# create a file for initial commit
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test',
author=b'test ',
committer=b'test ')
# Setup target repo
self.target_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.target_path)
target_repo = porcelain.clone(self.repo.path, target=self.target_path,
errstream=BytesIO())
target_repo.close()
# create a second file to be pushed
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test2',
author=b'test2 ',
committer=b'test2 ')
self.assertTrue(b'refs/heads/master' in self.repo.refs)
self.assertTrue(b'refs/heads/master' in target_repo.refs)
def test_simple(self):
outstream = BytesIO()
errstream = BytesIO()
# Pull changes into the cloned repo
porcelain.pull(self.target_path, self.repo.path, b'refs/heads/master',
outstream=outstream, errstream=errstream)
# Check the target repo for pushed changes
with Repo(self.target_path) as r:
self.assertEqual(r[b'HEAD'].id, self.repo[b'HEAD'].id)
def test_no_refspec(self):
outstream = BytesIO()
errstream = BytesIO()
# Pull changes into the cloned repo
porcelain.pull(self.target_path, self.repo.path, outstream=outstream,
errstream=errstream)
# Check the target repo for pushed changes
with Repo(self.target_path) as r:
self.assertEqual(r[b'HEAD'].id, self.repo[b'HEAD'].id)
def test_no_remote_location(self):
outstream = BytesIO()
errstream = BytesIO()
# Pull changes into the cloned repo
porcelain.pull(self.target_path, refspecs=b'refs/heads/master',
outstream=outstream, errstream=errstream)
# Check the target repo for pushed changes
with Repo(self.target_path) as r:
self.assertEqual(r[b'HEAD'].id, self.repo[b'HEAD'].id)
class StatusTests(PorcelainTestCase):
def test_empty(self):
results = porcelain.status(self.repo)
self.assertEqual(
{'add': [], 'delete': [], 'modify': []},
results.staged)
self.assertEqual([], results.unstaged)
def test_status_base(self):
"""Integration test for `status` functionality."""
# Commit a dummy file then modify it
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write('origstuff')
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
# modify access and modify time of path
os.utime(fullpath, (0, 0))
with open(fullpath, 'wb') as f:
f.write(b'stuff')
# Make a dummy file and stage it
filename_add = 'bar'
fullpath = os.path.join(self.repo.path, filename_add)
with open(fullpath, 'w') as f:
f.write('stuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
results = porcelain.status(self.repo)
self.assertEqual(results.staged['add'][0],
filename_add.encode('ascii'))
self.assertEqual(results.unstaged, [b'foo'])
def test_status_all(self):
del_path = os.path.join(self.repo.path, 'foo')
mod_path = os.path.join(self.repo.path, 'bar')
add_path = os.path.join(self.repo.path, 'baz')
us_path = os.path.join(self.repo.path, 'blye')
ut_path = os.path.join(self.repo.path, 'blyat')
with open(del_path, 'w') as f:
f.write('origstuff')
with open(mod_path, 'w') as f:
f.write('origstuff')
with open(us_path, 'w') as f:
f.write('origstuff')
porcelain.add(repo=self.repo.path, paths=[del_path, mod_path, us_path])
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
porcelain.remove(self.repo.path, [del_path])
with open(add_path, 'w') as f:
f.write('origstuff')
with open(mod_path, 'w') as f:
f.write('more_origstuff')
with open(us_path, 'w') as f:
f.write('more_origstuff')
porcelain.add(repo=self.repo.path, paths=[add_path, mod_path])
with open(us_path, 'w') as f:
f.write('\norigstuff')
with open(ut_path, 'w') as f:
f.write('origstuff')
results = porcelain.status(self.repo.path)
self.assertDictEqual(
{'add': [b'baz'], 'delete': [b'foo'], 'modify': [b'bar']},
results.staged)
self.assertListEqual(results.unstaged, [b'blye'])
self.assertListEqual(results.untracked, ['blyat'])
def test_status_crlf_mismatch(self):
# First make a commit as if the file has been added on a Linux system
# or with core.autocrlf=True
file_path = os.path.join(self.repo.path, 'crlf')
with open(file_path, 'wb') as f:
f.write(b'line1\nline2')
porcelain.add(repo=self.repo.path, paths=[file_path])
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
# Then update the file as if it was created by CGit on a Windows
# system with core.autocrlf=true
with open(file_path, 'wb') as f:
f.write(b'line1\r\nline2')
results = porcelain.status(self.repo)
self.assertDictEqual(
{'add': [], 'delete': [], 'modify': []},
results.staged)
self.assertListEqual(results.unstaged, [b'crlf'])
self.assertListEqual(results.untracked, [])
def test_status_crlf_convert(self):
# First make a commit as if the file has been added on a Linux system
# or with core.autocrlf=True
file_path = os.path.join(self.repo.path, 'crlf')
with open(file_path, 'wb') as f:
f.write(b'line1\nline2')
porcelain.add(repo=self.repo.path, paths=[file_path])
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
# Then update the file as if it was created by CGit on a Windows
# system with core.autocrlf=true
with open(file_path, 'wb') as f:
f.write(b'line1\r\nline2')
# TODO: It should be set automatically by looking at the configuration
c = self.repo.get_config()
c.set("core", "autocrlf", True)
c.write_to_path()
results = porcelain.status(self.repo)
self.assertDictEqual(
{'add': [], 'delete': [], 'modify': []},
results.staged)
self.assertListEqual(results.unstaged, [])
self.assertListEqual(results.untracked, [])
def test_get_tree_changes_add(self):
"""Unit test for get_tree_changes add."""
# Make a dummy file, stage
filename = 'bar'
fullpath = os.path.join(self.repo.path, filename)
with open(fullpath, 'w') as f:
f.write('stuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
filename = 'foo'
fullpath = os.path.join(self.repo.path, filename)
with open(fullpath, 'w') as f:
f.write('stuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
changes = porcelain.get_tree_changes(self.repo.path)
self.assertEqual(changes['add'][0], filename.encode('ascii'))
self.assertEqual(len(changes['add']), 1)
self.assertEqual(len(changes['modify']), 0)
self.assertEqual(len(changes['delete']), 0)
def test_get_tree_changes_modify(self):
"""Unit test for get_tree_changes modify."""
# Make a dummy file, stage, commit, modify
filename = 'foo'
fullpath = os.path.join(self.repo.path, filename)
with open(fullpath, 'w') as f:
f.write('stuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
with open(fullpath, 'w') as f:
f.write('otherstuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
changes = porcelain.get_tree_changes(self.repo.path)
self.assertEqual(changes['modify'][0], filename.encode('ascii'))
self.assertEqual(len(changes['add']), 0)
self.assertEqual(len(changes['modify']), 1)
self.assertEqual(len(changes['delete']), 0)
def test_get_tree_changes_delete(self):
"""Unit test for get_tree_changes delete."""
# Make a dummy file, stage, commit, remove
filename = 'foo'
fullpath = os.path.join(self.repo.path, filename)
with open(fullpath, 'w') as f:
f.write('stuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
cwd = os.getcwd()
try:
os.chdir(self.repo.path)
porcelain.remove(repo=self.repo.path, paths=[filename])
finally:
os.chdir(cwd)
changes = porcelain.get_tree_changes(self.repo.path)
self.assertEqual(changes['delete'][0], filename.encode('ascii'))
self.assertEqual(len(changes['add']), 0)
self.assertEqual(len(changes['modify']), 0)
self.assertEqual(len(changes['delete']), 1)
def test_get_untracked_paths(self):
with open(os.path.join(self.repo.path, '.gitignore'), 'w') as f:
f.write('ignored\n')
with open(os.path.join(self.repo.path, 'ignored'), 'w') as f:
f.write('blah\n')
with open(os.path.join(self.repo.path, 'notignored'), 'w') as f:
f.write('blah\n')
self.assertEqual(
set(['ignored', 'notignored', '.gitignore']),
set(porcelain.get_untracked_paths(self.repo.path, self.repo.path,
self.repo.open_index())))
self.assertEqual(set(['.gitignore', 'notignored']),
set(porcelain.status(self.repo).untracked))
self.assertEqual(set(['.gitignore', 'notignored', 'ignored']),
set(porcelain.status(self.repo, ignored=True)
.untracked))
def test_get_untracked_paths_nested(self):
with open(os.path.join(self.repo.path, 'notignored'), 'w') as f:
f.write('blah\n')
subrepo = Repo.init(os.path.join(self.repo.path, 'nested'), mkdir=True)
with open(os.path.join(subrepo.path, 'another'), 'w') as f:
f.write('foo\n')
self.assertEqual(
set(['notignored']),
set(porcelain.get_untracked_paths(self.repo.path, self.repo.path,
self.repo.open_index())))
self.assertEqual(
set(['another']),
set(porcelain.get_untracked_paths(subrepo.path, subrepo.path,
subrepo.open_index())))
# TODO(jelmer): Add test for dulwich.porcelain.daemon
class UploadPackTests(PorcelainTestCase):
"""Tests for upload_pack."""
def test_upload_pack(self):
outf = BytesIO()
exitcode = porcelain.upload_pack(
self.repo.path, BytesIO(b"0000"), outf)
outlines = outf.getvalue().splitlines()
self.assertEqual([b"0000"], outlines)
self.assertEqual(0, exitcode)
class ReceivePackTests(PorcelainTestCase):
"""Tests for receive_pack."""
def test_receive_pack(self):
filename = 'foo'
fullpath = os.path.join(self.repo.path, filename)
with open(fullpath, 'w') as f:
f.write('stuff')
porcelain.add(repo=self.repo.path, paths=fullpath)
self.repo.do_commit(message=b'test status',
author=b'author ',
committer=b'committer ',
author_timestamp=1402354300,
commit_timestamp=1402354300, author_timezone=0,
commit_timezone=0)
outf = BytesIO()
exitcode = porcelain.receive_pack(
self.repo.path, BytesIO(b"0000"), outf)
outlines = outf.getvalue().splitlines()
self.assertEqual([
b'0091319b56ce3aee2d489f759736a79cc552c9bb86d9 HEAD\x00 report-status ' # noqa: E501
b'delete-refs quiet ofs-delta side-band-64k '
b'no-done symref=HEAD:refs/heads/master',
- b'003f319b56ce3aee2d489f759736a79cc552c9bb86d9 refs/heads/master',
+ b'003f319b56ce3aee2d489f759736a79cc552c9bb86d9 refs/heads/master',
b'0000'], outlines)
self.assertEqual(0, exitcode)
class BranchListTests(PorcelainTestCase):
def test_standard(self):
self.assertEqual(set([]), set(porcelain.branch_list(self.repo)))
def test_new_branch(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo[b"HEAD"] = c1.id
porcelain.branch_create(self.repo, b"foo")
self.assertEqual(
set([b"master", b"foo"]),
set(porcelain.branch_list(self.repo)))
class BranchCreateTests(PorcelainTestCase):
def test_branch_exists(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo[b"HEAD"] = c1.id
porcelain.branch_create(self.repo, b"foo")
self.assertRaises(KeyError, porcelain.branch_create, self.repo, b"foo")
porcelain.branch_create(self.repo, b"foo", force=True)
def test_new_branch(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo[b"HEAD"] = c1.id
porcelain.branch_create(self.repo, b"foo")
self.assertEqual(
set([b"master", b"foo"]),
set(porcelain.branch_list(self.repo)))
class BranchDeleteTests(PorcelainTestCase):
def test_simple(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo[b"HEAD"] = c1.id
porcelain.branch_create(self.repo, b'foo')
self.assertTrue(b"foo" in porcelain.branch_list(self.repo))
porcelain.branch_delete(self.repo, b'foo')
self.assertFalse(b"foo" in porcelain.branch_list(self.repo))
def test_simple_unicode(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo[b"HEAD"] = c1.id
porcelain.branch_create(self.repo, 'foo')
self.assertTrue(b"foo" in porcelain.branch_list(self.repo))
porcelain.branch_delete(self.repo, 'foo')
self.assertFalse(b"foo" in porcelain.branch_list(self.repo))
class FetchTests(PorcelainTestCase):
def test_simple(self):
outstream = BytesIO()
errstream = BytesIO()
# create a file for initial commit
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test',
author=b'test ',
committer=b'test ')
# Setup target repo
target_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, target_path)
target_repo = porcelain.clone(self.repo.path, target=target_path,
errstream=errstream)
# create a second file to be pushed
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test2',
author=b'test2 ',
committer=b'test2 ')
self.assertFalse(self.repo[b'HEAD'].id in target_repo)
target_repo.close()
# Fetch changes into the cloned repo
porcelain.fetch(target_path, self.repo.path,
outstream=outstream, errstream=errstream)
# Assert that fetch updated the local image of the remote
self.assert_correct_remote_refs(
target_repo.get_refs(), self.repo.get_refs())
# Check the target repo for pushed changes
with Repo(target_path) as r:
self.assertTrue(self.repo[b'HEAD'].id in r)
def test_with_remote_name(self):
remote_name = b'origin'
outstream = BytesIO()
errstream = BytesIO()
# create a file for initial commit
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test',
author=b'test ',
committer=b'test ')
# Setup target repo
target_path = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, target_path)
target_repo = porcelain.clone(self.repo.path, target=target_path,
errstream=errstream)
# Capture current refs
target_refs = target_repo.get_refs()
# create a second file to be pushed
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.commit(repo=self.repo.path, message=b'test2',
author=b'test2 ',
committer=b'test2 ')
self.assertFalse(self.repo[b'HEAD'].id in target_repo)
target_repo.close()
# Fetch changes into the cloned repo
porcelain.fetch(target_path, self.repo.path, remote_name=remote_name,
outstream=outstream, errstream=errstream)
# Assert that fetch updated the local image of the remote
self.assert_correct_remote_refs(
target_repo.get_refs(), self.repo.get_refs())
# Check the target repo for pushed changes, as well as updates
# for the refs
with Repo(target_path) as r:
self.assertTrue(self.repo[b'HEAD'].id in r)
self.assertNotEqual(self.repo.get_refs(), target_refs)
def assert_correct_remote_refs(
self, local_refs, remote_refs, remote_name=b'origin'):
"""Assert that known remote refs corresponds to actual remote refs."""
local_ref_prefix = b'refs/heads'
remote_ref_prefix = b'refs/remotes/' + remote_name
locally_known_remote_refs = {
k[len(remote_ref_prefix) + 1:]: v for k, v in local_refs.items()
if k.startswith(remote_ref_prefix)}
normalized_remote_refs = {
k[len(local_ref_prefix) + 1:]: v for k, v in remote_refs.items()
if k.startswith(local_ref_prefix)}
self.assertEqual(locally_known_remote_refs, normalized_remote_refs)
class RepackTests(PorcelainTestCase):
def test_empty(self):
porcelain.repack(self.repo)
def test_simple(self):
handle, fullpath = tempfile.mkstemp(dir=self.repo.path)
os.close(handle)
porcelain.add(repo=self.repo.path, paths=fullpath)
porcelain.repack(self.repo)
class LsTreeTests(PorcelainTestCase):
def test_empty(self):
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
f = StringIO()
porcelain.ls_tree(self.repo, b"HEAD", outstream=f)
self.assertEqual(f.getvalue(), "")
def test_simple(self):
# Commit a dummy file then modify it
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write('origstuff')
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
f = StringIO()
porcelain.ls_tree(self.repo, b"HEAD", outstream=f)
self.assertEqual(
f.getvalue(),
'100644 blob 8b82634d7eae019850bb883f06abf428c58bc9aa\tfoo\n')
def test_recursive(self):
# Create a directory then write a dummy file in it
dirpath = os.path.join(self.repo.path, 'adir')
filepath = os.path.join(dirpath, 'afile')
os.mkdir(dirpath)
with open(filepath, 'w') as f:
f.write('origstuff')
porcelain.add(repo=self.repo.path, paths=[filepath])
porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
f = StringIO()
porcelain.ls_tree(self.repo, b"HEAD", outstream=f)
self.assertEqual(
f.getvalue(),
'40000 tree b145cc69a5e17693e24d8a7be0016ed8075de66d\tadir\n')
f = StringIO()
porcelain.ls_tree(self.repo, b"HEAD", outstream=f, recursive=True)
self.assertEqual(
f.getvalue(),
'40000 tree b145cc69a5e17693e24d8a7be0016ed8075de66d\tadir\n'
'100644 blob 8b82634d7eae019850bb883f06abf428c58bc9aa\tadir'
'/afile\n')
class LsRemoteTests(PorcelainTestCase):
def test_empty(self):
self.assertEqual({}, porcelain.ls_remote(self.repo.path))
def test_some(self):
cid = porcelain.commit(repo=self.repo.path, message=b'test status',
author=b'author ',
committer=b'committer ')
self.assertEqual({
b'refs/heads/master': cid,
b'HEAD': cid},
porcelain.ls_remote(self.repo.path))
class LsFilesTests(PorcelainTestCase):
def test_empty(self):
self.assertEqual([], list(porcelain.ls_files(self.repo)))
def test_simple(self):
# Commit a dummy file then modify it
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write('origstuff')
porcelain.add(repo=self.repo.path, paths=[fullpath])
self.assertEqual([b'foo'], list(porcelain.ls_files(self.repo)))
class RemoteAddTests(PorcelainTestCase):
def test_new(self):
porcelain.remote_add(
self.repo, 'jelmer', 'git://jelmer.uk/code/dulwich')
c = self.repo.get_config()
self.assertEqual(
c.get((b'remote', b'jelmer'), b'url'),
b'git://jelmer.uk/code/dulwich')
def test_exists(self):
porcelain.remote_add(
self.repo, 'jelmer', 'git://jelmer.uk/code/dulwich')
self.assertRaises(porcelain.RemoteExists, porcelain.remote_add,
self.repo, 'jelmer', 'git://jelmer.uk/code/dulwich')
class CheckIgnoreTests(PorcelainTestCase):
def test_check_ignored(self):
with open(os.path.join(self.repo.path, '.gitignore'), 'w') as f:
f.write('foo')
foo_path = os.path.join(self.repo.path, 'foo')
with open(foo_path, 'w') as f:
f.write('BAR')
bar_path = os.path.join(self.repo.path, 'bar')
with open(bar_path, 'w') as f:
f.write('BAR')
self.assertEqual(
['foo'],
list(porcelain.check_ignore(self.repo, [foo_path])))
self.assertEqual(
[], list(porcelain.check_ignore(self.repo, [bar_path])))
def test_check_added_abs(self):
path = os.path.join(self.repo.path, 'foo')
with open(path, 'w') as f:
f.write('BAR')
self.repo.stage(['foo'])
with open(os.path.join(self.repo.path, '.gitignore'), 'w') as f:
f.write('foo\n')
self.assertEqual(
[], list(porcelain.check_ignore(self.repo, [path])))
self.assertEqual(
['foo'],
list(porcelain.check_ignore(self.repo, [path], no_index=True)))
def test_check_added_rel(self):
with open(os.path.join(self.repo.path, 'foo'), 'w') as f:
f.write('BAR')
self.repo.stage(['foo'])
with open(os.path.join(self.repo.path, '.gitignore'), 'w') as f:
f.write('foo\n')
cwd = os.getcwd()
os.mkdir(os.path.join(self.repo.path, 'bar'))
os.chdir(os.path.join(self.repo.path, 'bar'))
try:
self.assertEqual(
list(porcelain.check_ignore(self.repo, ['../foo'])), [])
self.assertEqual(['../foo'], list(
porcelain.check_ignore(self.repo, ['../foo'], no_index=True)))
finally:
os.chdir(cwd)
class UpdateHeadTests(PorcelainTestCase):
def test_set_to_branch(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo.refs[b"refs/heads/blah"] = c1.id
porcelain.update_head(self.repo, "blah")
self.assertEqual(c1.id, self.repo.head())
self.assertEqual(b'ref: refs/heads/blah',
self.repo.refs.read_ref(b'HEAD'))
def test_set_to_branch_detached(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo.refs[b"refs/heads/blah"] = c1.id
porcelain.update_head(self.repo, "blah", detached=True)
self.assertEqual(c1.id, self.repo.head())
self.assertEqual(c1.id, self.repo.refs.read_ref(b'HEAD'))
def test_set_to_commit_detached(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo.refs[b"refs/heads/blah"] = c1.id
porcelain.update_head(self.repo, c1.id, detached=True)
self.assertEqual(c1.id, self.repo.head())
self.assertEqual(c1.id, self.repo.refs.read_ref(b'HEAD'))
def test_set_new_branch(self):
[c1] = build_commit_graph(self.repo.object_store, [[1]])
self.repo.refs[b"refs/heads/blah"] = c1.id
porcelain.update_head(self.repo, "blah", new_branch="bar")
self.assertEqual(c1.id, self.repo.head())
self.assertEqual(b'ref: refs/heads/bar',
self.repo.refs.read_ref(b'HEAD'))
class MailmapTests(PorcelainTestCase):
def test_no_mailmap(self):
self.assertEqual(
b'Jelmer Vernooij ',
porcelain.check_mailmap(
self.repo, b'Jelmer Vernooij '))
def test_mailmap_lookup(self):
with open(os.path.join(self.repo.path, '.mailmap'), 'wb') as f:
f.write(b"""\
Jelmer Vernooij
""")
self.assertEqual(
b'Jelmer Vernooij ',
porcelain.check_mailmap(
self.repo, b'Jelmer Vernooij '))
class FsckTests(PorcelainTestCase):
def test_none(self):
self.assertEqual(
[],
list(porcelain.fsck(self.repo)))
def test_git_dir(self):
obj = Tree()
a = Blob()
a.data = b"foo"
obj.add(b".git", 0o100644, a.id)
self.repo.object_store.add_objects(
[(a, None), (obj, None)])
self.assertEqual(
[(obj.id, 'invalid name .git')],
[(sha, str(e)) for (sha, e) in porcelain.fsck(self.repo)])
class DescribeTests(PorcelainTestCase):
def test_no_commits(self):
self.assertRaises(KeyError, porcelain.describe, self.repo.path)
def test_single_commit(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(repo=self.repo.path, paths=[fullpath])
sha = porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ")
self.assertEqual(
'g{}'.format(sha[:7].decode('ascii')),
porcelain.describe(self.repo.path))
def test_tag(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ")
porcelain.tag_create(self.repo.path, b"tryme", b'foo ',
b'bar', annotated=True)
self.assertEqual(
"tryme",
porcelain.describe(self.repo.path))
def test_tag_and_commit(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ")
porcelain.tag_create(self.repo.path, b"tryme", b'foo ',
b'bar', annotated=True)
with open(fullpath, 'w') as f:
f.write("BAR2")
porcelain.add(repo=self.repo.path, paths=[fullpath])
sha = porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ")
self.assertEqual(
'tryme-1-g{}'.format(sha[:7].decode('ascii')),
porcelain.describe(self.repo.path))
class HelperTests(PorcelainTestCase):
def test_path_to_tree_path_base(self):
self.assertEqual(
b'bar', porcelain.path_to_tree_path('/home/foo', '/home/foo/bar'))
self.assertEqual(b'bar', porcelain.path_to_tree_path('.', './bar'))
self.assertEqual(b'bar', porcelain.path_to_tree_path('.', 'bar'))
cwd = os.getcwd()
self.assertEqual(
b'bar', porcelain.path_to_tree_path('.', os.path.join(cwd, 'bar')))
self.assertEqual(b'bar', porcelain.path_to_tree_path(cwd, 'bar'))
def test_path_to_tree_path_syntax(self):
self.assertEqual(b'bar', porcelain.path_to_tree_path(b'.', './bar'))
self.assertEqual(b'bar', porcelain.path_to_tree_path('.', b'./bar'))
self.assertEqual(b'bar', porcelain.path_to_tree_path(b'.', b'./bar'))
def test_path_to_tree_path_error(self):
with self.assertRaises(ValueError):
porcelain.path_to_tree_path('/home/foo/', '/home/bar/baz')
def test_path_to_tree_path_rel(self):
cwd = os.getcwd()
os.mkdir(os.path.join(self.repo.path, 'foo'))
os.mkdir(os.path.join(self.repo.path, 'foo/bar'))
try:
os.chdir(os.path.join(self.repo.path, 'foo/bar'))
self.assertEqual(b'bar/baz', porcelain.path_to_tree_path(
'..', 'baz'))
self.assertEqual(b'bar/baz', porcelain.path_to_tree_path(
os.path.join(os.getcwd(), '..'),
os.path.join(os.getcwd(), 'baz')))
self.assertEqual(b'bar/baz', porcelain.path_to_tree_path(
'..', os.path.join(os.getcwd(), 'baz')))
self.assertEqual(b'bar/baz', porcelain.path_to_tree_path(
os.path.join(os.getcwd(), '..'), 'baz'))
finally:
os.chdir(cwd)
class GetObjectByPathTests(PorcelainTestCase):
def test_simple(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ")
self.assertEqual(
b"BAR",
porcelain.get_object_by_path(self.repo, 'foo').data)
self.assertEqual(
b"BAR",
porcelain.get_object_by_path(self.repo, b'foo').data)
def test_encoding(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(repo=self.repo.path, paths=[fullpath])
porcelain.commit(
self.repo.path, message=b"Some message",
author=b"Joe ",
committer=b"Bob ",
encoding=b"utf-8")
self.assertEqual(
b"BAR",
porcelain.get_object_by_path(self.repo, 'foo').data)
self.assertEqual(
b"BAR",
porcelain.get_object_by_path(self.repo, b'foo').data)
def test_missing(self):
self.assertRaises(
KeyError,
porcelain.get_object_by_path, self.repo, 'foo')
class WriteTreeTests(PorcelainTestCase):
def test_simple(self):
fullpath = os.path.join(self.repo.path, 'foo')
with open(fullpath, 'w') as f:
f.write("BAR")
porcelain.add(repo=self.repo.path, paths=[fullpath])
self.assertEqual(
b'd2092c8a9f311f0311083bf8d177f2ca0ab5b241',
porcelain.write_tree(self.repo))
class ActiveBranchTests(PorcelainTestCase):
def test_simple(self):
self.assertEqual(b'master', porcelain.active_branch(self.repo))
diff --git a/dulwich/tests/test_refs.py b/dulwich/tests/test_refs.py
index 2c13a6b0..b2ec1e0a 100644
--- a/dulwich/tests/test_refs.py
+++ b/dulwich/tests/test_refs.py
@@ -1,689 +1,689 @@
# test_refs.py -- tests for refs.py
# encoding: utf-8
# Copyright (C) 2013 Jelmer Vernooij
#
# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
# General Public License as public by the Free Software Foundation; version 2.0
# or (at your option) any later version. You can redistribute it and/or
# modify it under the terms of either of these two licenses.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# You should have received a copy of the licenses; if not, see
# for a copy of the GNU General Public License
# and for a copy of the Apache
# License, Version 2.0.
#
"""Tests for dulwich.refs."""
from io import BytesIO
import os
import sys
import tempfile
from dulwich import errors
from dulwich.file import (
GitFile,
)
from dulwich.objects import ZERO_SHA
from dulwich.refs import (
DictRefsContainer,
InfoRefsContainer,
check_ref_format,
_split_ref_line,
parse_symref_value,
read_packed_refs_with_peeled,
read_packed_refs,
strip_peeled_refs,
write_packed_refs,
)
from dulwich.repo import Repo
from dulwich.tests import (
SkipTest,
TestCase,
)
from dulwich.tests.utils import (
open_repo,
tear_down_repo,
)
class CheckRefFormatTests(TestCase):
"""Tests for the check_ref_format function.
These are the same tests as in the git test suite.
"""
def test_valid(self):
self.assertTrue(check_ref_format(b'heads/foo'))
self.assertTrue(check_ref_format(b'foo/bar/baz'))
self.assertTrue(check_ref_format(b'refs///heads/foo'))
self.assertTrue(check_ref_format(b'foo./bar'))
self.assertTrue(check_ref_format(b'heads/foo@bar'))
self.assertTrue(check_ref_format(b'heads/fix.lock.error'))
def test_invalid(self):
self.assertFalse(check_ref_format(b'foo'))
self.assertFalse(check_ref_format(b'heads/foo/'))
self.assertFalse(check_ref_format(b'./foo'))
self.assertFalse(check_ref_format(b'.refs/foo'))
self.assertFalse(check_ref_format(b'heads/foo..bar'))
self.assertFalse(check_ref_format(b'heads/foo?bar'))
self.assertFalse(check_ref_format(b'heads/foo.lock'))
self.assertFalse(check_ref_format(b'heads/v@{ation'))
self.assertFalse(check_ref_format(b'heads/foo\bar'))
ONES = b'1' * 40
TWOS = b'2' * 40
THREES = b'3' * 40
FOURS = b'4' * 40
class PackedRefsFileTests(TestCase):
def test_split_ref_line_errors(self):
self.assertRaises(errors.PackedRefsException, _split_ref_line,
b'singlefield')
self.assertRaises(errors.PackedRefsException, _split_ref_line,
b'badsha name')
self.assertRaises(errors.PackedRefsException, _split_ref_line,
ONES + b' bad/../refname')
def test_read_without_peeled(self):
f = BytesIO(b'\n'.join([
b'# comment',
ONES + b' ref/1',
TWOS + b' ref/2']))
self.assertEqual([(ONES, b'ref/1'), (TWOS, b'ref/2')],
list(read_packed_refs(f)))
def test_read_without_peeled_errors(self):
f = BytesIO(b'\n'.join([
ONES + b' ref/1',
b'^' + TWOS]))
self.assertRaises(errors.PackedRefsException, list,
read_packed_refs(f))
def test_read_with_peeled(self):
f = BytesIO(b'\n'.join([
ONES + b' ref/1',
TWOS + b' ref/2',
b'^' + THREES,
FOURS + b' ref/4']))
self.assertEqual([
(ONES, b'ref/1', None),
(TWOS, b'ref/2', THREES),
(FOURS, b'ref/4', None),
], list(read_packed_refs_with_peeled(f)))
def test_read_with_peeled_errors(self):
f = BytesIO(b'\n'.join([
b'^' + TWOS,
ONES + b' ref/1']))
self.assertRaises(errors.PackedRefsException, list,
read_packed_refs(f))
f = BytesIO(b'\n'.join([
ONES + b' ref/1',
b'^' + TWOS,
b'^' + THREES]))
self.assertRaises(errors.PackedRefsException, list,
read_packed_refs(f))
def test_write_with_peeled(self):
f = BytesIO()
write_packed_refs(f, {b'ref/1': ONES, b'ref/2': TWOS},
{b'ref/1': THREES})
self.assertEqual(
b'\n'.join([b'# pack-refs with: peeled',
ONES + b' ref/1',
b'^' + THREES,
TWOS + b' ref/2']) + b'\n',
f.getvalue())
def test_write_without_peeled(self):
f = BytesIO()
write_packed_refs(f, {b'ref/1': ONES, b'ref/2': TWOS})
self.assertEqual(b'\n'.join([ONES + b' ref/1',
TWOS + b' ref/2']) + b'\n',
f.getvalue())
# Dict of refs that we expect all RefsContainerTests subclasses to define.
_TEST_REFS = {
b'HEAD': b'42d06bd4b77fed026b154d16493e5deab78f02ec',
b'refs/heads/40-char-ref-aaaaaaaaaaaaaaaaaa':
b'42d06bd4b77fed026b154d16493e5deab78f02ec',
b'refs/heads/master': b'42d06bd4b77fed026b154d16493e5deab78f02ec',
b'refs/heads/packed': b'42d06bd4b77fed026b154d16493e5deab78f02ec',
b'refs/tags/refs-0.1': b'df6800012397fb85c56e7418dd4eb9405dee075c',
b'refs/tags/refs-0.2': b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
b'refs/heads/loop': b'ref: refs/heads/loop',
}
class RefsContainerTests(object):
def test_keys(self):
actual_keys = set(self._refs.keys())
self.assertEqual(set(self._refs.allkeys()), actual_keys)
self.assertEqual(set(_TEST_REFS.keys()), actual_keys)
actual_keys = self._refs.keys(b'refs/heads')
actual_keys.discard(b'loop')
self.assertEqual(
[b'40-char-ref-aaaaaaaaaaaaaaaaaa', b'master', b'packed'],
sorted(actual_keys))
self.assertEqual([b'refs-0.1', b'refs-0.2'],
sorted(self._refs.keys(b'refs/tags')))
def test_iter(self):
actual_keys = set(self._refs.keys())
self.assertEqual(set(self._refs), actual_keys)
self.assertEqual(set(_TEST_REFS.keys()), actual_keys)
def test_as_dict(self):
# refs/heads/loop does not show up even if it exists
expected_refs = dict(_TEST_REFS)
del expected_refs[b'refs/heads/loop']
self.assertEqual(expected_refs, self._refs.as_dict())
def test_get_symrefs(self):
self._refs.set_symbolic_ref(b'refs/heads/src', b'refs/heads/dst')
symrefs = self._refs.get_symrefs()
if b'HEAD' in symrefs:
symrefs.pop(b'HEAD')
self.assertEqual({b'refs/heads/src': b'refs/heads/dst',
b'refs/heads/loop': b'refs/heads/loop'},
symrefs)
def test_setitem(self):
self._refs[b'refs/some/ref'] = (
b'42d06bd4b77fed026b154d16493e5deab78f02ec')
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/some/ref'])
self.assertRaises(
errors.RefFormatError, self._refs.__setitem__,
b'notrefs/foo', b'42d06bd4b77fed026b154d16493e5deab78f02ec')
def test_set_if_equals(self):
nines = b'9' * 40
self.assertFalse(self._refs.set_if_equals(b'HEAD', b'c0ffee', nines))
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'HEAD'])
self.assertTrue(self._refs.set_if_equals(
b'HEAD', b'42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
self.assertEqual(nines, self._refs[b'HEAD'])
# Setting the ref again is a no-op, but will return True.
self.assertTrue(self._refs.set_if_equals(b'HEAD', nines, nines))
self.assertEqual(nines, self._refs[b'HEAD'])
self.assertTrue(self._refs.set_if_equals(b'refs/heads/master', None,
nines))
self.assertEqual(nines, self._refs[b'refs/heads/master'])
self.assertTrue(self._refs.set_if_equals(
b'refs/heads/nonexistant', ZERO_SHA, nines))
self.assertEqual(nines, self._refs[b'refs/heads/nonexistant'])
def test_add_if_new(self):
nines = b'9' * 40
self.assertFalse(self._refs.add_if_new(b'refs/heads/master', nines))
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/heads/master'])
self.assertTrue(self._refs.add_if_new(b'refs/some/ref', nines))
self.assertEqual(nines, self._refs[b'refs/some/ref'])
def test_set_symbolic_ref(self):
self._refs.set_symbolic_ref(b'refs/heads/symbolic',
b'refs/heads/master')
self.assertEqual(b'ref: refs/heads/master',
self._refs.read_loose_ref(b'refs/heads/symbolic'))
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/heads/symbolic'])
def test_set_symbolic_ref_overwrite(self):
nines = b'9' * 40
self.assertFalse(b'refs/heads/symbolic' in self._refs)
self._refs[b'refs/heads/symbolic'] = nines
self.assertEqual(nines,
self._refs.read_loose_ref(b'refs/heads/symbolic'))
self._refs.set_symbolic_ref(b'refs/heads/symbolic',
b'refs/heads/master')
self.assertEqual(b'ref: refs/heads/master',
self._refs.read_loose_ref(b'refs/heads/symbolic'))
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/heads/symbolic'])
def test_check_refname(self):
self._refs._check_refname(b'HEAD')
self._refs._check_refname(b'refs/stash')
self._refs._check_refname(b'refs/heads/foo')
self.assertRaises(errors.RefFormatError, self._refs._check_refname,
b'refs')
self.assertRaises(errors.RefFormatError, self._refs._check_refname,
b'notrefs/foo')
def test_contains(self):
self.assertTrue(b'refs/heads/master' in self._refs)
self.assertFalse(b'refs/heads/bar' in self._refs)
def test_delitem(self):
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/heads/master'])
del self._refs[b'refs/heads/master']
self.assertRaises(KeyError, lambda: self._refs[b'refs/heads/master'])
def test_remove_if_equals(self):
self.assertFalse(self._refs.remove_if_equals(b'HEAD', b'c0ffee'))
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'HEAD'])
self.assertTrue(self._refs.remove_if_equals(
b'refs/tags/refs-0.2',
b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
self.assertTrue(self._refs.remove_if_equals(
b'refs/tags/refs-0.2', ZERO_SHA))
self.assertFalse(b'refs/tags/refs-0.2' in self._refs)
def test_import_refs_name(self):
self._refs[b'refs/remotes/origin/other'] = (
b'48d01bd4b77fed026b154d16493e5deab78f02ec')
self._refs.import_refs(
b'refs/remotes/origin',
{b'master': b'42d06bd4b77fed026b154d16493e5deab78f02ec'})
self.assertEqual(
b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/remotes/origin/master'])
self.assertEqual(
b'48d01bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/remotes/origin/other'])
def test_import_refs_name_prune(self):
self._refs[b'refs/remotes/origin/other'] = (
b'48d01bd4b77fed026b154d16493e5deab78f02ec')
self._refs.import_refs(
b'refs/remotes/origin',
{b'master': b'42d06bd4b77fed026b154d16493e5deab78f02ec'},
prune=True)
self.assertEqual(
b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/remotes/origin/master'])
self.assertNotIn(
b'refs/remotes/origin/other', self._refs)
class DictRefsContainerTests(RefsContainerTests, TestCase):
def setUp(self):
TestCase.setUp(self)
self._refs = DictRefsContainer(dict(_TEST_REFS))
def test_invalid_refname(self):
# FIXME: Move this test into RefsContainerTests, but requires
# some way of injecting invalid refs.
self._refs._refs[b'refs/stash'] = b'00' * 20
expected_refs = dict(_TEST_REFS)
del expected_refs[b'refs/heads/loop']
expected_refs[b'refs/stash'] = b'00' * 20
self.assertEqual(expected_refs, self._refs.as_dict())
class DiskRefsContainerTests(RefsContainerTests, TestCase):
def setUp(self):
TestCase.setUp(self)
self._repo = open_repo('refs.git')
self.addCleanup(tear_down_repo, self._repo)
self._refs = self._repo.refs
def test_get_packed_refs(self):
self.assertEqual({
b'refs/heads/packed': b'42d06bd4b77fed026b154d16493e5deab78f02ec',
b'refs/tags/refs-0.1': b'df6800012397fb85c56e7418dd4eb9405dee075c',
}, self._refs.get_packed_refs())
def test_get_peeled_not_packed(self):
# not packed
self.assertEqual(None, self._refs.get_peeled(b'refs/tags/refs-0.2'))
self.assertEqual(b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
self._refs[b'refs/tags/refs-0.2'])
# packed, known not peelable
self.assertEqual(self._refs[b'refs/heads/packed'],
self._refs.get_peeled(b'refs/heads/packed'))
# packed, peeled
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs.get_peeled(b'refs/tags/refs-0.1'))
def test_setitem(self):
RefsContainerTests.test_setitem(self)
path = os.path.join(self._refs.path, b'refs', b'some', b'ref')
with open(path, 'rb') as f:
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
f.read()[:40])
self.assertRaises(
OSError, self._refs.__setitem__,
b'refs/some/ref/sub', b'42d06bd4b77fed026b154d16493e5deab78f02ec')
def test_setitem_packed(self):
with open(os.path.join(self._refs.path, b'packed-refs'), 'w') as f:
f.write('# pack-refs with: peeled fully-peeled sorted \n')
f.write(
'42d06bd4b77fed026b154d16493e5deab78f02ec refs/heads/packed\n')
# It's allowed to set a new ref on a packed ref, the new ref will be
# placed outside on refs/
self._refs[b'refs/heads/packed'] = (
b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'
)
packed_ref_path = os.path.join(
self._refs.path, b'refs', b'heads', b'packed')
with open(packed_ref_path, 'rb') as f:
self.assertEqual(
b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
f.read()[:40])
self.assertRaises(
OSError, self._refs.__setitem__,
b'refs/heads/packed/sub',
b'42d06bd4b77fed026b154d16493e5deab78f02ec')
def test_setitem_symbolic(self):
ones = b'1' * 40
self._refs[b'HEAD'] = ones
self.assertEqual(ones, self._refs[b'HEAD'])
# ensure HEAD was not modified
f = open(os.path.join(self._refs.path, b'HEAD'), 'rb')
v = next(iter(f)).rstrip(b'\n\r')
f.close()
self.assertEqual(b'ref: refs/heads/master', v)
# ensure the symbolic link was written through
f = open(os.path.join(self._refs.path, b'refs', b'heads', b'master'),
'rb')
self.assertEqual(ones, f.read()[:40])
f.close()
def test_set_if_equals(self):
RefsContainerTests.test_set_if_equals(self)
# ensure symref was followed
self.assertEqual(b'9' * 40, self._refs[b'refs/heads/master'])
# ensure lockfile was deleted
self.assertFalse(os.path.exists(
os.path.join(self._refs.path, b'refs', b'heads', b'master.lock')))
self.assertFalse(os.path.exists(
os.path.join(self._refs.path, b'HEAD.lock')))
def test_add_if_new_packed(self):
# don't overwrite packed ref
self.assertFalse(self._refs.add_if_new(b'refs/tags/refs-0.1',
b'9' * 40))
self.assertEqual(b'df6800012397fb85c56e7418dd4eb9405dee075c',
self._refs[b'refs/tags/refs-0.1'])
def test_add_if_new_symbolic(self):
# Use an empty repo instead of the default.
repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
os.makedirs(repo_dir)
repo = Repo.init(repo_dir)
self.addCleanup(tear_down_repo, repo)
refs = repo.refs
nines = b'9' * 40
self.assertEqual(b'ref: refs/heads/master', refs.read_ref(b'HEAD'))
self.assertFalse(b'refs/heads/master' in refs)
self.assertTrue(refs.add_if_new(b'HEAD', nines))
self.assertEqual(b'ref: refs/heads/master', refs.read_ref(b'HEAD'))
self.assertEqual(nines, refs[b'HEAD'])
self.assertEqual(nines, refs[b'refs/heads/master'])
self.assertFalse(refs.add_if_new(b'HEAD', b'1' * 40))
self.assertEqual(nines, refs[b'HEAD'])
self.assertEqual(nines, refs[b'refs/heads/master'])
def test_follow(self):
self.assertEqual(([b'HEAD', b'refs/heads/master'],
b'42d06bd4b77fed026b154d16493e5deab78f02ec'),
self._refs.follow(b'HEAD'))
self.assertEqual(([b'refs/heads/master'],
b'42d06bd4b77fed026b154d16493e5deab78f02ec'),
self._refs.follow(b'refs/heads/master'))
self.assertRaises(KeyError, self._refs.follow, b'refs/heads/loop')
def test_delitem(self):
RefsContainerTests.test_delitem(self)
ref_file = os.path.join(self._refs.path, b'refs', b'heads', b'master')
self.assertFalse(os.path.exists(ref_file))
self.assertFalse(b'refs/heads/master' in self._refs.get_packed_refs())
def test_delitem_symbolic(self):
self.assertEqual(b'ref: refs/heads/master',
self._refs.read_loose_ref(b'HEAD'))
del self._refs[b'HEAD']
self.assertRaises(KeyError, lambda: self._refs[b'HEAD'])
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs[b'refs/heads/master'])
self.assertFalse(
os.path.exists(os.path.join(self._refs.path, b'HEAD')))
def test_remove_if_equals_symref(self):
# HEAD is a symref, so shouldn't equal its dereferenced value
self.assertFalse(self._refs.remove_if_equals(
b'HEAD', b'42d06bd4b77fed026b154d16493e5deab78f02ec'))
self.assertTrue(self._refs.remove_if_equals(
b'refs/heads/master', b'42d06bd4b77fed026b154d16493e5deab78f02ec'))
self.assertRaises(KeyError, lambda: self._refs[b'refs/heads/master'])
# HEAD is now a broken symref
self.assertRaises(KeyError, lambda: self._refs[b'HEAD'])
self.assertEqual(b'ref: refs/heads/master',
self._refs.read_loose_ref(b'HEAD'))
self.assertFalse(os.path.exists(
os.path.join(self._refs.path, b'refs', b'heads', b'master.lock')))
self.assertFalse(os.path.exists(
os.path.join(self._refs.path, b'HEAD.lock')))
def test_remove_packed_without_peeled(self):
refs_file = os.path.join(self._repo.path, 'packed-refs')
f = GitFile(refs_file)
refs_data = f.read()
f.close()
f = GitFile(refs_file, 'wb')
- f.write(b'\n'.join(l for l in refs_data.split(b'\n')
- if not l or l[0] not in b'#^'))
+ f.write(b'\n'.join(line for line in refs_data.split(b'\n')
+ if not line or line[0] not in b'#^'))
f.close()
self._repo = Repo(self._repo.path)
refs = self._repo.refs
self.assertTrue(refs.remove_if_equals(
b'refs/heads/packed', b'42d06bd4b77fed026b154d16493e5deab78f02ec'))
def test_remove_if_equals_packed(self):
# test removing ref that is only packed
self.assertEqual(b'df6800012397fb85c56e7418dd4eb9405dee075c',
self._refs[b'refs/tags/refs-0.1'])
self.assertTrue(
self._refs.remove_if_equals(
b'refs/tags/refs-0.1',
b'df6800012397fb85c56e7418dd4eb9405dee075c'))
self.assertRaises(KeyError, lambda: self._refs[b'refs/tags/refs-0.1'])
def test_remove_parent(self):
self._refs[b'refs/heads/foo/bar'] = (
b'df6800012397fb85c56e7418dd4eb9405dee075c'
)
del self._refs[b'refs/heads/foo/bar']
ref_file = os.path.join(
self._refs.path, b'refs', b'heads', b'foo', b'bar',
)
self.assertFalse(os.path.exists(ref_file))
ref_file = os.path.join(self._refs.path, b'refs', b'heads', b'foo')
self.assertFalse(os.path.exists(ref_file))
ref_file = os.path.join(self._refs.path, b'refs', b'heads')
self.assertTrue(os.path.exists(ref_file))
self._refs[b'refs/heads/foo'] = (
b'df6800012397fb85c56e7418dd4eb9405dee075c'
)
def test_read_ref(self):
self.assertEqual(b'ref: refs/heads/master',
self._refs.read_ref(b'HEAD'))
self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
self._refs.read_ref(b'refs/heads/packed'))
self.assertEqual(None, self._refs.read_ref(b'nonexistant'))
def test_read_loose_ref(self):
self._refs[b'refs/heads/foo'] = (
b'df6800012397fb85c56e7418dd4eb9405dee075c'
)
self.assertEqual(None, self._refs.read_ref(b'refs/heads/foo/bar'))
def test_non_ascii(self):
try:
encoded_ref = os.fsencode(u'refs/tags/schön')
except UnicodeEncodeError:
raise SkipTest(
"filesystem encoding doesn't support special character")
p = os.path.join(os.fsencode(self._repo.path), encoded_ref)
with open(p, 'w') as f:
f.write('00' * 20)
expected_refs = dict(_TEST_REFS)
expected_refs[encoded_ref] = b'00' * 20
del expected_refs[b'refs/heads/loop']
self.assertEqual(expected_refs, self._repo.get_refs())
def test_cyrillic(self):
if sys.platform == 'win32':
raise SkipTest(
"filesystem encoding doesn't support arbitrary bytes")
# reported in https://github.com/dulwich/dulwich/issues/608
name = b'\xcd\xee\xe2\xe0\xff\xe2\xe5\xf2\xea\xe01'
encoded_ref = b'refs/heads/' + name
with open(os.path.join(
os.fsencode(self._repo.path), encoded_ref), 'w') as f:
f.write('00' * 20)
expected_refs = set(_TEST_REFS.keys())
expected_refs.add(encoded_ref)
self.assertEqual(expected_refs,
set(self._repo.refs.allkeys()))
self.assertEqual({r[len(b'refs/'):] for r in expected_refs
if r.startswith(b'refs/')},
set(self._repo.refs.subkeys(b'refs/')))
expected_refs.remove(b'refs/heads/loop')
expected_refs.add(b'HEAD')
self.assertEqual(expected_refs,
set(self._repo.get_refs().keys()))
_TEST_REFS_SERIALIZED = (
b'42d06bd4b77fed026b154d16493e5deab78f02ec\t'
b'refs/heads/40-char-ref-aaaaaaaaaaaaaaaaaa\n'
b'42d06bd4b77fed026b154d16493e5deab78f02ec\trefs/heads/master\n'
b'42d06bd4b77fed026b154d16493e5deab78f02ec\trefs/heads/packed\n'
b'df6800012397fb85c56e7418dd4eb9405dee075c\trefs/tags/refs-0.1\n'
b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8\trefs/tags/refs-0.2\n')
class InfoRefsContainerTests(TestCase):
def test_invalid_refname(self):
text = _TEST_REFS_SERIALIZED + b'00' * 20 + b'\trefs/stash\n'
refs = InfoRefsContainer(BytesIO(text))
expected_refs = dict(_TEST_REFS)
del expected_refs[b'HEAD']
expected_refs[b'refs/stash'] = b'00' * 20
del expected_refs[b'refs/heads/loop']
self.assertEqual(expected_refs, refs.as_dict())
def test_keys(self):
refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED))
actual_keys = set(refs.keys())
self.assertEqual(set(refs.allkeys()), actual_keys)
expected_refs = dict(_TEST_REFS)
del expected_refs[b'HEAD']
del expected_refs[b'refs/heads/loop']
self.assertEqual(set(expected_refs.keys()), actual_keys)
actual_keys = refs.keys(b'refs/heads')
actual_keys.discard(b'loop')
self.assertEqual(
[b'40-char-ref-aaaaaaaaaaaaaaaaaa', b'master', b'packed'],
sorted(actual_keys))
self.assertEqual([b'refs-0.1', b'refs-0.2'],
sorted(refs.keys(b'refs/tags')))
def test_as_dict(self):
refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED))
# refs/heads/loop does not show up even if it exists
expected_refs = dict(_TEST_REFS)
del expected_refs[b'HEAD']
del expected_refs[b'refs/heads/loop']
self.assertEqual(expected_refs, refs.as_dict())
def test_contains(self):
refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED))
self.assertTrue(b'refs/heads/master' in refs)
self.assertFalse(b'refs/heads/bar' in refs)
def test_get_peeled(self):
refs = InfoRefsContainer(BytesIO(_TEST_REFS_SERIALIZED))
# refs/heads/loop does not show up even if it exists
self.assertEqual(
_TEST_REFS[b'refs/heads/master'],
refs.get_peeled(b'refs/heads/master'))
class ParseSymrefValueTests(TestCase):
def test_valid(self):
self.assertEqual(
b'refs/heads/foo',
parse_symref_value(b'ref: refs/heads/foo'))
def test_invalid(self):
self.assertRaises(ValueError, parse_symref_value, b'foobar')
class StripPeeledRefsTests(TestCase):
all_refs = {
b'refs/heads/master': b'8843d7f92416211de9ebb963ff4ce28125932878',
b'refs/heads/testing': b'186a005b134d8639a58b6731c7c1ea821a6eedba',
b'refs/tags/1.0.0': b'a93db4b0360cc635a2b93675010bac8d101f73f0',
b'refs/tags/1.0.0^{}': b'a93db4b0360cc635a2b93675010bac8d101f73f0',
b'refs/tags/2.0.0': b'0749936d0956c661ac8f8d3483774509c165f89e',
b'refs/tags/2.0.0^{}': b'0749936d0956c661ac8f8d3483774509c165f89e',
}
non_peeled_refs = {
b'refs/heads/master': b'8843d7f92416211de9ebb963ff4ce28125932878',
b'refs/heads/testing': b'186a005b134d8639a58b6731c7c1ea821a6eedba',
b'refs/tags/1.0.0': b'a93db4b0360cc635a2b93675010bac8d101f73f0',
b'refs/tags/2.0.0': b'0749936d0956c661ac8f8d3483774509c165f89e',
}
def test_strip_peeled_refs(self):
# Simple check of two dicts
self.assertEqual(
strip_peeled_refs(self.all_refs),
self.non_peeled_refs)