diff --git a/swh/web/inbound_email/utils.py b/swh/web/inbound_email/utils.py index b469e40f..e7545ac7 100644 --- a/swh/web/inbound_email/utils.py +++ b/swh/web/inbound_email/utils.py @@ -1,80 +1,166 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from dataclasses import dataclass from email.headerregistry import Address from email.message import EmailMessage -from typing import List, Optional +import logging +from typing import List, Optional, Set + +from django.core.signing import Signer +from django.utils.crypto import constant_time_compare + +logger = logging.getLogger(__name__) def extract_recipients(message: EmailMessage) -> List[Address]: """Extract a list of recipients of the `message`. This uses the ``To`` and ``Cc`` fields. """ ret = [] for header_name in ("to", "cc"): for header in message.get_all(header_name, []): ret.extend(header.addresses) return ret @dataclass class AddressMatch: """Data related to a recipient match""" recipient: Address """The original recipient that matched the expected address""" extension: Optional[str] """The parsed +-extension of the matched recipient address""" def single_recipient_matches( recipient: Address, address: str ) -> Optional[AddressMatch]: """Check whether a single address matches the provided base address. The match is case-insensitive, which is not really RFC-compliant but is consistent with what most people would expect. This function supports "+-addressing", where the local part of the email address is appended with a `+`. """ parsed_address = Address(addr_spec=address.lower()) if recipient.domain.lower() != parsed_address.domain: return None base_username, _, extension = recipient.username.partition("+") if base_username.lower() != parsed_address.username: return None return AddressMatch(recipient=recipient, extension=extension or None) def recipient_matches(message: EmailMessage, address: str) -> List[AddressMatch]: """Check whether any of the message recipients match the given address. The match is case-insensitive, which is not really RFC-compliant but matches what most people would expect. This function supports "+-addressing", where the local part of the email address is appended with a `+`. """ ret = [] for recipient in extract_recipients(message): match = single_recipient_matches(recipient, address) if match: ret.append(match) return ret + + +ADDRESS_SIGNER_SEP = "." +"""Separator for email address signatures""" + + +def get_address_signer(salt: str) -> Signer: + """Get a signer for the given seed""" + return Signer(salt=salt, sep=ADDRESS_SIGNER_SEP) + + +def get_address_for_pk(salt: str, base_address: str, pk: int) -> str: + """Get the email address that will be able to receive messages to be logged in + this request.""" + if "@" not in base_address: + raise ValueError("Base address needs to contain an @") + + username, domain = base_address.split("@") + + extension = get_address_signer(salt).sign(str(pk)) + + return f"{username}+{extension}@{domain}" + + +def get_pk_from_extension(salt: str, extension: str) -> int: + """Retrieve the primary key for the given inbound address extension. + + We reimplement `Signer.unsign`, because the extension can be casemapped at any + point in the email chain (even though email is, theoretically, case sensitive), + so we have to compare lowercase versions of both the extension and the + signature... + + Raises ValueError if the signature couldn't be verified. + + """ + + value, signature = extension.rsplit(ADDRESS_SIGNER_SEP, 1) + expected_signature = get_address_signer(salt).signature(value) + if not constant_time_compare(signature.lower(), expected_signature.lower()): + raise ValueError(f"Invalid signature for extension {extension}") + + return int(value) + + +def get_pks_from_message( + salt: str, base_address: str, message: EmailMessage +) -> Set[int]: + """Retrieve the set of primary keys that were successfully decoded from the + recipients of the ``message`` matching ``base_address``. + + This uses :func:`recipient_matches` to retrieve all the recipient addresses matching + ``base_address``, then :func:`get_pk_from_extension` to decode the primary key and + verify the signature for every extension. To generate relevant email addresses, use + :func:`get_address_for_pk` with the same ``base_address`` and ``salt``. + + Returns: + the set of primary keys that were successfully decoded from the recipients of the + ``message`` + + """ + ret: Set[int] = set() + + for match in recipient_matches(message, base_address): + extension = match.extension + if extension is None: + logger.debug( + "Recipient address %s cannot be matched to a request, ignoring", + match.recipient.addr_spec, + ) + continue + + try: + ret.add(get_pk_from_extension(salt, extension)) + except ValueError: + logger.debug( + "Recipient address %s failed validation", match.recipient.addr_spec + ) + continue + + return ret diff --git a/swh/web/tests/inbound_email/test_utils.py b/swh/web/tests/inbound_email/test_utils.py index 83af934c..1f328ee1 100644 --- a/swh/web/tests/inbound_email/test_utils.py +++ b/swh/web/tests/inbound_email/test_utils.py @@ -1,132 +1,243 @@ # Copyright (C) 2022 The Software Heritage developers # See the AUTHORS file at the top-level directory of this distribution # License: GNU Affero General Public License version 3, or any later version # See top-level LICENSE file for more information from email.headerregistry import Address from email.message import EmailMessage from swh.web.inbound_email import utils def test_extract_recipients(): message = EmailMessage() assert utils.extract_recipients(message) == [] message["To"] = "Test Recipient " assert utils.extract_recipients(message) == [ Address(display_name="Test Recipient", addr_spec="test-recipient@example.com") ] message["Cc"] = ( "test-recipient-2@example.com, " "Another Test Recipient " ) assert utils.extract_recipients(message) == [ Address(display_name="Test Recipient", addr_spec="test-recipient@example.com"), Address(addr_spec="test-recipient-2@example.com"), Address( display_name="Another Test Recipient", addr_spec="test-recipient-3@example.com", ), ] del message["To"] assert utils.extract_recipients(message) == [ Address(addr_spec="test-recipient-2@example.com"), Address( display_name="Another Test Recipient", addr_spec="test-recipient-3@example.com", ), ] def test_single_recipient_matches(): assert ( utils.single_recipient_matches( Address(addr_spec="test@example.com"), "match@example.com" ) is None ) assert utils.single_recipient_matches( Address(addr_spec="match@example.com"), "match@example.com" ) == utils.AddressMatch( recipient=Address(addr_spec="match@example.com"), extension=None ) assert utils.single_recipient_matches( Address(addr_spec="MaTch+12345AbC@exaMple.Com"), "match@example.com" ) == utils.AddressMatch( recipient=Address(addr_spec="MaTch+12345AbC@exaMple.Com"), extension="12345AbC" ) def test_recipient_matches(): message = EmailMessage() assert utils.recipient_matches(message, "match@example.com") == [] message = EmailMessage() message["to"] = "nomatch@example.com" assert utils.recipient_matches(message, "match@example.com") == [] message = EmailMessage() message["to"] = "match@example.com" assert utils.recipient_matches(message, "match@example.com") == [ utils.AddressMatch( recipient=Address(addr_spec="match@example.com"), extension=None ) ] message = EmailMessage() message["to"] = "match+extension@example.com" assert utils.recipient_matches(message, "match@example.com") == [ utils.AddressMatch( recipient=Address(addr_spec="match+extension@example.com"), extension="extension", ) ] message = EmailMessage() message["to"] = "match+weird+plussed+extension@example.com" assert utils.recipient_matches(message, "match@example.com") == [ utils.AddressMatch( recipient=Address(addr_spec="match+weird+plussed+extension@example.com"), extension="weird+plussed+extension", ) ] message = EmailMessage() message["to"] = "nomatch@example.com" message["cc"] = ", ".join( ( "match@example.com", "match@notamatch.example.com", "Another Match ", ) ) assert utils.recipient_matches(message, "match@example.com") == [ utils.AddressMatch( recipient=Address(addr_spec="match@example.com"), extension=None, ), utils.AddressMatch( recipient=Address( display_name="Another Match", addr_spec="match+extension@example.com" ), extension="extension", ), ] def test_recipient_matches_casemapping(): message = EmailMessage() message["to"] = "match@example.com" assert utils.recipient_matches(message, "Match@Example.Com") assert utils.recipient_matches(message, "match@example.com") message = EmailMessage() message["to"] = "Match+weirdCaseMapping@Example.Com" matches = utils.recipient_matches(message, "match@example.com") assert matches assert matches[0].extension == "weirdCaseMapping" + + +def test_get_address_for_pk(): + salt = "test_salt" + pks = [1, 10, 1000] + base_address = "base@example.com" + + addresses = { + pk: utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for pk in pks + } + + assert len(set(addresses.values())) == len(addresses) + + for pk, address in addresses.items(): + localpart, _, domain = address.partition("@") + base_localpart, _, extension = localpart.partition("+") + assert domain == "example.com" + assert base_localpart == "base" + assert extension.startswith(f"{pk}.") + + +def test_get_address_for_pk_salt(): + pk = 1000 + base_address = "base@example.com" + addresses = [ + utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for salt in ["salt1", "salt2"] + ] + + assert len(addresses) == len(set(addresses)) + + +def test_get_pks_from_message(): + salt = "test_salt" + pks = [1, 10, 1000] + base_address = "base@example.com" + + addresses = { + pk: utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for pk in pks + } + + message = EmailMessage() + message["To"] = "test@example.com" + + assert utils.get_pks_from_message(salt, base_address, message) == set() + + message = EmailMessage() + message["To"] = f"Test Address <{addresses[1]}>" + + assert utils.get_pks_from_message(salt, base_address, message) == {1} + + message = EmailMessage() + message["To"] = f"Test Address <{addresses[1]}>" + message["Cc"] = ", ".join( + [ + f"Test Address <{addresses[1]}>", + f"Another Test Address <{addresses[10].lower()}>", + "A Third Address ", + ] + ) + + assert utils.get_pks_from_message(salt, base_address, message) == {1, 10} + + +def test_get_pks_from_message_logging(caplog): + salt = "test_salt" + pks = [1, 10, 1000] + base_address = "base@example.com" + + addresses = { + pk: utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for pk in pks + } + + message = EmailMessage() + message["To"] = f"Test Address <{base_address}>" + + assert utils.get_pks_from_message(salt, base_address, message) == set() + + relevant_records = [ + record + for record in caplog.records + if record.name == "swh.web.inbound_email.utils" + ] + assert len(relevant_records) == 1 + assert relevant_records[0].levelname == "DEBUG" + assert ( + f"{base_address} cannot be matched to a request" + in relevant_records[0].getMessage() + ) + + # Replace the signature with "mangle{signature}" + mangled_address = addresses[1].replace(".", ".mangle", 1) + + message = EmailMessage() + message["To"] = f"Test Address <{mangled_address}>" + + assert utils.get_pks_from_message(salt, base_address, message) == set() + + relevant_records = [ + record + for record in caplog.records + if record.name == "swh.web.inbound_email.utils" + ] + assert len(relevant_records) == 2 + assert relevant_records[0].levelname == "DEBUG" + assert relevant_records[1].levelname == "DEBUG" + + assert f"{mangled_address} failed" in relevant_records[1].getMessage()