diff --git a/requirements.txt b/requirements.txt --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ click deprecated +python-magic pyyaml sentry-sdk diff --git a/swh/core/tarball.py b/swh/core/tarball.py --- a/swh/core/tarball.py +++ b/swh/core/tarball.py @@ -1,4 +1,4 @@ -# Copyright (C) 2015-2019 The Software Heritage developers +# Copyright (C) 2015-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information @@ -10,6 +10,8 @@ import tarfile import zipfile +import magic + from . import utils @@ -73,6 +75,16 @@ shutil.register_unpack_format(name, extensions, function) +_mime_to_archive_format = { + "application/x-compress": "tar.Z|x", + "application/x-tar": "tar", + "application/x-bzip2": "bztar", + "application/gzip": "gztar", + "application/x-lzip": "tar.lz", + "application/zip": "zip", +} + + def uncompress(tarpath: str, dest: str): """Uncompress tarpath to dest folder if tarball is supported. @@ -91,15 +103,21 @@ try: os.makedirs(dest, exist_ok=True) format = None + # try to get archive format from extension for format_, exts, _ in shutil.get_unpack_formats(): if any([tarpath.lower().endswith(ext.lower()) for ext in exts]): format = format_ break + # try to get archive format from file mimetype + if format is None: + m = magic.Magic(mime=True) + mime = m.from_file(tarpath) + format = _mime_to_archive_format.get(mime) shutil.unpack_archive(tarpath, extract_dir=dest, format=format) except shutil.ReadError as e: raise ValueError(f"Problem during unpacking {tarpath}. Reason: {e}") except NotImplementedError: - if tarpath.lower().endswith(".zip"): + if tarpath.lower().endswith(".zip") or format == "zip": _unpack_zip(tarpath, dest) else: raise diff --git a/swh/core/tests/test_tarball.py b/swh/core/tests/test_tarball.py --- a/swh/core/tests/test_tarball.py +++ b/swh/core/tests/test_tarball.py @@ -1,8 +1,9 @@ -# Copyright (C) 2019 The Software Heritage developers +# Copyright (C) 2019-2021 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU General Public License version 3, or any later version # See top-level LICENSE file for more information +import hashlib import os import shutil @@ -160,28 +161,13 @@ assert format_id[0] in unpack_formats_v2 -def test_uncompress_archives(tmp_path, datadir, prepare_shutil_state): +def test_uncompress_archives(tmp_path, datadir): """High level call uncompression on un/supported archives """ archive_dir = os.path.join(datadir, "archives") archive_files = os.listdir(archive_dir) - # not supported yet - unsupported_archives = [] - for archive_file in archive_files: - if archive_file.endswith((".Z", ".x", ".lz", ".crate")): - unsupported_archives.append(os.path.join(archive_dir, archive_file)) - - for archive_path in unsupported_archives: - with pytest.raises( - ValueError, match=f"Problem during unpacking {archive_path}." - ): - tarball.uncompress(archive_path, dest=tmp_path) - - # register those unsupported formats - tarball.register_new_archive_formats() - # unsupported formats are now supported for archive_file in archive_files: archive_path = os.path.join(archive_dir, archive_file) @@ -242,3 +228,24 @@ shutil.copy(os.path.join(archives_path, archive_file), archive_file_upper) tarball.uncompress(archive_file_upper, extract_dir) assert len(os.listdir(extract_dir)) > 0 + + +def test_uncompress_archive_no_extension(tmp_path, datadir): + """Copy test archives in a temporary directory but turn their names + to their md5 sums, then check they can be successfully extracted. + """ + archives_path = os.path.join(datadir, "archives") + archive_files = [ + f + for f in os.listdir(archives_path) + if os.path.isfile(os.path.join(archives_path, f)) + ] + for archive_file in archive_files: + archive_file_path = os.path.join(archives_path, archive_file) + with open(archive_file_path, "rb") as f: + md5sum = hashlib.md5(f.read()).hexdigest() + archive_file_md5sum = os.path.join(tmp_path, md5sum) + extract_dir = os.path.join(tmp_path, archive_file) + shutil.copy(archive_file_path, archive_file_md5sum) + tarball.uncompress(archive_file_md5sum, extract_dir) + assert len(os.listdir(extract_dir)) > 0