diff --git a/CHANGELOG.md b/CHANGELOG.md index e76f31e26..d1b5edf61 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,7 @@ # Change Log +- Add [ASGI TLS Extension version: 0.2](https://github.com/jonfoster/asgiref/blob/master/specs/tls.rst) to h11, httptools, websockets, and wsproto impl (#1119) + ## 0.20.0 - 2022-11-20 ### Added diff --git a/tests/conftest.py b/tests/conftest.py index d26d39432..9485b1d82 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -43,6 +43,13 @@ def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: ) +@pytest.fixture +def tls_client_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: + return tls_certificate_authority.issue_cert( + "client@example.com", common_name="uvicorn client" + ) + + @pytest.fixture def tls_ca_certificate_pem_path(tls_certificate_authority: trustme.CA): with tls_certificate_authority.cert_pem.tempfile() as ca_cert_pem: @@ -96,6 +103,13 @@ def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext: return ssl_ctx +@pytest.fixture +def tls_client_certificate_pem_path(tls_client_certificate: trustme.LeafCert): + private_key_and_cert_chain = tls_client_certificate.private_key_and_cert_chain_pem + with private_key_and_cert_chain.tempfile() as client_cert_pem: + yield client_cert_pem + + @pytest.fixture(scope="package") def reload_directory_structure(tmp_path_factory: pytest.TempPathFactory): """ diff --git a/tests/test_ssl.py b/tests/test_ssl.py index d60bcf54e..b113d89b4 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -1,3 +1,5 @@ +import ssl + import httpx import pytest @@ -34,6 +36,56 @@ async def test_run( assert response.status_code == 204 +@pytest.mark.anyio +async def test_run_httptools_client_cert( + tls_ca_ssl_context, + tls_ca_certificate_pem_path, + tls_ca_certificate_private_key_path, + tls_client_certificate_pem_path, +): + config = Config( + app=app, + loop="asyncio", + http="httptools", + limit_max_requests=1, + ssl_keyfile=tls_ca_certificate_private_key_path, + ssl_certfile=tls_ca_certificate_pem_path, + ssl_ca_certs=tls_ca_certificate_pem_path, + ssl_cert_reqs=ssl.CERT_REQUIRED, + ) + async with run_server(config): + async with httpx.AsyncClient( + verify=tls_ca_ssl_context, cert=tls_client_certificate_pem_path + ) as client: + response = await client.get("https://127.0.0.1:8000") + assert response.status_code == 204 + + +@pytest.mark.anyio +async def test_run_h11_client_cert( + tls_ca_ssl_context, + tls_ca_certificate_pem_path, + tls_ca_certificate_private_key_path, + tls_client_certificate_pem_path, +): + config = Config( + app=app, + loop="asyncio", + http="h11", + limit_max_requests=1, + ssl_keyfile=tls_ca_certificate_private_key_path, + ssl_certfile=tls_ca_certificate_pem_path, + ssl_ca_certs=tls_ca_certificate_pem_path, + ssl_cert_reqs=ssl.CERT_REQUIRED, + ) + async with run_server(config): + async with httpx.AsyncClient( + verify=tls_ca_ssl_context, cert=tls_client_certificate_pem_path + ) as client: + response = await client.get("https://127.0.0.1:8000") + assert response.status_code == 204 + + @pytest.mark.anyio async def test_run_chain( tls_ca_ssl_context, diff --git a/uvicorn/config.py b/uvicorn/config.py index 0ebc562c1..e7993f1d9 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -285,6 +285,7 @@ def __init__( self.callback_notify = callback_notify self.ssl_keyfile = ssl_keyfile self.ssl_certfile = ssl_certfile + self.ssl_cert_pem: Optional[str] = None self.ssl_keyfile_password = ssl_keyfile_password self.ssl_version = ssl_version self.ssl_cert_reqs = ssl_cert_reqs @@ -446,6 +447,8 @@ def load(self) -> None: ca_certs=self.ssl_ca_certs, ciphers=self.ssl_ciphers, ) + with open(self.ssl_certfile) as file: + self.ssl_cert_pem = file.read() else: self.ssl = None diff --git a/uvicorn/protocols/http/h11_impl.py b/uvicorn/protocols/http/h11_impl.py index c2764b028..430b9ba74 100644 --- a/uvicorn/protocols/http/h11_impl.py +++ b/uvicorn/protocols/http/h11_impl.py @@ -2,7 +2,17 @@ import http import logging import sys -from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) from urllib.parse import unquote import h11 @@ -20,6 +30,7 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState @@ -99,6 +110,7 @@ def __init__( self.server: Optional[Tuple[str, int]] = None self.client: Optional[Tuple[str, int]] = None self.scheme: Optional[Literal["http", "https"]] = None + self.tls: Optional[Dict[str, Any]] = None # Per-request state self.scope: HTTPScope = None # type: ignore[assignment] @@ -117,6 +129,9 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" + if self.scheme == "https": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) @@ -216,14 +231,17 @@ def handle_events(self) -> None: "http_version": event.http_version.decode("ascii"), "server": self.server, "client": self.client, - "scheme": self.scheme, + "scheme": self.scheme, # type: ignore[typeddict-item] "method": event.method.decode("ascii"), "root_path": self.root_path, "path": unquote(raw_path.decode("ascii")), "raw_path": raw_path, "query_string": query_string, "headers": self.headers, + "extensions": {}, } + if self.scheme == "https": + self.scope["extensions"]["tls"] = self.tls # type: ignore upgrade = self._get_upgrade() if upgrade == b"websocket" and self._should_upgrade_to_ws(): diff --git a/uvicorn/protocols/http/httptools_impl.py b/uvicorn/protocols/http/httptools_impl.py index 734e8945d..64bd6f899 100644 --- a/uvicorn/protocols/http/httptools_impl.py +++ b/uvicorn/protocols/http/httptools_impl.py @@ -6,7 +6,18 @@ import urllib from asyncio.events import TimerHandle from collections import deque -from typing import TYPE_CHECKING, Callable, Deque, List, Optional, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Deque, + Dict, + List, + Optional, + Tuple, + Union, + cast, +) import httptools @@ -23,6 +34,7 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState @@ -98,6 +110,7 @@ def __init__( self.client: Optional[Tuple[str, int]] = None self.scheme: Optional[Literal["http", "https"]] = None self.pipeline: Deque[Tuple[RequestResponseCycle, ASGI3Application]] = deque() + self.tls: Optional[Dict[str, Any]] = None # Per-request state self.scope: HTTPScope = None # type: ignore[assignment] @@ -117,6 +130,9 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "https" if is_ssl(transport) else "http" + if self.scheme == "https": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sHTTP connection made", prefix) @@ -237,7 +253,10 @@ def on_message_begin(self) -> None: "scheme": self.scheme, "root_path": self.root_path, "headers": self.headers, + "extensions": {}, } + if self.scheme == "https": + self.scope["extensions"]["tls"] = self.tls # type: ignore # Parser callbacks def on_url(self, url: bytes) -> None: diff --git a/uvicorn/protocols/utils.py b/uvicorn/protocols/utils.py index fbd4b4d5d..96b9ac09d 100644 --- a/uvicorn/protocols/utils.py +++ b/uvicorn/protocols/utils.py @@ -1,10 +1,30 @@ import asyncio +import ssl import urllib.parse -from typing import TYPE_CHECKING, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple if TYPE_CHECKING: from asgiref.typing import WWWScope +RDNS_MAPPING: Dict[str, str] = { + "commonName": "CN", + "localityName": "L", + "stateOrProvinceName": "ST", + "organizationName": "O", + "organizationalUnitName": "OU", + "countryName": "C", + "streetAddress": "STREET", + "domainComponent": "DC", + "userId": "UID", +} + +TLS_VERSION_MAP: Dict[str, int] = { + "TLSv1": 0x0301, + "TLSv1.1": 0x0302, + "TLSv1.2": 0x0303, + "TLSv1.3": 0x0304, +} + def get_remote_addr(transport: asyncio.Transport) -> Optional[Tuple[str, int]]: socket_info = transport.get_extra_info("socket") @@ -53,3 +73,56 @@ def get_path_with_query_string(scope: "WWWScope") -> str: path_with_query_string, scope["query_string"].decode("ascii") ) return path_with_query_string + + +def get_tls_info( + transport: asyncio.Transport, server_pem: Optional[str] = None +) -> Dict: + + ### + # server_cert: Unable to set from transport information, need to read from config + # client_cert_chain: Just the peercert, currently no access to the full cert chain + # client_cert_name: + # client_cert_error: No access to this + # tls_version: + # cipher_suite: Too hard to convert without direct access to openssl + ### + + ssl_info: Dict[str, Any] = { + "server_cert": server_pem, + "client_cert_chain": [], + "client_cert_name": None, + "client_cert_error": None, + "tls_version": None, + "cipher_suite": None, + } + + ssl_object = transport.get_extra_info("ssl_object", default=None) + peercert = ssl_object.getpeercert() + + if peercert: + rdn_strings = [] + for rdn in peercert["subject"]: + rdn_strings.append( + "+".join( + [ + "%s = %s" % (RDNS_MAPPING[entry[0]], entry[1]) + for entry in reversed(rdn) + if entry[0] in RDNS_MAPPING + ] + ) + ) + + ssl_info["client_cert_chain"] = [ + ssl.DER_cert_to_PEM_cert(ssl_object.getpeercert(binary_form=True)) + ] + ssl_info["client_cert_name"] = ", ".join(rdn_strings) if rdn_strings else "" + + ssl_info["tls_version"] = ( + TLS_VERSION_MAP[ssl_object.version()] + if ssl_object.version() in TLS_VERSION_MAP + else None + ) + ssl_info["cipher_suite"] = list(ssl_object.cipher()) + + return ssl_info diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 297203ec6..6142dca9e 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -2,7 +2,17 @@ import http import logging import sys -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Tuple, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Sequence, + Tuple, + Union, + cast, +) from urllib.parse import unquote import websockets @@ -19,6 +29,7 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState @@ -80,6 +91,7 @@ def __init__( self.server: Optional[Tuple[str, int]] = None self.client: Optional[Tuple[str, int]] = None self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + self.tls: Optional[Dict[str, Any]] = None # Connection events self.scope: WebSocketScope = None # type: ignore[assignment] @@ -121,6 +133,9 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "wss" if is_ssl(transport) else "ws" + if self.scheme == "wss": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) + if self.logger.isEnabledFor(TRACE_LOG_LEVEL): prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) @@ -190,7 +205,11 @@ async def process_request( "query_string": query_string.encode("ascii"), "headers": asgi_headers, "subprotocols": subprotocols, + "extensions": {}, } + if self.scheme == "wss": + self.scope["extensions"]["tls"] = self.tls # type: ignore + task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete) self.tasks.add(task) diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 1d76f3a88..951f13e07 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -16,6 +16,7 @@ get_local_addr, get_path_with_query_string, get_remote_addr, + get_tls_info, is_ssl, ) from uvicorn.server import ServerState @@ -70,6 +71,7 @@ def __init__( self.server: typing.Optional[typing.Tuple[str, int]] = None self.client: typing.Optional[typing.Tuple[str, int]] = None self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + self.tls: typing.Optional[typing.Dict[str, typing.Any]] = None # WebSocket state self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() @@ -97,6 +99,9 @@ def connection_made( # type: ignore[override] self.client = get_remote_addr(transport) self.scheme = "wss" if is_ssl(transport) else "ws" + if self.scheme == "wss": + self.tls = get_tls_info(transport, self.config.ssl_cert_pem) + if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) @@ -183,8 +188,11 @@ def handle_connect(self, event: events.Request) -> None: "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.subprotocols, - "extensions": None, + "extensions": {}, } + if self.scheme == "wss": + self.scope["extensions"]["tls"] = self.tls # type: ignore + self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) task.add_done_callback(self.on_task_complete)