Skip to content

Instantly share code, notes, and snippets.

@sgryjp
Last active November 26, 2024 01:04
Show Gist options
  • Save sgryjp/cac0f78ce7bca356de8ce0a089577409 to your computer and use it in GitHub Desktop.
Save sgryjp/cac0f78ce7bca356de8ce0a089577409 to your computer and use it in GitHub Desktop.
Implement CRAM-MD5 with aiosmtpd for Email Clients’ Unit Testing
"""Example to add support of CRAM-MD5 to aiosmtpd (unit testing purpose).
1. Defines a "handler" which will be an event handler for aiosmtpd.
* It has a method ``handle_message``.
* This is a requirement for subclasses of ``aiosmtpd.handlers.Message``,
which is designed to let user focus on handling of received messages.
* In this example, it just keep the received messages to a list object.
* It defines a special method ``auth_CRAM__MD5`` to tell aiosmtpd that
(1) the handler supports CRAM-MD5 authentication mechanism, and
(2) it can perform the CRAM-MD5 authentication by calling the method.
* See https://aiosmtpd.aio-libs.org/en/latest/auth.html#auth_MECHANISM
* Note that defining a method with this name will let aiosmtpd send
`250-AUTH` line with a parameter `CRAM-MD5` to a logged in client.
I found some example which overrides `auth_EHLO` method to manually
craft such response but it's not necessary in the scope of this example.
* It defines a special method ``auth_CRAM__MD5`` to tell aiosmtpd that
2. Defined "Authenticator" class of which responsibility is authentication.
* It provides authentication logic for PLAIN and LOGIN mechanisms, not for
CRAM-MD5.
"""
from __future__ import annotations
import email.utils
import hashlib
import hmac
import logging
import secrets
import smtplib
import warnings
from collections.abc import Iterator, Sequence
from contextlib import contextmanager
from email.message import Message as Em_Message
from typing import Any, Literal
from aiosmtpd.controller import BaseController, Controller
from aiosmtpd.handlers import Message
from aiosmtpd.smtp import SMTP, AuthResult, Envelope, LoginPassword, Session
AuthMechanism = Literal["CRAM-MD5", "PLAIN", "LOGIN"]
INTERCEPT_CLIENT_DEBUG_MESSAGES = False
_logger = logging.getLogger(__name__)
class MyHandler(Message):
def __init__(self, allowed_credentials: list[tuple[str, str]]) -> None:
super().__init__()
self.received_messages: list[Em_Message] = [] # received messages
self.allowed_credentials = allowed_credentials # username/password pairs
# Overriding this method is the only requirement for a subclass of Message
def handle_message(self, message: Em_Message) -> None:
_logger.info(
"[handler] handle_message() [%s] Subject: %s",
message.get("Date"),
message.get("Subject"),
)
self.received_messages.append(message)
class MyHandlerWithCramMd5(MyHandler):
# Defining a method under special naming convention let aiosmtpd know that this
# handler provides an authentication mechanism which are not supported by default.
# Note that this method does not receive the login credential.
async def auth_CRAM__MD5(self, server: SMTP, args: list[str]) -> AuthResult:
# Send challenge and receive response for it
challenge = secrets.token_bytes(32) # !!CAUTION!! Not standard compliant
response = await server.challenge_auth(challenge)
if not isinstance(response, bytes):
return AuthResult(success=False, handled=True) # Unsuccessful, final
# Validate the hash embedded in the response.
username, hash = response.split() # The response is b"{username} {hash}"
_logger.debug(
"[handler] Incoming credential: username=%s, hash=%s",
username,
hash,
)
for valid_username, valid_password in self.allowed_credentials:
if username != valid_username.encode("utf-8"):
continue
valid_hash = (
hmac.new(
key=valid_password.encode("utf-8"),
msg=challenge,
digestmod=hashlib.md5,
)
.hexdigest()
.encode("utf-8")
)
if hash == valid_hash:
return AuthResult(success=True, handled=True) # Successful and final
# Return unsuccessful non-final result to let client try other mechanisms
return AuthResult(success=False, handled=False) # Unsuccessful and non-final
class MyAuthenticator:
"""A callable object which authenticate a credential."""
def __init__(self, allowed_credentials: list[tuple[str, str]]) -> None:
self.allowed_credentials = [
(user.encode("utf-8"), pwd.encode("utf-8"))
for user, pwd in allowed_credentials
]
def __call__(
self,
_server: SMTP,
_session: Session,
_envelope: Envelope,
mechanism: str,
auth_data: Any,
) -> AuthResult:
"""Authenticate a received login credential."""
_logger.debug(
"[authenticator] %s: %s",
mechanism,
auth_data,
) # !!CAUTION!! Insecure
if mechanism in ("LOGIN", "PLAIN"):
# The aiosmtpd's built-in mechanism passes a LoginPassword as auth_data
assert isinstance(auth_data, LoginPassword)
# Try matching the incoming credential with known allowed ones.
login, pwd = auth_data.login, auth_data.password
for valid_username, valid_password in self.allowed_credentials:
if login == valid_username and pwd == valid_password:
# Found a matching credential
return AuthResult(success=True, handled=True)
# This branch is unreachable in this example.
return AuthResult(success=False, handled=False)
@contextmanager
def start_threaded_server(
handler: MyHandler,
hostname: str,
port: int,
server_auth_required: bool,
server_auth_mechanisms: Sequence[AuthMechanism],
) -> Iterator[BaseController]:
"""Start a thread and run an SMTP server in the thread."""
auth_exclude_mechanism: list[str] = []
if "LOGIN" not in server_auth_mechanisms:
auth_exclude_mechanism.append("LOGIN")
if "PLAIN" not in server_auth_mechanisms:
auth_exclude_mechanism.append("PLAIN")
controller = Controller(
handler,
hostname=hostname,
port=port,
# Parameters below will be passed to smtplib.smtp.SMTP
auth_required=server_auth_required,
auth_exclude_mechanism=auth_exclude_mechanism or None,
auth_require_tls=False, # !!CAUTION!! Insecure
authenticator=MyAuthenticator(handler.allowed_credentials),
)
# Suppress a warning for allowing authentication while TLS is not required.
# (This is an implementation for unit test.)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", module="aiosmtpd.smtp")
_logger.info("Starting a threaded server...")
controller.start()
try:
_logger.info("Waiting until the server finishes...")
yield controller
_logger.info("Server finished.")
finally:
_logger.info("Stopping the server...")
controller.stop()
_logger.info("Stopped the server...")
def sendmail(
hostname: str,
port: int,
username: str | None = None,
password: str | None = None,
auth_mechanisms: Sequence[AuthMechanism] | None = None,
timeout: float = 20.0,
) -> None:
if auth_mechanisms is None:
auth_mechanisms = ["CRAM-MD5", "PLAIN", "LOGIN"]
msg = Em_Message()
msg["Date"] = email.utils.formatdate(localtime=True)
msg["Subject"] = "Hello"
msg.set_payload("World!")
msg["From"] = "[email protected]"
msg["To"] = "[email protected]"
with smtplib.SMTP(hostname, port, timeout=timeout) as client:
if INTERCEPT_CLIENT_DEBUG_MESSAGES: # !!CAUTION!! This is an unreliable hack
client.debuglevel = 2
client._print_debug = lambda *args: _logger.info("[client] %s", args) # type: ignore
# Try login, if a credential is given
if username is not None and password is not None:
# If mechanism negotiation is not required, use `client.login` instead.
mechanism = login(
client,
username,
password,
auth_mechanisms=auth_mechanisms,
)
_logger.info("Login succeeded with authentication mechanism: %s", mechanism)
client.send_message(msg)
def login(
client: smtplib.SMTP,
username: str,
password: str,
auth_mechanisms: Sequence[AuthMechanism],
) -> str:
"""Try login to server.
Returns
-------
mechanism : str
The authentication mechanism used.
Raises
------
smtplib.SMTPException
No suitable authentication method was found.
smtplib.SMTPAuthenticationError
Authentication failed.
"""
# Send EHLO first to fill client object's ESMTP feature dictionary
client.ehlo()
# Select mechanisms to try
advertized_mechanisms = [
s.strip() for s in client.esmtp_features.get("auth", "").split()
]
mechanisms_to_try = [m for m in auth_mechanisms if m in advertized_mechanisms]
_logger.info(
"Mechanism to try: %s (server supports %s)",
mechanisms_to_try,
advertized_mechanisms,
)
# Set the credential as attributes of SMTP object.
# This is required by methods such as ``smtplib.SMTP.auth_plain`)
# (See document of ``smtplib.SMTP.auth``)
client.user = username
client.password = password
# Try authentication using each mechanisms.
auth_error: Exception | None = None
for mech in mechanisms_to_try:
try:
if mech == "CRAM-MD5":
code, _response = client.auth(mech, client.auth_cram_md5)
elif mech == "PLAIN":
code, _response = client.auth(mech, client.auth_plain)
else:
assert mech == "LOGIN"
code, _response = client.auth(mech, client.auth_login)
if (code // 100) == 2:
return mech
except smtplib.SMTPAuthenticationError as exc:
auth_error = exc
# Re-raise the last authentication error, just like smtplib.SMTP.login do.
if auth_error is not None:
raise auth_error
# If no authentication was tried, raise an error for negotiation failure
msg = "No suitable authentication method found."
raise smtplib.SMTPException(msg)
def main() -> None:
allowed_credentials = [
("alice", "wonderland"),
("bob", "sponge"),
]
hostname = "127.0.0.1"
port = 8025
username = "alice"
password = "wonderland"
client_timeout = 4.0 # seconds
client_auth_mechanisms: list[AuthMechanism] = ["CRAM-MD5", "PLAIN", "LOGIN"]
server_auth_required = True
server_auth_mechanisms: list[AuthMechanism] = ["CRAM-MD5", "PLAIN", "LOGIN"]
# Setup logging
logging.basicConfig(
format="%(asctime)s %(levelname)8s %(thread)06x %(name)-16s | %(message)s",
level="DEBUG",
)
logging.getLogger("mail.log").setLevel("INFO") # Lower aiosmtpd's log level
if "CRAM-MD5" in server_auth_mechanisms:
handler = MyHandlerWithCramMd5(allowed_credentials)
else:
handler = MyHandler(allowed_credentials)
with start_threaded_server(
handler,
hostname,
port,
server_auth_required=server_auth_required,
server_auth_mechanisms=server_auth_mechanisms,
) as _controller:
sendmail(
hostname,
port,
username,
password,
auth_mechanisms=client_auth_mechanisms,
timeout=client_timeout,
)
_logger.info("List of received messages:")
for msg in handler.received_messages:
_logger.info(" [%s] %s", msg["Date"], msg["Subject"])
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment