diff --git a/dulwich/contrib/test_swift.py b/dulwich/contrib/test_swift.py index d4b56392..5891b735 100644 --- a/dulwich/contrib/test_swift.py +++ b/dulwich/contrib/test_swift.py @@ -1,484 +1,477 @@ # test_swift.py -- Unittests for the Swift backend. # Copyright (C) 2013 eNovance SAS # # Author: Fabien Boucher # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for dulwich.contrib.swift.""" import posixpath from time import time from io import BytesIO try: from StringIO import StringIO except ImportError: from io import StringIO -import sys from unittest import skipIf from dulwich.tests import ( TestCase, ) from dulwich.tests.test_object_store import ( ObjectStoreTests, ) -from dulwich.tests.utils import ( - build_pack, - ) from dulwich.objects import ( Blob, Commit, Tree, Tag, parse_timezone, ) -from dulwich.pack import ( - REF_DELTA, - ) try: from simplejson import dumps as json_dumps except ImportError: from json import dumps as json_dumps missing_libs = [] try: import gevent # noqa:F401 except ImportError: missing_libs.append("gevent") try: import geventhttpclient # noqa:F401 except ImportError: missing_libs.append("geventhttpclient") try: from mock import patch except ImportError: missing_libs.append("mock") skipmsg = "Required libraries are not installed (%r)" % missing_libs if not missing_libs: from dulwich.contrib import swift config_file = """[swift] auth_url = http://127.0.0.1:8080/auth/%(version_str)s auth_ver = %(version_int)s username = test;tester password = testing region_name = %(region_name)s endpoint_type = %(endpoint_type)s concurrency = %(concurrency)s chunk_length = %(chunk_length)s cache_length = %(cache_length)s http_pool_length = %(http_pool_length)s http_timeout = %(http_timeout)s """ def_config_file = {'version_str': 'v1.0', 'version_int': 1, 'concurrency': 1, 'chunk_length': 12228, 'cache_length': 1, 'region_name': 'test', 'endpoint_type': 'internalURL', 'http_pool_length': 1, 'http_timeout': 1} def create_swift_connector(store={}): return lambda root, conf: FakeSwiftConnector(root, conf=conf, store=store) class Response(object): def __init__(self, headers={}, status=200, content=None): self.headers = headers self.status_code = status self.content = content def __getitem__(self, key): return self.headers[key] def items(self): return self.headers.items() def read(self): return self.content def fake_auth_request_v1(*args, **kwargs): ret = Response({'X-Storage-Url': 'http://127.0.0.1:8080/v1.0/AUTH_fakeuser', 'X-Auth-Token': '12' * 10}, 200) return ret def fake_auth_request_v1_error(*args, **kwargs): ret = Response({}, 401) return ret def fake_auth_request_v2(*args, **kwargs): s_url = 'http://127.0.0.1:8080/v1.0/AUTH_fakeuser' resp = {'access': {'token': {'id': '12' * 10}, 'serviceCatalog': [ {'type': 'object-store', 'endpoints': [{'region': 'test', 'internalURL': s_url, }, ] }, ] } } ret = Response(status=200, content=json_dumps(resp)) return ret def create_commit(data, marker=b'Default', blob=None): if not blob: blob = Blob.from_string(b'The blob content ' + marker) tree = Tree() tree.add(b"thefile_" + marker, 0o100644, blob.id) cmt = Commit() if data: assert isinstance(data[-1], Commit) cmt.parents = [data[-1].id] cmt.tree = tree.id author = b"John Doe " + marker + b" " cmt.author = cmt.committer = author tz = parse_timezone(b'-0200')[0] cmt.commit_time = cmt.author_time = int(time()) cmt.commit_timezone = cmt.author_timezone = tz cmt.encoding = b"UTF-8" cmt.message = b"The commit message " + marker tag = Tag() tag.tagger = b"john@doe.net" tag.message = b"Annotated tag" tag.tag_timezone = parse_timezone(b'-0200')[0] tag.tag_time = cmt.author_time tag.object = (Commit, cmt.id) tag.name = b"v_" + marker + b"_0.1" return blob, tree, tag, cmt def create_commits(length=1, marker=b'Default'): data = [] for i in range(0, length): _marker = ("%s_%s" % (marker, i)).encode() blob, tree, tag, cmt = create_commit(data, _marker) data.extend([blob, tree, tag, cmt]) return data @skipIf(missing_libs, skipmsg) class FakeSwiftConnector(object): def __init__(self, root, conf, store=None): if store: self.store = store else: self.store = {} self.conf = conf self.root = root self.concurrency = 1 self.chunk_length = 12228 self.cache_length = 1 def put_object(self, name, content): name = posixpath.join(self.root, name) if hasattr(content, 'seek'): content.seek(0) content = content.read() self.store[name] = content def get_object(self, name, range=None): name = posixpath.join(self.root, name) if not range: try: return BytesIO(self.store[name]) except KeyError: return None else: l, r = range.split('-') try: if not l: r = -int(r) return self.store[name][r:] else: return self.store[name][int(l):int(r)] except KeyError: return None def get_container_objects(self): return [{'name': k.replace(self.root + '/', '')} for k in self.store] def create_root(self): if self.root in self.store.keys(): pass else: self.store[self.root] = '' def get_object_stat(self, name): name = posixpath.join(self.root, name) if name not in self.store: return None return {'content-length': len(self.store[name])} @skipIf(missing_libs, skipmsg) class TestSwiftRepo(TestCase): def setUp(self): super(TestSwiftRepo, self).setUp() self.conf = swift.load_conf(file=StringIO(config_file % def_config_file)) def test_init(self): store = {'fakerepo/objects/pack': ''} with patch('dulwich.contrib.swift.SwiftConnector', new_callable=create_swift_connector, store=store): swift.SwiftRepo('fakerepo', conf=self.conf) def test_init_no_data(self): with patch('dulwich.contrib.swift.SwiftConnector', new_callable=create_swift_connector): self.assertRaises(Exception, swift.SwiftRepo, 'fakerepo', self.conf) def test_init_bad_data(self): store = {'fakerepo/.git/objects/pack': ''} with patch('dulwich.contrib.swift.SwiftConnector', new_callable=create_swift_connector, store=store): self.assertRaises(Exception, swift.SwiftRepo, 'fakerepo', self.conf) def test_put_named_file(self): store = {'fakerepo/objects/pack': ''} with patch('dulwich.contrib.swift.SwiftConnector', new_callable=create_swift_connector, store=store): repo = swift.SwiftRepo('fakerepo', conf=self.conf) desc = b'Fake repo' repo._put_named_file('description', desc) self.assertEqual(repo.scon.store['fakerepo/description'], desc) def test_init_bare(self): fsc = FakeSwiftConnector('fakeroot', conf=self.conf) with patch('dulwich.contrib.swift.SwiftConnector', new_callable=create_swift_connector, store=fsc.store): swift.SwiftRepo.init_bare(fsc, conf=self.conf) self.assertIn('fakeroot/objects/pack', fsc.store) self.assertIn('fakeroot/info/refs', fsc.store) self.assertIn('fakeroot/description', fsc.store) @skipIf(missing_libs, skipmsg) class TestSwiftInfoRefsContainer(TestCase): def setUp(self): super(TestSwiftInfoRefsContainer, self).setUp() content = ( b"22effb216e3a82f97da599b8885a6cadb488b4c5\trefs/heads/master\n" b"cca703b0e1399008b53a1a236d6b4584737649e4\trefs/heads/dev") self.store = {'fakerepo/info/refs': content} self.conf = swift.load_conf(file=StringIO(config_file % def_config_file)) self.fsc = FakeSwiftConnector('fakerepo', conf=self.conf) self.object_store = {} def test_init(self): """info/refs does not exists""" irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) self.assertEqual(len(irc._refs), 0) self.fsc.store = self.store irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) self.assertIn(b'refs/heads/dev', irc.allkeys()) self.assertIn(b'refs/heads/master', irc.allkeys()) def test_set_if_equals(self): self.fsc.store = self.store irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) irc.set_if_equals(b'refs/heads/dev', b"cca703b0e1399008b53a1a236d6b4584737649e4", b'1'*40) self.assertEqual(irc[b'refs/heads/dev'], b'1'*40) def test_remove_if_equals(self): self.fsc.store = self.store irc = swift.SwiftInfoRefsContainer(self.fsc, self.object_store) irc.remove_if_equals(b'refs/heads/dev', b"cca703b0e1399008b53a1a236d6b4584737649e4") self.assertNotIn(b'refs/heads/dev', irc.allkeys()) @skipIf(missing_libs, skipmsg) class TestSwiftConnector(TestCase): def setUp(self): super(TestSwiftConnector, self).setUp() self.conf = swift.load_conf(file=StringIO(config_file % def_config_file)) with patch('geventhttpclient.HTTPClient.request', fake_auth_request_v1): self.conn = swift.SwiftConnector('fakerepo', conf=self.conf) def test_init_connector(self): self.assertEqual(self.conn.auth_ver, '1') self.assertEqual(self.conn.auth_url, 'http://127.0.0.1:8080/auth/v1.0') self.assertEqual(self.conn.user, 'test:tester') self.assertEqual(self.conn.password, 'testing') self.assertEqual(self.conn.root, 'fakerepo') self.assertEqual(self.conn.storage_url, 'http://127.0.0.1:8080/v1.0/AUTH_fakeuser') self.assertEqual(self.conn.token, '12' * 10) self.assertEqual(self.conn.http_timeout, 1) self.assertEqual(self.conn.http_pool_length, 1) self.assertEqual(self.conn.concurrency, 1) self.conf.set('swift', 'auth_ver', '2') self.conf.set('swift', 'auth_url', 'http://127.0.0.1:8080/auth/v2.0') with patch('geventhttpclient.HTTPClient.request', fake_auth_request_v2): conn = swift.SwiftConnector('fakerepo', conf=self.conf) self.assertEqual(conn.user, 'tester') self.assertEqual(conn.tenant, 'test') self.conf.set('swift', 'auth_ver', '1') self.conf.set('swift', 'auth_url', 'http://127.0.0.1:8080/auth/v1.0') with patch('geventhttpclient.HTTPClient.request', fake_auth_request_v1_error): self.assertRaises(swift.SwiftException, lambda: swift.SwiftConnector('fakerepo', conf=self.conf)) def test_root_exists(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response()): self.assertEqual(self.conn.test_root_exists(), True) def test_root_not_exists(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response(status=404)): self.assertEqual(self.conn.test_root_exists(), None) def test_create_root(self): with patch('dulwich.contrib.swift.SwiftConnector.test_root_exists', lambda *args: None): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response()): self.assertEqual(self.conn.create_root(), None) def test_create_root_fails(self): with patch('dulwich.contrib.swift.SwiftConnector.test_root_exists', lambda *args: None): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response(status=404)): self.assertRaises(swift.SwiftException, lambda: self.conn.create_root()) def test_get_container_objects(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response(content=json_dumps( (({'name': 'a'}, {'name': 'b'}))))): self.assertEqual(len(self.conn.get_container_objects()), 2) def test_get_container_objects_fails(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response(status=404)): self.assertEqual(self.conn.get_container_objects(), None) def test_get_object_stat(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response(headers={'content-length': '10'})): self.assertEqual(self.conn.get_object_stat('a')['content-length'], '10') def test_get_object_stat_fails(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response(status=404)): self.assertEqual(self.conn.get_object_stat('a'), None) def test_put_object(self): with patch('geventhttpclient.HTTPClient.request', lambda *args, **kwargs: Response()): self.assertEqual(self.conn.put_object('a', BytesIO(b'content')), None) def test_put_object_fails(self): with patch('geventhttpclient.HTTPClient.request', lambda *args, **kwargs: Response(status=400)): self.assertRaises(swift.SwiftException, lambda: self.conn.put_object( 'a', BytesIO(b'content'))) def test_get_object(self): with patch('geventhttpclient.HTTPClient.request', lambda *args, **kwargs: Response(content=b'content')): self.assertEqual(self.conn.get_object('a').read(), b'content') with patch('geventhttpclient.HTTPClient.request', lambda *args, **kwargs: Response(content=b'content')): self.assertEqual( self.conn.get_object('a', range='0-6'), b'content') def test_get_object_fails(self): with patch('geventhttpclient.HTTPClient.request', lambda *args, **kwargs: Response(status=404)): self.assertEqual(self.conn.get_object('a'), None) def test_del_object(self): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response()): self.assertEqual(self.conn.del_object('a'), None) def test_del_root(self): with patch('dulwich.contrib.swift.SwiftConnector.del_object', lambda *args: None): with patch('dulwich.contrib.swift.SwiftConnector.' 'get_container_objects', lambda *args: ({'name': 'a'}, {'name': 'b'})): with patch('geventhttpclient.HTTPClient.request', lambda *args: Response()): self.assertEqual(self.conn.del_root(), None) @skipIf(missing_libs, skipmsg) class SwiftObjectStoreTests(ObjectStoreTests, TestCase): def setUp(self): TestCase.setUp(self) conf = swift.load_conf(file=StringIO(config_file % def_config_file)) fsc = FakeSwiftConnector('fakerepo', conf=conf) self.store = swift.SwiftObjectStore(fsc) diff --git a/dulwich/diff_tree.py b/dulwich/diff_tree.py index 4e759f4d..d81fa19e 100644 --- a/dulwich/diff_tree.py +++ b/dulwich/diff_tree.py @@ -1,623 +1,623 @@ # diff_tree.py -- Utilities for diffing files and trees. # Copyright (C) 2010 Google, Inc. # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Utilities for diffing files and trees.""" -import sys + from collections import ( defaultdict, namedtuple, ) from io import BytesIO from itertools import chain import stat from dulwich.objects import ( S_ISGITLINK, TreeEntry, ) # TreeChange type constants. CHANGE_ADD = 'add' CHANGE_MODIFY = 'modify' CHANGE_DELETE = 'delete' CHANGE_RENAME = 'rename' CHANGE_COPY = 'copy' CHANGE_UNCHANGED = 'unchanged' RENAME_CHANGE_TYPES = (CHANGE_RENAME, CHANGE_COPY) _NULL_ENTRY = TreeEntry(None, None, None) _MAX_SCORE = 100 RENAME_THRESHOLD = 60 MAX_FILES = 200 REWRITE_THRESHOLD = None class TreeChange(namedtuple('TreeChange', ['type', 'old', 'new'])): """Named tuple a single change between two trees.""" @classmethod def add(cls, new): return cls(CHANGE_ADD, _NULL_ENTRY, new) @classmethod def delete(cls, old): return cls(CHANGE_DELETE, old, _NULL_ENTRY) def _tree_entries(path, tree): result = [] if not tree: return result for entry in tree.iteritems(name_order=True): result.append(entry.in_path(path)) return result def _merge_entries(path, tree1, tree2): """Merge the entries of two trees. Args: path: A path to prepend to all tree entry names. tree1: The first Tree object to iterate, or None. tree2: The second Tree object to iterate, or None. Returns: A list of pairs of TreeEntry objects for each pair of entries in the trees. If an entry exists in one tree but not the other, the other entry will have all attributes set to None. If neither entry's path is None, they are guaranteed to match. """ entries1 = _tree_entries(path, tree1) entries2 = _tree_entries(path, tree2) i1 = i2 = 0 len1 = len(entries1) len2 = len(entries2) result = [] while i1 < len1 and i2 < len2: entry1 = entries1[i1] entry2 = entries2[i2] if entry1.path < entry2.path: result.append((entry1, _NULL_ENTRY)) i1 += 1 elif entry1.path > entry2.path: result.append((_NULL_ENTRY, entry2)) i2 += 1 else: result.append((entry1, entry2)) i1 += 1 i2 += 1 for i in range(i1, len1): result.append((entries1[i], _NULL_ENTRY)) for i in range(i2, len2): result.append((_NULL_ENTRY, entries2[i])) return result def _is_tree(entry): mode = entry.mode if mode is None: return False return stat.S_ISDIR(mode) def walk_trees(store, tree1_id, tree2_id, prune_identical=False): """Recursively walk all the entries of two trees. Iteration is depth-first pre-order, as in e.g. os.walk. Args: store: An ObjectStore for looking up objects. tree1_id: The SHA of the first Tree object to iterate, or None. tree2_id: The SHA of the second Tree object to iterate, or None. param prune_identical: If True, identical subtrees will not be walked. Returns: Iterator over Pairs of TreeEntry objects for each pair of entries in the trees and their subtrees recursively. If an entry exists in one tree but not the other, the other entry will have all attributes set to None. If neither entry's path is None, they are guaranteed to match. """ # This could be fairly easily generalized to >2 trees if we find a use # case. mode1 = tree1_id and stat.S_IFDIR or None mode2 = tree2_id and stat.S_IFDIR or None todo = [(TreeEntry(b'', mode1, tree1_id), TreeEntry(b'', mode2, tree2_id))] while todo: entry1, entry2 = todo.pop() is_tree1 = _is_tree(entry1) is_tree2 = _is_tree(entry2) if prune_identical and is_tree1 and is_tree2 and entry1 == entry2: continue tree1 = is_tree1 and store[entry1.sha] or None tree2 = is_tree2 and store[entry2.sha] or None path = entry1.path or entry2.path todo.extend(reversed(_merge_entries(path, tree1, tree2))) yield entry1, entry2 def _skip_tree(entry, include_trees): if entry.mode is None or (not include_trees and stat.S_ISDIR(entry.mode)): return _NULL_ENTRY return entry def tree_changes(store, tree1_id, tree2_id, want_unchanged=False, rename_detector=None, include_trees=False, change_type_same=False): """Find the differences between the contents of two trees. Args: store: An ObjectStore for looking up objects. tree1_id: The SHA of the source tree. tree2_id: The SHA of the target tree. want_unchanged: If True, include TreeChanges for unmodified entries as well. include_trees: Whether to include trees rename_detector: RenameDetector object for detecting renames. change_type_same: Whether to report change types in the same entry or as delete+add. Returns: Iterator over TreeChange instances for each change between the source and target tree. """ if (rename_detector is not None and tree1_id is not None and tree2_id is not None): for change in rename_detector.changes_with_renames( tree1_id, tree2_id, want_unchanged=want_unchanged, include_trees=include_trees): yield change return entries = walk_trees(store, tree1_id, tree2_id, prune_identical=(not want_unchanged)) for entry1, entry2 in entries: if entry1 == entry2 and not want_unchanged: continue # Treat entries for trees as missing. entry1 = _skip_tree(entry1, include_trees) entry2 = _skip_tree(entry2, include_trees) if entry1 != _NULL_ENTRY and entry2 != _NULL_ENTRY: if (stat.S_IFMT(entry1.mode) != stat.S_IFMT(entry2.mode) and not change_type_same): # File type changed: report as delete/add. yield TreeChange.delete(entry1) entry1 = _NULL_ENTRY change_type = CHANGE_ADD elif entry1 == entry2: change_type = CHANGE_UNCHANGED else: change_type = CHANGE_MODIFY elif entry1 != _NULL_ENTRY: change_type = CHANGE_DELETE elif entry2 != _NULL_ENTRY: change_type = CHANGE_ADD else: # Both were None because at least one was a tree. continue yield TreeChange(change_type, entry1, entry2) def _all_eq(seq, key, value): for e in seq: if key(e) != value: return False return True def _all_same(seq, key): return _all_eq(seq[1:], key, key(seq[0])) def tree_changes_for_merge(store, parent_tree_ids, tree_id, rename_detector=None): """Get the tree changes for a merge tree relative to all its parents. Args: store: An ObjectStore for looking up objects. parent_tree_ids: An iterable of the SHAs of the parent trees. tree_id: The SHA of the merge tree. rename_detector: RenameDetector object for detecting renames. Returns: Iterator over lists of TreeChange objects, one per conflicted path in the merge. Each list contains one element per parent, with the TreeChange for that path relative to that parent. An element may be None if it never existed in one parent and was deleted in two others. A path is only included in the output if it is a conflict, i.e. its SHA in the merge tree is not found in any of the parents, or in the case of deletes, if not all of the old SHAs match. """ all_parent_changes = [tree_changes(store, t, tree_id, rename_detector=rename_detector) for t in parent_tree_ids] num_parents = len(parent_tree_ids) changes_by_path = defaultdict(lambda: [None] * num_parents) # Organize by path. for i, parent_changes in enumerate(all_parent_changes): for change in parent_changes: if change.type == CHANGE_DELETE: path = change.old.path else: path = change.new.path changes_by_path[path][i] = change def old_sha(c): return c.old.sha def change_type(c): return c.type # Yield only conflicting changes. for _, changes in sorted(changes_by_path.items()): assert len(changes) == num_parents have = [c for c in changes if c is not None] if _all_eq(have, change_type, CHANGE_DELETE): if not _all_same(have, old_sha): yield changes elif not _all_same(have, change_type): yield changes elif None not in changes: # If no change was found relative to one parent, that means the SHA # must have matched the SHA in that parent, so it is not a # conflict. yield changes _BLOCK_SIZE = 64 def _count_blocks(obj): """Count the blocks in an object. Splits the data into blocks either on lines or <=64-byte chunks of lines. Args: obj: The object to count blocks for. Returns: A dict of block hashcode -> total bytes occurring. """ block_counts = defaultdict(int) block = BytesIO() n = 0 # Cache attrs as locals to avoid expensive lookups in the inner loop. block_write = block.write block_seek = block.seek block_truncate = block.truncate block_getvalue = block.getvalue for c in chain(*obj.as_raw_chunks()): c = c.to_bytes(1, 'big') block_write(c) n += 1 if c == b'\n' or n == _BLOCK_SIZE: value = block_getvalue() block_counts[hash(value)] += len(value) block_seek(0) block_truncate() n = 0 if n > 0: last_block = block_getvalue() block_counts[hash(last_block)] += len(last_block) return block_counts def _common_bytes(blocks1, blocks2): """Count the number of common bytes in two block count dicts. Args: block1: The first dict of block hashcode -> total bytes. block2: The second dict of block hashcode -> total bytes. Returns: The number of bytes in common between blocks1 and blocks2. This is only approximate due to possible hash collisions. """ # Iterate over the smaller of the two dicts, since this is symmetrical. if len(blocks1) > len(blocks2): blocks1, blocks2 = blocks2, blocks1 score = 0 for block, count1 in blocks1.items(): count2 = blocks2.get(block) if count2: score += min(count1, count2) return score def _similarity_score(obj1, obj2, block_cache=None): """Compute a similarity score for two objects. Args: obj1: The first object to score. obj2: The second object to score. block_cache: An optional dict of SHA to block counts to cache results between calls. Returns: The similarity score between the two objects, defined as the number of bytes in common between the two objects divided by the maximum size, scaled to the range 0-100. """ if block_cache is None: block_cache = {} if obj1.id not in block_cache: block_cache[obj1.id] = _count_blocks(obj1) if obj2.id not in block_cache: block_cache[obj2.id] = _count_blocks(obj2) common_bytes = _common_bytes(block_cache[obj1.id], block_cache[obj2.id]) max_size = max(obj1.raw_length(), obj2.raw_length()) if not max_size: return _MAX_SCORE return int(float(common_bytes) * _MAX_SCORE / max_size) def _tree_change_key(entry): # Sort by old path then new path. If only one exists, use it for both keys. path1 = entry.old.path path2 = entry.new.path if path1 is None: path1 = path2 if path2 is None: path2 = path1 return (path1, path2) class RenameDetector(object): """Object for handling rename detection between two trees.""" def __init__(self, store, rename_threshold=RENAME_THRESHOLD, max_files=MAX_FILES, rewrite_threshold=REWRITE_THRESHOLD, find_copies_harder=False): """Initialize the rename detector. Args: store: An ObjectStore for looking up objects. rename_threshold: The threshold similarity score for considering an add/delete pair to be a rename/copy; see _similarity_score. max_files: The maximum number of adds and deletes to consider, or None for no limit. The detector is guaranteed to compare no more than max_files ** 2 add/delete pairs. This limit is provided because rename detection can be quadratic in the project size. If the limit is exceeded, no content rename detection is attempted. rewrite_threshold: The threshold similarity score below which a modify should be considered a delete/add, or None to not break modifies; see _similarity_score. find_copies_harder: If True, consider unmodified files when detecting copies. """ self._store = store self._rename_threshold = rename_threshold self._rewrite_threshold = rewrite_threshold self._max_files = max_files self._find_copies_harder = find_copies_harder self._want_unchanged = False def _reset(self): self._adds = [] self._deletes = [] self._changes = [] def _should_split(self, change): if (self._rewrite_threshold is None or change.type != CHANGE_MODIFY or change.old.sha == change.new.sha): return False old_obj = self._store[change.old.sha] new_obj = self._store[change.new.sha] return _similarity_score(old_obj, new_obj) < self._rewrite_threshold def _add_change(self, change): if change.type == CHANGE_ADD: self._adds.append(change) elif change.type == CHANGE_DELETE: self._deletes.append(change) elif self._should_split(change): self._deletes.append(TreeChange.delete(change.old)) self._adds.append(TreeChange.add(change.new)) elif ((self._find_copies_harder and change.type == CHANGE_UNCHANGED) or change.type == CHANGE_MODIFY): # Treat all modifies as potential deletes for rename detection, # but don't split them (to avoid spurious renames). Setting # find_copies_harder means we treat unchanged the same as # modified. self._deletes.append(change) else: self._changes.append(change) def _collect_changes(self, tree1_id, tree2_id): want_unchanged = self._find_copies_harder or self._want_unchanged for change in tree_changes(self._store, tree1_id, tree2_id, want_unchanged=want_unchanged, include_trees=self._include_trees): self._add_change(change) def _prune(self, add_paths, delete_paths): self._adds = [a for a in self._adds if a.new.path not in add_paths] self._deletes = [d for d in self._deletes if d.old.path not in delete_paths] def _find_exact_renames(self): add_map = defaultdict(list) for add in self._adds: add_map[add.new.sha].append(add.new) delete_map = defaultdict(list) for delete in self._deletes: # Keep track of whether the delete was actually marked as a delete. # If not, it needs to be marked as a copy. is_delete = delete.type == CHANGE_DELETE delete_map[delete.old.sha].append((delete.old, is_delete)) add_paths = set() delete_paths = set() for sha, sha_deletes in delete_map.items(): sha_adds = add_map[sha] for (old, is_delete), new in zip(sha_deletes, sha_adds): if stat.S_IFMT(old.mode) != stat.S_IFMT(new.mode): continue if is_delete: delete_paths.add(old.path) add_paths.add(new.path) new_type = is_delete and CHANGE_RENAME or CHANGE_COPY self._changes.append(TreeChange(new_type, old, new)) num_extra_adds = len(sha_adds) - len(sha_deletes) # TODO(dborowitz): Less arbitrary way of dealing with extra copies. old = sha_deletes[0][0] if num_extra_adds > 0: for new in sha_adds[-num_extra_adds:]: add_paths.add(new.path) self._changes.append(TreeChange(CHANGE_COPY, old, new)) self._prune(add_paths, delete_paths) def _should_find_content_renames(self): return len(self._adds) * len(self._deletes) <= self._max_files ** 2 def _rename_type(self, check_paths, delete, add): if check_paths and delete.old.path == add.new.path: # If the paths match, this must be a split modify, so make sure it # comes out as a modify. return CHANGE_MODIFY elif delete.type != CHANGE_DELETE: # If it's in deletes but not marked as a delete, it must have been # added due to find_copies_harder, and needs to be marked as a # copy. return CHANGE_COPY return CHANGE_RENAME def _find_content_rename_candidates(self): candidates = self._candidates = [] # TODO: Optimizations: # - Compare object sizes before counting blocks. # - Skip if delete's S_IFMT differs from all adds. # - Skip if adds or deletes is empty. # Match C git's behavior of not attempting to find content renames if # the matrix size exceeds the threshold. if not self._should_find_content_renames(): return block_cache = {} check_paths = self._rename_threshold is not None for delete in self._deletes: if S_ISGITLINK(delete.old.mode): continue # Git links don't exist in this repo. old_sha = delete.old.sha old_obj = self._store[old_sha] block_cache[old_sha] = _count_blocks(old_obj) for add in self._adds: if stat.S_IFMT(delete.old.mode) != stat.S_IFMT(add.new.mode): continue new_obj = self._store[add.new.sha] score = _similarity_score(old_obj, new_obj, block_cache=block_cache) if score > self._rename_threshold: new_type = self._rename_type(check_paths, delete, add) rename = TreeChange(new_type, delete.old, add.new) candidates.append((-score, rename)) def _choose_content_renames(self): # Sort scores from highest to lowest, but keep names in ascending # order. self._candidates.sort() delete_paths = set() add_paths = set() for _, change in self._candidates: new_path = change.new.path if new_path in add_paths: continue old_path = change.old.path orig_type = change.type if old_path in delete_paths: change = TreeChange(CHANGE_COPY, change.old, change.new) # If the candidate was originally a copy, that means it came from a # modified or unchanged path, so we don't want to prune it. if orig_type != CHANGE_COPY: delete_paths.add(old_path) add_paths.add(new_path) self._changes.append(change) self._prune(add_paths, delete_paths) def _join_modifies(self): if self._rewrite_threshold is None: return modifies = {} delete_map = dict((d.old.path, d) for d in self._deletes) for add in self._adds: path = add.new.path delete = delete_map.get(path) if (delete is not None and stat.S_IFMT(delete.old.mode) == stat.S_IFMT(add.new.mode)): modifies[path] = TreeChange(CHANGE_MODIFY, delete.old, add.new) self._adds = [a for a in self._adds if a.new.path not in modifies] self._deletes = [a for a in self._deletes if a.new.path not in modifies] self._changes += modifies.values() def _sorted_changes(self): result = [] result.extend(self._adds) result.extend(self._deletes) result.extend(self._changes) result.sort(key=_tree_change_key) return result def _prune_unchanged(self): if self._want_unchanged: return self._deletes = [ d for d in self._deletes if d.type != CHANGE_UNCHANGED] def changes_with_renames(self, tree1_id, tree2_id, want_unchanged=False, include_trees=False): """Iterate TreeChanges between two tree SHAs, with rename detection.""" self._reset() self._want_unchanged = want_unchanged self._include_trees = include_trees self._collect_changes(tree1_id, tree2_id) self._find_exact_renames() self._find_content_rename_candidates() self._choose_content_renames() self._join_modifies() self._prune_unchanged() return self._sorted_changes() # Hold on to the pure-python implementations for testing. _is_tree_py = _is_tree _merge_entries_py = _merge_entries _count_blocks_py = _count_blocks try: # Try to import C versions from dulwich._diff_tree import _is_tree, _merge_entries, _count_blocks except ImportError: pass diff --git a/dulwich/fastexport.py b/dulwich/fastexport.py index 89ae5932..c4fd1cf6 100644 --- a/dulwich/fastexport.py +++ b/dulwich/fastexport.py @@ -1,245 +1,243 @@ # __init__.py -- Fast export/import functionality # Copyright (C) 2010-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Fast export/import functionality.""" -import sys - from dulwich.index import ( commit_tree, ) from dulwich.objects import ( Blob, Commit, Tag, ZERO_SHA, ) from fastimport import ( # noqa: E402 commands, errors as fastimport_errors, parser, processor, ) import stat # noqa: E402 def split_email(text): (name, email) = text.rsplit(b" <", 1) return (name, email.rstrip(b">")) class GitFastExporter(object): """Generate a fast-export output stream for Git objects.""" def __init__(self, outf, store): self.outf = outf self.store = store self.markers = {} self._marker_idx = 0 def print_cmd(self, cmd): self.outf.write(getattr(cmd, "__bytes__", cmd.__repr__)() + b"\n") def _allocate_marker(self): self._marker_idx += 1 return ("%d" % (self._marker_idx,)).encode('ascii') def _export_blob(self, blob): marker = self._allocate_marker() self.markers[marker] = blob.id return (commands.BlobCommand(marker, blob.data), marker) def emit_blob(self, blob): (cmd, marker) = self._export_blob(blob) self.print_cmd(cmd) return marker def _iter_files(self, base_tree, new_tree): for ((old_path, new_path), (old_mode, new_mode), (old_hexsha, new_hexsha)) in \ self.store.tree_changes(base_tree, new_tree): if new_path is None: yield commands.FileDeleteCommand(old_path) continue if not stat.S_ISDIR(new_mode): blob = self.store[new_hexsha] marker = self.emit_blob(blob) if old_path != new_path and old_path is not None: yield commands.FileRenameCommand(old_path, new_path) if old_mode != new_mode or old_hexsha != new_hexsha: prefixed_marker = b':' + marker yield commands.FileModifyCommand( new_path, new_mode, prefixed_marker, None ) def _export_commit(self, commit, ref, base_tree=None): file_cmds = list(self._iter_files(base_tree, commit.tree)) marker = self._allocate_marker() if commit.parents: from_ = commit.parents[0] merges = commit.parents[1:] else: from_ = None merges = [] author, author_email = split_email(commit.author) committer, committer_email = split_email(commit.committer) cmd = commands.CommitCommand( ref, marker, (author, author_email, commit.author_time, commit.author_timezone), (committer, committer_email, commit.commit_time, commit.commit_timezone), commit.message, from_, merges, file_cmds) return (cmd, marker) def emit_commit(self, commit, ref, base_tree=None): cmd, marker = self._export_commit(commit, ref, base_tree) self.print_cmd(cmd) return marker class GitImportProcessor(processor.ImportProcessor): """An import processor that imports into a Git repository using Dulwich. """ # FIXME: Batch creation of objects? def __init__(self, repo, params=None, verbose=False, outf=None): processor.ImportProcessor.__init__(self, params, verbose) self.repo = repo self.last_commit = ZERO_SHA self.markers = {} self._contents = {} def lookup_object(self, objectish): if objectish.startswith(b":"): return self.markers[objectish[1:]] return objectish def import_stream(self, stream): p = parser.ImportParser(stream) self.process(p.iter_commands) return self.markers def blob_handler(self, cmd): """Process a BlobCommand.""" blob = Blob.from_string(cmd.data) self.repo.object_store.add_object(blob) if cmd.mark: self.markers[cmd.mark] = blob.id def checkpoint_handler(self, cmd): """Process a CheckpointCommand.""" pass def commit_handler(self, cmd): """Process a CommitCommand.""" commit = Commit() if cmd.author is not None: author = cmd.author else: author = cmd.committer (author_name, author_email, author_timestamp, author_timezone) = author (committer_name, committer_email, commit_timestamp, commit_timezone) = cmd.committer commit.author = author_name + b" <" + author_email + b">" commit.author_timezone = author_timezone commit.author_time = int(author_timestamp) commit.committer = committer_name + b" <" + committer_email + b">" commit.commit_timezone = commit_timezone commit.commit_time = int(commit_timestamp) commit.message = cmd.message commit.parents = [] if cmd.from_: cmd.from_ = self.lookup_object(cmd.from_) self._reset_base(cmd.from_) for filecmd in cmd.iter_files(): if filecmd.name == b"filemodify": if filecmd.data is not None: blob = Blob.from_string(filecmd.data) self.repo.object_store.add(blob) blob_id = blob.id else: blob_id = self.lookup_object(filecmd.dataref) self._contents[filecmd.path] = (filecmd.mode, blob_id) elif filecmd.name == b"filedelete": del self._contents[filecmd.path] elif filecmd.name == b"filecopy": self._contents[filecmd.dest_path] = self._contents[ filecmd.src_path] elif filecmd.name == b"filerename": self._contents[filecmd.new_path] = self._contents[ filecmd.old_path] del self._contents[filecmd.old_path] elif filecmd.name == b"filedeleteall": self._contents = {} else: raise Exception("Command %s not supported" % filecmd.name) commit.tree = commit_tree( self.repo.object_store, ((path, hexsha, mode) for (path, (mode, hexsha)) in self._contents.items())) if self.last_commit != ZERO_SHA: commit.parents.append(self.last_commit) for merge in cmd.merges: commit.parents.append(self.lookup_object(merge)) self.repo.object_store.add_object(commit) self.repo[cmd.ref] = commit.id self.last_commit = commit.id if cmd.mark: self.markers[cmd.mark] = commit.id def progress_handler(self, cmd): """Process a ProgressCommand.""" pass def _reset_base(self, commit_id): if self.last_commit == commit_id: return self._contents = {} self.last_commit = commit_id if commit_id != ZERO_SHA: tree_id = self.repo[commit_id].tree for (path, mode, hexsha) in ( self.repo.object_store.iter_tree_contents(tree_id)): self._contents[path] = (mode, hexsha) def reset_handler(self, cmd): """Process a ResetCommand.""" if cmd.from_ is None: from_ = ZERO_SHA else: from_ = self.lookup_object(cmd.from_) self._reset_base(from_) self.repo.refs[cmd.ref] = from_ def tag_handler(self, cmd): """Process a TagCommand.""" tag = Tag() tag.tagger = cmd.tagger tag.message = cmd.message tag.name = cmd.tag self.repo.add_object(tag) self.repo.refs["refs/tags/" + tag.name] = tag.id def feature_handler(self, cmd): """Process a FeatureCommand.""" raise fastimport_errors.UnknownFeature(cmd.feature_name) diff --git a/dulwich/hooks.py b/dulwich/hooks.py index 4c586844..380b1638 100644 --- a/dulwich/hooks.py +++ b/dulwich/hooks.py @@ -1,197 +1,196 @@ # hooks.py -- for dealing with git hooks # Copyright (C) 2012-2013 Jelmer Vernooij and others. # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Access to hooks.""" import os import subprocess -import sys import tempfile from dulwich.errors import ( HookError, ) class Hook(object): """Generic hook object.""" def execute(self, *args): """Execute the hook with the given args Args: args: argument list to hook Raises: HookError: hook execution failure Returns: a hook may return a useful value """ raise NotImplementedError(self.execute) class ShellHook(Hook): """Hook by executable file Implements standard githooks(5) [0]: [0] http://www.kernel.org/pub/software/scm/git/docs/githooks.html """ def __init__(self, name, path, numparam, pre_exec_callback=None, post_exec_callback=None, cwd=None): """Setup shell hook definition Args: name: name of hook for error messages path: absolute path to executable file numparam: number of requirements parameters pre_exec_callback: closure for setup before execution Defaults to None. Takes in the variable argument list from the execute functions and returns a modified argument list for the shell hook. post_exec_callback: closure for cleanup after execution Defaults to None. Takes in a boolean for hook success and the modified argument list and returns the final hook return value if applicable cwd: working directory to switch to when executing the hook """ self.name = name self.filepath = path self.numparam = numparam self.pre_exec_callback = pre_exec_callback self.post_exec_callback = post_exec_callback self.cwd = cwd def execute(self, *args): """Execute the hook with given args""" if len(args) != self.numparam: raise HookError("Hook %s executed with wrong number of args. \ Expected %d. Saw %d. args: %s" % (self.name, self.numparam, len(args), args)) if (self.pre_exec_callback is not None): args = self.pre_exec_callback(*args) try: ret = subprocess.call([self.filepath] + list(args), cwd=self.cwd) if ret != 0: if (self.post_exec_callback is not None): self.post_exec_callback(0, *args) raise HookError("Hook %s exited with non-zero status" % (self.name)) if (self.post_exec_callback is not None): return self.post_exec_callback(1, *args) except OSError: # no file. silent failure. if (self.post_exec_callback is not None): self.post_exec_callback(0, *args) class PreCommitShellHook(ShellHook): """pre-commit shell hook""" def __init__(self, controldir): filepath = os.path.join(controldir, 'hooks', 'pre-commit') ShellHook.__init__(self, 'pre-commit', filepath, 0, cwd=controldir) class PostCommitShellHook(ShellHook): """post-commit shell hook""" def __init__(self, controldir): filepath = os.path.join(controldir, 'hooks', 'post-commit') ShellHook.__init__(self, 'post-commit', filepath, 0, cwd=controldir) class CommitMsgShellHook(ShellHook): """commit-msg shell hook Args: args[0]: commit message Returns: new commit message or None """ def __init__(self, controldir): filepath = os.path.join(controldir, 'hooks', 'commit-msg') def prepare_msg(*args): (fd, path) = tempfile.mkstemp() with os.fdopen(fd, 'wb') as f: f.write(args[0]) return (path,) def clean_msg(success, *args): if success: with open(args[0], 'rb') as f: new_msg = f.read() os.unlink(args[0]) return new_msg os.unlink(args[0]) ShellHook.__init__(self, 'commit-msg', filepath, 1, prepare_msg, clean_msg, controldir) class PostReceiveShellHook(ShellHook): """post-receive shell hook""" def __init__(self, controldir): self.controldir = controldir filepath = os.path.join(controldir, 'hooks', 'post-receive') ShellHook.__init__(self, 'post-receive', filepath, 0) def execute(self, client_refs): # do nothing if the script doesn't exist if not os.path.exists(self.filepath): return None try: env = os.environ.copy() env['GIT_DIR'] = self.controldir p = subprocess.Popen( self.filepath, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env ) # client_refs is a list of (oldsha, newsha, ref) in_data = '\n'.join([' '.join(ref) for ref in client_refs]) out_data, err_data = p.communicate(in_data) if (p.returncode != 0) or err_data: err_fmt = "post-receive exit code: %d\n" \ + "stdout:\n%s\nstderr:\n%s" err_msg = err_fmt % (p.returncode, out_data, err_data) raise HookError(err_msg) return out_data except OSError as err: raise HookError(repr(err)) diff --git a/dulwich/objects.py b/dulwich/objects.py index eeb330f7..59f6b85c 100644 --- a/dulwich/objects.py +++ b/dulwich/objects.py @@ -1,1435 +1,1434 @@ # objects.py -- Access to base git objects # Copyright (C) 2007 James Westby # Copyright (C) 2008-2013 Jelmer Vernooij # # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Access to base git objects.""" import binascii from io import BytesIO from collections import namedtuple import os import posixpath import stat -import sys import warnings import zlib from hashlib import sha1 from dulwich.errors import ( ChecksumMismatch, NotBlobError, NotCommitError, NotTagError, NotTreeError, ObjectFormatException, FileFormatException, ) from dulwich.file import GitFile ZERO_SHA = b'0' * 40 # Header fields for commits _TREE_HEADER = b'tree' _PARENT_HEADER = b'parent' _AUTHOR_HEADER = b'author' _COMMITTER_HEADER = b'committer' _ENCODING_HEADER = b'encoding' _MERGETAG_HEADER = b'mergetag' _GPGSIG_HEADER = b'gpgsig' # Header fields for objects _OBJECT_HEADER = b'object' _TYPE_HEADER = b'type' _TAG_HEADER = b'tag' _TAGGER_HEADER = b'tagger' S_IFGITLINK = 0o160000 MAX_TIME = 9223372036854775807 # (2**63) - 1 - signed long int max BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----" class EmptyFileException(FileFormatException): """An unexpectedly empty file was encountered.""" def S_ISGITLINK(m): """Check if a mode indicates a submodule. Args: m: Mode to check Returns: a ``boolean`` """ return (stat.S_IFMT(m) == S_IFGITLINK) def _decompress(string): dcomp = zlib.decompressobj() dcomped = dcomp.decompress(string) dcomped += dcomp.flush() return dcomped def sha_to_hex(sha): """Takes a string and returns the hex of the sha within""" hexsha = binascii.hexlify(sha) assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % hexsha return hexsha def hex_to_sha(hex): """Takes a hex sha and returns a binary sha""" assert len(hex) == 40, "Incorrect length of hexsha: %s" % hex try: return binascii.unhexlify(hex) except TypeError as exc: if not isinstance(hex, bytes): raise raise ValueError(exc.args[0]) def valid_hexsha(hex): if len(hex) != 40: return False try: binascii.unhexlify(hex) except (TypeError, binascii.Error): return False else: return True def hex_to_filename(path, hex): """Takes a hex sha and returns its filename relative to the given path.""" # os.path.join accepts bytes or unicode, but all args must be of the same # type. Make sure that hex which is expected to be bytes, is the same type # as path. if getattr(path, 'encode', None) is not None: hex = hex.decode('ascii') dir = hex[:2] file = hex[2:] # Check from object dir return os.path.join(path, dir, file) def filename_to_hex(filename): """Takes an object filename and returns its corresponding hex sha.""" # grab the last (up to) two path components names = filename.rsplit(os.path.sep, 2)[-2:] errmsg = "Invalid object filename: %s" % filename assert len(names) == 2, errmsg base, rest = names assert len(base) == 2 and len(rest) == 38, errmsg hex = (base + rest).encode('ascii') hex_to_sha(hex) return hex def object_header(num_type, length): """Return an object header for the given numeric type and text length.""" return (object_class(num_type).type_name + b' ' + str(length).encode('ascii') + b'\0') def serializable_property(name, docstring=None): """A property that helps tracking whether serialization is necessary. """ def set(obj, value): setattr(obj, "_"+name, value) obj._needs_serialization = True def get(obj): return getattr(obj, "_"+name) return property(get, set, doc=docstring) def object_class(type): """Get the object class corresponding to the given type. Args: type: Either a type name string or a numeric type. Returns: The ShaFile subclass corresponding to the given type, or None if type is not a valid type name/number. """ return _TYPE_MAP.get(type, None) def check_hexsha(hex, error_msg): """Check if a string is a valid hex sha string. Args: hex: Hex string to check error_msg: Error message to use in exception Raises: ObjectFormatException: Raised when the string is not valid """ if not valid_hexsha(hex): raise ObjectFormatException("%s %s" % (error_msg, hex)) def check_identity(identity, error_msg): """Check if the specified identity is valid. This will raise an exception if the identity is not valid. Args: identity: Identity string error_msg: Error message to use in exception """ email_start = identity.find(b'<') email_end = identity.find(b'>') if (email_start < 0 or email_end < 0 or email_end <= email_start or identity.find(b'<', email_start + 1) >= 0 or identity.find(b'>', email_end + 1) >= 0 or not identity.endswith(b'>')): raise ObjectFormatException(error_msg) def check_time(time_seconds): """Check if the specified time is not prone to overflow error. This will raise an exception if the time is not valid. Args: time_info: author/committer/tagger info """ # Prevent overflow error if time_seconds > MAX_TIME: raise ObjectFormatException( 'Date field should not exceed %s' % MAX_TIME) def git_line(*items): """Formats items into a space separated line.""" return b' '.join(items) + b'\n' class FixedSha(object): """SHA object that behaves like hashlib's but is given a fixed value.""" __slots__ = ('_hexsha', '_sha') def __init__(self, hexsha): if getattr(hexsha, 'encode', None) is not None: hexsha = hexsha.encode('ascii') if not isinstance(hexsha, bytes): raise TypeError('Expected bytes for hexsha, got %r' % hexsha) self._hexsha = hexsha self._sha = hex_to_sha(hexsha) def digest(self): """Return the raw SHA digest.""" return self._sha def hexdigest(self): """Return the hex SHA digest.""" return self._hexsha.decode('ascii') class ShaFile(object): """A git SHA file.""" __slots__ = ('_chunked_text', '_sha', '_needs_serialization') @staticmethod def _parse_legacy_object_header(magic, f): """Parse a legacy object, creating it but not reading the file.""" bufsize = 1024 decomp = zlib.decompressobj() header = decomp.decompress(magic) start = 0 end = -1 while end < 0: extra = f.read(bufsize) header += decomp.decompress(extra) magic += extra end = header.find(b'\0', start) start = len(header) header = header[:end] type_name, size = header.split(b' ', 1) try: int(size) # sanity check except ValueError as e: raise ObjectFormatException("Object size not an integer: %s" % e) obj_class = object_class(type_name) if not obj_class: raise ObjectFormatException("Not a known type: %s" % type_name) return obj_class() def _parse_legacy_object(self, map): """Parse a legacy object, setting the raw string.""" text = _decompress(map) header_end = text.find(b'\0') if header_end < 0: raise ObjectFormatException("Invalid object header, no \\0") self.set_raw_string(text[header_end+1:]) def as_legacy_object_chunks(self, compression_level=-1): """Return chunks representing the object in the experimental format. Returns: List of strings """ compobj = zlib.compressobj(compression_level) yield compobj.compress(self._header()) for chunk in self.as_raw_chunks(): yield compobj.compress(chunk) yield compobj.flush() def as_legacy_object(self, compression_level=-1): """Return string representing the object in the experimental format. """ return b''.join(self.as_legacy_object_chunks( compression_level=compression_level)) def as_raw_chunks(self): """Return chunks with serialization of the object. Returns: List of strings, not necessarily one per line """ if self._needs_serialization: self._sha = None self._chunked_text = self._serialize() self._needs_serialization = False return self._chunked_text def as_raw_string(self): """Return raw string with serialization of the object. Returns: String object """ return b''.join(self.as_raw_chunks()) def __bytes__(self): """Return raw string serialization of this object.""" return self.as_raw_string() def __hash__(self): """Return unique hash for this object.""" return hash(self.id) def as_pretty_string(self): """Return a string representing this object, fit for display.""" return self.as_raw_string() def set_raw_string(self, text, sha=None): """Set the contents of this object from a serialized string.""" if not isinstance(text, bytes): raise TypeError('Expected bytes for text, got %r' % text) self.set_raw_chunks([text], sha) def set_raw_chunks(self, chunks, sha=None): """Set the contents of this object from a list of chunks.""" self._chunked_text = chunks self._deserialize(chunks) if sha is None: self._sha = None else: self._sha = FixedSha(sha) self._needs_serialization = False @staticmethod def _parse_object_header(magic, f): """Parse a new style object, creating it but not reading the file.""" num_type = (ord(magic[0:1]) >> 4) & 7 obj_class = object_class(num_type) if not obj_class: raise ObjectFormatException("Not a known type %d" % num_type) return obj_class() def _parse_object(self, map): """Parse a new style object, setting self._text.""" # skip type and size; type must have already been determined, and # we trust zlib to fail if it's otherwise corrupted byte = ord(map[0:1]) used = 1 while (byte & 0x80) != 0: byte = ord(map[used:used+1]) used += 1 raw = map[used:] self.set_raw_string(_decompress(raw)) @classmethod def _is_legacy_object(cls, magic): b0 = ord(magic[0:1]) b1 = ord(magic[1:2]) word = (b0 << 8) + b1 return (b0 & 0x8F) == 0x08 and (word % 31) == 0 @classmethod def _parse_file(cls, f): map = f.read() if not map: raise EmptyFileException('Corrupted empty file detected') if cls._is_legacy_object(map): obj = cls._parse_legacy_object_header(map, f) obj._parse_legacy_object(map) else: obj = cls._parse_object_header(map, f) obj._parse_object(map) return obj def __init__(self): """Don't call this directly""" self._sha = None self._chunked_text = [] self._needs_serialization = True def _deserialize(self, chunks): raise NotImplementedError(self._deserialize) def _serialize(self): raise NotImplementedError(self._serialize) @classmethod def from_path(cls, path): """Open a SHA file from disk.""" with GitFile(path, 'rb') as f: return cls.from_file(f) @classmethod def from_file(cls, f): """Get the contents of a SHA file on disk.""" try: obj = cls._parse_file(f) obj._sha = None return obj except (IndexError, ValueError): raise ObjectFormatException("invalid object header") @staticmethod def from_raw_string(type_num, string, sha=None): """Creates an object of the indicated type from the raw string given. Args: type_num: The numeric type of the object. string: The raw uncompressed contents. sha: Optional known sha for the object """ obj = object_class(type_num)() obj.set_raw_string(string, sha) return obj @staticmethod def from_raw_chunks(type_num, chunks, sha=None): """Creates an object of the indicated type from the raw chunks given. Args: type_num: The numeric type of the object. chunks: An iterable of the raw uncompressed contents. sha: Optional known sha for the object """ obj = object_class(type_num)() obj.set_raw_chunks(chunks, sha) return obj @classmethod def from_string(cls, string): """Create a ShaFile from a string.""" obj = cls() obj.set_raw_string(string) return obj def _check_has_member(self, member, error_msg): """Check that the object has a given member variable. Args: member: the member variable to check for error_msg: the message for an error if the member is missing Raises: ObjectFormatException: with the given error_msg if member is missing or is None """ if getattr(self, member, None) is None: raise ObjectFormatException(error_msg) def check(self): """Check this object for internal consistency. Raises: ObjectFormatException: if the object is malformed in some way ChecksumMismatch: if the object was created with a SHA that does not match its contents """ # TODO: if we find that error-checking during object parsing is a # performance bottleneck, those checks should be moved to the class's # check() method during optimization so we can still check the object # when necessary. old_sha = self.id try: self._deserialize(self.as_raw_chunks()) self._sha = None new_sha = self.id except Exception as e: raise ObjectFormatException(e) if old_sha != new_sha: raise ChecksumMismatch(new_sha, old_sha) def _header(self): return object_header(self.type, self.raw_length()) def raw_length(self): """Returns the length of the raw string of this object.""" ret = 0 for chunk in self.as_raw_chunks(): ret += len(chunk) return ret def sha(self): """The SHA1 object that is the name of this object.""" if self._sha is None or self._needs_serialization: # this is a local because as_raw_chunks() overwrites self._sha new_sha = sha1() new_sha.update(self._header()) for chunk in self.as_raw_chunks(): new_sha.update(chunk) self._sha = new_sha return self._sha def copy(self): """Create a new copy of this SHA1 object from its raw string""" obj_class = object_class(self.get_type()) return obj_class.from_raw_string( self.get_type(), self.as_raw_string(), self.id) @property def id(self): """The hex SHA of this object.""" return self.sha().hexdigest().encode('ascii') def get_type(self): """Return the type number for this object class.""" return self.type_num def set_type(self, type): """Set the type number for this object class.""" self.type_num = type # DEPRECATED: use type_num or type_name as needed. type = property(get_type, set_type) def __repr__(self): return "<%s %s>" % (self.__class__.__name__, self.id) def __ne__(self, other): """Check whether this object does not match the other.""" return not isinstance(other, ShaFile) or self.id != other.id def __eq__(self, other): """Return True if the SHAs of the two objects match. """ return isinstance(other, ShaFile) and self.id == other.id def __lt__(self, other): """Return whether SHA of this object is less than the other. """ if not isinstance(other, ShaFile): raise TypeError return self.id < other.id def __le__(self, other): """Check whether SHA of this object is less than or equal to the other. """ if not isinstance(other, ShaFile): raise TypeError return self.id <= other.id def __cmp__(self, other): """Compare the SHA of this object with that of the other object. """ if not isinstance(other, ShaFile): raise TypeError return cmp(self.id, other.id) # noqa: F821 class Blob(ShaFile): """A Git Blob object.""" __slots__ = () type_name = b'blob' type_num = 3 def __init__(self): super(Blob, self).__init__() self._chunked_text = [] self._needs_serialization = False def _get_data(self): return self.as_raw_string() def _set_data(self, data): self.set_raw_string(data) data = property(_get_data, _set_data, "The text contained within the blob object.") def _get_chunked(self): return self._chunked_text def _set_chunked(self, chunks): self._chunked_text = chunks def _serialize(self): return self._chunked_text def _deserialize(self, chunks): self._chunked_text = chunks chunked = property( _get_chunked, _set_chunked, "The text within the blob object, as chunks (not necessarily lines).") @classmethod def from_path(cls, path): blob = ShaFile.from_path(path) if not isinstance(blob, cls): raise NotBlobError(path) return blob def check(self): """Check this object for internal consistency. Raises: ObjectFormatException: if the object is malformed in some way """ super(Blob, self).check() def splitlines(self): """Return list of lines in this blob. This preserves the original line endings. """ chunks = self.chunked if not chunks: return [] if len(chunks) == 1: return chunks[0].splitlines(True) remaining = None ret = [] for chunk in chunks: lines = chunk.splitlines(True) if len(lines) > 1: ret.append((remaining or b"") + lines[0]) ret.extend(lines[1:-1]) remaining = lines[-1] elif len(lines) == 1: if remaining is None: remaining = lines.pop() else: remaining += lines.pop() if remaining is not None: ret.append(remaining) return ret def _parse_message(chunks): """Parse a message with a list of fields and a body. Args: chunks: the raw chunks of the tag or commit object. Returns: iterator of tuples of (field, value), one per header line, in the order read from the text, possibly including duplicates. Includes a field named None for the freeform tag/commit text. """ f = BytesIO(b''.join(chunks)) k = None v = "" eof = False def _strip_last_newline(value): """Strip the last newline from value""" if value and value.endswith(b'\n'): return value[:-1] return value # Parse the headers # # Headers can contain newlines. The next line is indented with a space. # We store the latest key as 'k', and the accumulated value as 'v'. for line in f: if line.startswith(b' '): # Indented continuation of the previous line v += line[1:] else: if k is not None: # We parsed a new header, return its value yield (k, _strip_last_newline(v)) if line == b'\n': # Empty line indicates end of headers break (k, v) = line.split(b' ', 1) else: # We reached end of file before the headers ended. We still need to # return the previous header, then we need to return a None field for # the text. eof = True if k is not None: yield (k, _strip_last_newline(v)) yield (None, None) if not eof: # We didn't reach the end of file while parsing headers. We can return # the rest of the file as a message. yield (None, f.read()) f.close() class Tag(ShaFile): """A Git Tag object.""" type_name = b'tag' type_num = 4 __slots__ = ('_tag_timezone_neg_utc', '_name', '_object_sha', '_object_class', '_tag_time', '_tag_timezone', '_tagger', '_message', '_signature') def __init__(self): super(Tag, self).__init__() self._tagger = None self._tag_time = None self._tag_timezone = None self._tag_timezone_neg_utc = False self._signature = None @classmethod def from_path(cls, filename): tag = ShaFile.from_path(filename) if not isinstance(tag, cls): raise NotTagError(filename) return tag def check(self): """Check this object for internal consistency. Raises: ObjectFormatException: if the object is malformed in some way """ super(Tag, self).check() self._check_has_member("_object_sha", "missing object sha") self._check_has_member("_object_class", "missing object type") self._check_has_member("_name", "missing tag name") if not self._name: raise ObjectFormatException("empty tag name") check_hexsha(self._object_sha, "invalid object sha") if getattr(self, "_tagger", None): check_identity(self._tagger, "invalid tagger") self._check_has_member("_tag_time", "missing tag time") check_time(self._tag_time) last = None for field, _ in _parse_message(self._chunked_text): if field == _OBJECT_HEADER and last is not None: raise ObjectFormatException("unexpected object") elif field == _TYPE_HEADER and last != _OBJECT_HEADER: raise ObjectFormatException("unexpected type") elif field == _TAG_HEADER and last != _TYPE_HEADER: raise ObjectFormatException("unexpected tag name") elif field == _TAGGER_HEADER and last != _TAG_HEADER: raise ObjectFormatException("unexpected tagger") last = field def _serialize(self): chunks = [] chunks.append(git_line(_OBJECT_HEADER, self._object_sha)) chunks.append(git_line(_TYPE_HEADER, self._object_class.type_name)) chunks.append(git_line(_TAG_HEADER, self._name)) if self._tagger: if self._tag_time is None: chunks.append(git_line(_TAGGER_HEADER, self._tagger)) else: chunks.append(git_line( _TAGGER_HEADER, self._tagger, str(self._tag_time).encode('ascii'), format_timezone( self._tag_timezone, self._tag_timezone_neg_utc))) if self._message is not None: chunks.append(b'\n') # To close headers chunks.append(self._message) if self._signature is not None: chunks.append(self._signature) return chunks def _deserialize(self, chunks): """Grab the metadata attached to the tag""" self._tagger = None self._tag_time = None self._tag_timezone = None self._tag_timezone_neg_utc = False for field, value in _parse_message(chunks): if field == _OBJECT_HEADER: self._object_sha = value elif field == _TYPE_HEADER: obj_class = object_class(value) if not obj_class: raise ObjectFormatException("Not a known type: %s" % value) self._object_class = obj_class elif field == _TAG_HEADER: self._name = value elif field == _TAGGER_HEADER: (self._tagger, self._tag_time, (self._tag_timezone, self._tag_timezone_neg_utc)) = parse_time_entry(value) elif field is None: if value is None: self._message = None self._signature = None else: try: sig_idx = value.index(BEGIN_PGP_SIGNATURE) except ValueError: self._message = value self._signature = None else: self._message = value[:sig_idx] self._signature = value[sig_idx:] else: raise ObjectFormatException("Unknown field %s" % field) def _get_object(self): """Get the object pointed to by this tag. Returns: tuple of (object class, sha). """ return (self._object_class, self._object_sha) def _set_object(self, value): (self._object_class, self._object_sha) = value self._needs_serialization = True object = property(_get_object, _set_object) name = serializable_property("name", "The name of this tag") tagger = serializable_property( "tagger", "Returns the name of the person who created this tag") tag_time = serializable_property( "tag_time", "The creation timestamp of the tag. As the number of seconds " "since the epoch") tag_timezone = serializable_property( "tag_timezone", "The timezone that tag_time is in.") message = serializable_property( "message", "the message attached to this tag") signature = serializable_property( "signature", "Optional detached GPG signature") class TreeEntry(namedtuple('TreeEntry', ['path', 'mode', 'sha'])): """Named tuple encapsulating a single tree entry.""" def in_path(self, path): """Return a copy of this entry with the given path prepended.""" if not isinstance(self.path, bytes): raise TypeError('Expected bytes for path, got %r' % path) return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha) def parse_tree(text, strict=False): """Parse a tree text. Args: text: Serialized text to parse Returns: iterator of tuples of (name, mode, sha) Raises: ObjectFormatException: if the object was malformed in some way """ count = 0 length = len(text) while count < length: mode_end = text.index(b' ', count) mode_text = text[count:mode_end] if strict and mode_text.startswith(b'0'): raise ObjectFormatException("Invalid mode '%s'" % mode_text) try: mode = int(mode_text, 8) except ValueError: raise ObjectFormatException("Invalid mode '%s'" % mode_text) name_end = text.index(b'\0', mode_end) name = text[mode_end+1:name_end] count = name_end+21 sha = text[name_end+1:count] if len(sha) != 20: raise ObjectFormatException("Sha has invalid length") hexsha = sha_to_hex(sha) yield (name, mode, hexsha) def serialize_tree(items): """Serialize the items in a tree to a text. Args: items: Sorted iterable over (name, mode, sha) tuples Returns: Serialized tree text as chunks """ for name, mode, hexsha in items: yield (("%04o" % mode).encode('ascii') + b' ' + name + b'\0' + hex_to_sha(hexsha)) def sorted_tree_items(entries, name_order): """Iterate over a tree entries dictionary. Args: name_order: If True, iterate entries in order of their name. If False, iterate entries in tree order, that is, treat subtree entries as having '/' appended. entries: Dictionary mapping names to (mode, sha) tuples Returns: Iterator over (name, mode, hexsha) """ key_func = name_order and key_entry_name_order or key_entry for name, entry in sorted(entries.items(), key=key_func): mode, hexsha = entry # Stricter type checks than normal to mirror checks in the C version. mode = int(mode) if not isinstance(hexsha, bytes): raise TypeError('Expected bytes for SHA, got %r' % hexsha) yield TreeEntry(name, mode, hexsha) def key_entry(entry): """Sort key for tree entry. Args: entry: (name, value) tuplee """ (name, value) = entry if stat.S_ISDIR(value[0]): name += b'/' return name def key_entry_name_order(entry): """Sort key for tree entry in name order.""" return entry[0] def pretty_format_tree_entry(name, mode, hexsha, encoding="utf-8"): """Pretty format tree entry. Args: name: Name of the directory entry mode: Mode of entry hexsha: Hexsha of the referenced object Returns: string describing the tree entry """ if mode & stat.S_IFDIR: kind = "tree" else: kind = "blob" return "%04o %s %s\t%s\n" % ( mode, kind, hexsha.decode('ascii'), name.decode(encoding, 'replace')) class Tree(ShaFile): """A Git tree object""" type_name = b'tree' type_num = 2 __slots__ = ('_entries') def __init__(self): super(Tree, self).__init__() self._entries = {} @classmethod def from_path(cls, filename): tree = ShaFile.from_path(filename) if not isinstance(tree, cls): raise NotTreeError(filename) return tree def __contains__(self, name): return name in self._entries def __getitem__(self, name): return self._entries[name] def __setitem__(self, name, value): """Set a tree entry by name. Args: name: The name of the entry, as a string. value: A tuple of (mode, hexsha), where mode is the mode of the entry as an integral type and hexsha is the hex SHA of the entry as a string. """ mode, hexsha = value self._entries[name] = (mode, hexsha) self._needs_serialization = True def __delitem__(self, name): del self._entries[name] self._needs_serialization = True def __len__(self): return len(self._entries) def __iter__(self): return iter(self._entries) def add(self, name, mode, hexsha): """Add an entry to the tree. Args: mode: The mode of the entry as an integral type. Not all possible modes are supported by git; see check() for details. name: The name of the entry, as a string. hexsha: The hex SHA of the entry as a string. """ if isinstance(name, int) and isinstance(mode, bytes): (name, mode) = (mode, name) warnings.warn( "Please use Tree.add(name, mode, hexsha)", category=DeprecationWarning, stacklevel=2) self._entries[name] = mode, hexsha self._needs_serialization = True def iteritems(self, name_order=False): """Iterate over entries. Args: name_order: If True, iterate in name order instead of tree order. Returns: Iterator over (name, mode, sha) tuples """ return sorted_tree_items(self._entries, name_order) def items(self): """Return the sorted entries in this tree. Returns: List with (name, mode, sha) tuples """ return list(self.iteritems()) def _deserialize(self, chunks): """Grab the entries in the tree""" try: parsed_entries = parse_tree(b''.join(chunks)) except ValueError as e: raise ObjectFormatException(e) # TODO: list comprehension is for efficiency in the common (small) # case; if memory efficiency in the large case is a concern, use a # genexp. self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries]) def check(self): """Check this object for internal consistency. Raises: ObjectFormatException: if the object is malformed in some way """ super(Tree, self).check() last = None allowed_modes = (stat.S_IFREG | 0o755, stat.S_IFREG | 0o644, stat.S_IFLNK, stat.S_IFDIR, S_IFGITLINK, # TODO: optionally exclude as in git fsck --strict stat.S_IFREG | 0o664) for name, mode, sha in parse_tree(b''.join(self._chunked_text), True): check_hexsha(sha, 'invalid sha %s' % sha) if b'/' in name or name in (b'', b'.', b'..', b'.git'): raise ObjectFormatException( 'invalid name %s' % name.decode('utf-8', 'replace')) if mode not in allowed_modes: raise ObjectFormatException('invalid mode %06o' % mode) entry = (name, (mode, sha)) if last: if key_entry(last) > key_entry(entry): raise ObjectFormatException('entries not sorted') if name == last[0]: raise ObjectFormatException('duplicate entry %s' % name) last = entry def _serialize(self): return list(serialize_tree(self.iteritems())) def as_pretty_string(self): text = [] for name, mode, hexsha in self.iteritems(): text.append(pretty_format_tree_entry(name, mode, hexsha)) return "".join(text) def lookup_path(self, lookup_obj, path): """Look up an object in a Git tree. Args: lookup_obj: Callback for retrieving object by SHA1 path: Path to lookup Returns: A tuple of (mode, SHA) of the resulting path. """ parts = path.split(b'/') sha = self.id mode = None for p in parts: if not p: continue obj = lookup_obj(sha) if not isinstance(obj, Tree): raise NotTreeError(sha) mode, sha = obj[p] return mode, sha def parse_timezone(text): """Parse a timezone text fragment (e.g. '+0100'). Args: text: Text to parse. Returns: Tuple with timezone as seconds difference to UTC and a boolean indicating whether this was a UTC timezone prefixed with a negative sign (-0000). """ # cgit parses the first character as the sign, and the rest # as an integer (using strtol), which could also be negative. # We do the same for compatibility. See #697828. if not text[0] in b'+-': raise ValueError("Timezone must start with + or - (%(text)s)" % vars()) sign = text[:1] offset = int(text[1:]) if sign == b'-': offset = -offset unnecessary_negative_timezone = (offset >= 0 and sign == b'-') signum = (offset < 0) and -1 or 1 offset = abs(offset) hours = int(offset / 100) minutes = (offset % 100) return (signum * (hours * 3600 + minutes * 60), unnecessary_negative_timezone) def format_timezone(offset, unnecessary_negative_timezone=False): """Format a timezone for Git serialization. Args: offset: Timezone offset as seconds difference to UTC unnecessary_negative_timezone: Whether to use a minus sign for UTC or positive timezones (-0000 and --700 rather than +0000 / +0700). """ if offset % 60 != 0: raise ValueError("Unable to handle non-minute offset.") if offset < 0 or unnecessary_negative_timezone: sign = '-' offset = -offset else: sign = '+' return ('%c%02d%02d' % (sign, offset / 3600, (offset / 60) % 60)).encode('ascii') def parse_time_entry(value): """Parse time entry behavior Args: value: Bytes representing a git commit/tag line Raises: ObjectFormatException in case of parsing error (malformed field date) Returns: Tuple of (author, time, (timezone, timezone_neg_utc)) """ try: sep = value.rindex(b'> ') except ValueError: return (value, None, (None, False)) try: person = value[0:sep+1] rest = value[sep+2:] timetext, timezonetext = rest.rsplit(b' ', 1) time = int(timetext) timezone, timezone_neg_utc = parse_timezone(timezonetext) except ValueError as e: raise ObjectFormatException(e) return person, time, (timezone, timezone_neg_utc) def parse_commit(chunks): """Parse a commit object from chunks. Args: chunks: Chunks to parse Returns: Tuple of (tree, parents, author_info, commit_info, encoding, mergetag, gpgsig, message, extra) """ parents = [] extra = [] tree = None author_info = (None, None, (None, None)) commit_info = (None, None, (None, None)) encoding = None mergetag = [] message = None gpgsig = None for field, value in _parse_message(chunks): # TODO(jelmer): Enforce ordering if field == _TREE_HEADER: tree = value elif field == _PARENT_HEADER: parents.append(value) elif field == _AUTHOR_HEADER: author_info = parse_time_entry(value) elif field == _COMMITTER_HEADER: commit_info = parse_time_entry(value) elif field == _ENCODING_HEADER: encoding = value elif field == _MERGETAG_HEADER: mergetag.append(Tag.from_string(value + b'\n')) elif field == _GPGSIG_HEADER: gpgsig = value elif field is None: message = value else: extra.append((field, value)) return (tree, parents, author_info, commit_info, encoding, mergetag, gpgsig, message, extra) class Commit(ShaFile): """A git commit object""" type_name = b'commit' type_num = 1 __slots__ = ('_parents', '_encoding', '_extra', '_author_timezone_neg_utc', '_commit_timezone_neg_utc', '_commit_time', '_author_time', '_author_timezone', '_commit_timezone', '_author', '_committer', '_tree', '_message', '_mergetag', '_gpgsig') def __init__(self): super(Commit, self).__init__() self._parents = [] self._encoding = None self._mergetag = [] self._gpgsig = None self._extra = [] self._author_timezone_neg_utc = False self._commit_timezone_neg_utc = False @classmethod def from_path(cls, path): commit = ShaFile.from_path(path) if not isinstance(commit, cls): raise NotCommitError(path) return commit def _deserialize(self, chunks): (self._tree, self._parents, author_info, commit_info, self._encoding, self._mergetag, self._gpgsig, self._message, self._extra) = ( parse_commit(chunks)) (self._author, self._author_time, (self._author_timezone, self._author_timezone_neg_utc)) = author_info (self._committer, self._commit_time, (self._commit_timezone, self._commit_timezone_neg_utc)) = commit_info def check(self): """Check this object for internal consistency. Raises: ObjectFormatException: if the object is malformed in some way """ super(Commit, self).check() self._check_has_member("_tree", "missing tree") self._check_has_member("_author", "missing author") self._check_has_member("_committer", "missing committer") self._check_has_member("_author_time", "missing author time") self._check_has_member("_commit_time", "missing commit time") for parent in self._parents: check_hexsha(parent, "invalid parent sha") check_hexsha(self._tree, "invalid tree sha") check_identity(self._author, "invalid author") check_identity(self._committer, "invalid committer") check_time(self._author_time) check_time(self._commit_time) last = None for field, _ in _parse_message(self._chunked_text): if field == _TREE_HEADER and last is not None: raise ObjectFormatException("unexpected tree") elif field == _PARENT_HEADER and last not in (_PARENT_HEADER, _TREE_HEADER): raise ObjectFormatException("unexpected parent") elif field == _AUTHOR_HEADER and last not in (_TREE_HEADER, _PARENT_HEADER): raise ObjectFormatException("unexpected author") elif field == _COMMITTER_HEADER and last != _AUTHOR_HEADER: raise ObjectFormatException("unexpected committer") elif field == _ENCODING_HEADER and last != _COMMITTER_HEADER: raise ObjectFormatException("unexpected encoding") last = field # TODO: optionally check for duplicate parents def _serialize(self): chunks = [] tree_bytes = ( self._tree.id if isinstance(self._tree, Tree) else self._tree) chunks.append(git_line(_TREE_HEADER, tree_bytes)) for p in self._parents: chunks.append(git_line(_PARENT_HEADER, p)) chunks.append(git_line( _AUTHOR_HEADER, self._author, str(self._author_time).encode('ascii'), format_timezone( self._author_timezone, self._author_timezone_neg_utc))) chunks.append(git_line( _COMMITTER_HEADER, self._committer, str(self._commit_time).encode('ascii'), format_timezone(self._commit_timezone, self._commit_timezone_neg_utc))) if self.encoding: chunks.append(git_line(_ENCODING_HEADER, self.encoding)) for mergetag in self.mergetag: mergetag_chunks = mergetag.as_raw_string().split(b'\n') chunks.append(git_line(_MERGETAG_HEADER, mergetag_chunks[0])) # Embedded extra header needs leading space for chunk in mergetag_chunks[1:]: chunks.append(b' ' + chunk + b'\n') # No trailing empty line if chunks[-1].endswith(b' \n'): chunks[-1] = chunks[-1][:-2] for k, v in self.extra: if b'\n' in k or b'\n' in v: raise AssertionError( "newline in extra data: %r -> %r" % (k, v)) chunks.append(git_line(k, v)) if self.gpgsig: sig_chunks = self.gpgsig.split(b'\n') chunks.append(git_line(_GPGSIG_HEADER, sig_chunks[0])) for chunk in sig_chunks[1:]: chunks.append(git_line(b'', chunk)) chunks.append(b'\n') # There must be a new line after the headers chunks.append(self._message) return chunks tree = serializable_property( "tree", "Tree that is the state of this commit") def _get_parents(self): """Return a list of parents of this commit.""" return self._parents def _set_parents(self, value): """Set a list of parents of this commit.""" self._needs_serialization = True self._parents = value parents = property(_get_parents, _set_parents, doc="Parents of this commit, by their SHA1.") def _get_extra(self): """Return extra settings of this commit.""" return self._extra extra = property( _get_extra, doc="Extra header fields not understood (presumably added in a " "newer version of git). Kept verbatim so the object can " "be correctly reserialized. For private commit metadata, use " "pseudo-headers in Commit.message, rather than this field.") author = serializable_property( "author", "The name of the author of the commit") committer = serializable_property( "committer", "The name of the committer of the commit") message = serializable_property( "message", "The commit message") commit_time = serializable_property( "commit_time", "The timestamp of the commit. As the number of seconds since the " "epoch.") commit_timezone = serializable_property( "commit_timezone", "The zone the commit time is in") author_time = serializable_property( "author_time", "The timestamp the commit was written. As the number of " "seconds since the epoch.") author_timezone = serializable_property( "author_timezone", "Returns the zone the author time is in.") encoding = serializable_property( "encoding", "Encoding of the commit message.") mergetag = serializable_property( "mergetag", "Associated signed tag.") gpgsig = serializable_property( "gpgsig", "GPG Signature.") OBJECT_CLASSES = ( Commit, Tree, Blob, Tag, ) _TYPE_MAP = {} for cls in OBJECT_CLASSES: _TYPE_MAP[cls.type_name] = cls _TYPE_MAP[cls.type_num] = cls # Hold on to the pure-python implementations for testing _parse_tree_py = parse_tree _sorted_tree_items_py = sorted_tree_items try: # Try to import C versions from dulwich._objects import parse_tree, sorted_tree_items except ImportError: pass