diff --git a/swh/web/inbound_email/utils.py b/swh/web/inbound_email/utils.py --- a/swh/web/inbound_email/utils.py +++ b/swh/web/inbound_email/utils.py @@ -6,7 +6,13 @@ from dataclasses import dataclass from email.headerregistry import Address from email.message import EmailMessage -from typing import List, Optional +import logging +from typing import List, Optional, Set + +from django.core.signing import Signer +from django.utils.crypto import constant_time_compare + +logger = logging.getLogger(__name__) def extract_recipients(message: EmailMessage) -> List[Address]: @@ -78,3 +84,83 @@ ret.append(match) return ret + + +ADDRESS_SIGNER_SEP = "." +"""Separator for email address signatures""" + + +def get_address_signer(salt: str): + """Get a signer for the given seed""" + return Signer(salt=salt, sep=ADDRESS_SIGNER_SEP) + + +def get_address_for_pk(salt: str, base_address: str, pk: int) -> str: + """Get the email address that will be able to receive messages to be logged in + this request.""" + if "@" not in base_address: + raise ValueError("Base address needs to contain an @") + + username, domain = base_address.split("@") + + extension = get_address_signer(salt).sign(str(pk)) + + return f"{username}+{extension}@{domain}" + + +def get_pk_from_extension(salt: str, extension: str) -> int: + """Retrieve the primary key for the given inbound address extension. + + We reimplement `Signer.unsign`, because the extension can be casemapped at any + point in the email chain (even though email is, theoretically, case sensitive), + so we have to compare lowercase versions of both the extension and the + signature... + + Raises ValueError if the signature couldn't be verified. + + """ + + value, signature = extension.rsplit(ADDRESS_SIGNER_SEP, 1) + expected_signature = get_address_signer(salt).signature(value) + if not constant_time_compare(signature.lower(), expected_signature.lower()): + raise ValueError(f"Invalid signature for extension {extension}") + + return int(value) + + +def get_pks_from_message( + salt: str, base_address: str, message: EmailMessage +) -> Set[int]: + """Retrieve the set of primary keys that were successfully decoded from the + recipients of the ``message`` matching ``base_address``. + + This uses :func:`recipient_matches` to retrieve all the recipient addresses matching + ``base_address``, then :func:`get_pk_from_extension` to decode the primary key and + verify the signature for every extension. To generate relevant email addresses, use + :func:`get_address_for_pk` with the same ``base_address`` and ``salt``. + + Returns: + the set of primary keys that were successfully decoded from the recipients of the + ``message`` + + """ + ret: Set[int] = set() + + for match in recipient_matches(message, base_address): + extension = match.extension + if extension is None: + logger.debug( + "Recipient address %s cannot be matched to a request, ignoring", + match.recipient.addr_spec, + ) + continue + + try: + ret.add(get_pk_from_extension(salt, extension)) + except ValueError: + logger.debug( + "Recipient address %s failed validation", match.recipient.addr_spec + ) + continue + + return ret diff --git a/swh/web/tests/inbound_email/test_utils.py b/swh/web/tests/inbound_email/test_utils.py --- a/swh/web/tests/inbound_email/test_utils.py +++ b/swh/web/tests/inbound_email/test_utils.py @@ -130,3 +130,114 @@ matches = utils.recipient_matches(message, "match@example.com") assert matches assert matches[0].extension == "weirdCaseMapping" + + +def test_get_address_for_pk(): + salt = "test_salt" + pks = [1, 10, 1000] + base_address = "base@example.com" + + addresses = { + pk: utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for pk in pks + } + + assert len(set(addresses.values())) == len(addresses) + + for pk, address in addresses.items(): + localpart, _, domain = address.partition("@") + base_localpart, _, extension = localpart.partition("+") + assert domain == "example.com" + assert base_localpart == "base" + assert extension.startswith(f"{pk}.") + + +def test_get_address_for_pk_salt(): + pk = 1000 + base_address = "base@example.com" + addresses = [ + utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for salt in ["salt1", "salt2"] + ] + + assert len(addresses) == len(set(addresses)) + + +def test_get_pks_from_message(): + salt = "test_salt" + pks = [1, 10, 1000] + base_address = "base@example.com" + + addresses = { + pk: utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for pk in pks + } + + message = EmailMessage() + message["To"] = "test@example.com" + + assert utils.get_pks_from_message(salt, base_address, message) == set() + + message = EmailMessage() + message["To"] = f"Test Address <{addresses[1]}>" + + assert utils.get_pks_from_message(salt, base_address, message) == {1} + + message = EmailMessage() + message["To"] = f"Test Address <{addresses[1]}>" + message["Cc"] = ", ".join( + [ + f"Test Address <{addresses[1]}>", + f"Another Test Address <{addresses[10].lower()}>", + "A Third Address ", + ] + ) + + assert utils.get_pks_from_message(salt, base_address, message) == {1, 10} + + +def test_get_pks_from_message_logging(caplog): + salt = "test_salt" + pks = [1, 10, 1000] + base_address = "base@example.com" + + addresses = { + pk: utils.get_address_for_pk(salt=salt, base_address=base_address, pk=pk) + for pk in pks + } + + message = EmailMessage() + message["To"] = f"Test Address <{base_address}>" + + assert utils.get_pks_from_message(salt, base_address, message) == set() + + relevant_records = [ + record + for record in caplog.records + if record.name == "swh.web.inbound_email.utils" + ] + assert len(relevant_records) == 1 + assert relevant_records[0].levelname == "DEBUG" + assert ( + f"{base_address} cannot be matched to a request" + in relevant_records[0].getMessage() + ) + + # Replace the signature with "mangle{signature}" + mangled_address = addresses[1].replace(".", ".mangle", 1) + + message = EmailMessage() + message["To"] = f"Test Address <{mangled_address}>" + + assert utils.get_pks_from_message(salt, base_address, message) == set() + + relevant_records = [ + record + for record in caplog.records + if record.name == "swh.web.inbound_email.utils" + ] + assert len(relevant_records) == 2 + assert relevant_records[0].levelname == "DEBUG" + assert relevant_records[1].levelname == "DEBUG" + print(relevant_records) + assert f"{mangled_address} failed" in relevant_records[1].getMessage()