Skip to content

Commit

Permalink
fix: add Starlette middleware for checking Cross-Site Request Forgery…
Browse files Browse the repository at this point in the history
… (CSRF) when trusted origins are specified via environment variable (#4916)
  • Loading branch information
RogerHYang authored Oct 9, 2024
1 parent 199d0eb commit 26f8e4b
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ dependencies = [
"litellm>=1.0.3",
"openai>=1.0.0",
"tenacity",
"protobuf==3.20", # version minimum (for tests)
"protobuf==3.20.2", # version minimum (for tests)
"grpc-interceptor[testing]",
"responses",
"tiktoken",
Expand Down
24 changes: 24 additions & 0 deletions src/phoenix/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,14 @@
"""
The duration, in minutes, before password reset tokens expire.
"""
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS = "PHOENIX_CSRF_TRUSTED_ORIGINS"
"""
A comma-separated list of origins allowed to bypass Cross-Site Request Forgery (CSRF)
protection. This setting is recommended when configuring OAuth2 clients or sending
password reset emails. If this variable is left unspecified or contains no origins, CSRF
protection will not be enabled. In such cases, when a request includes `origin` or `referer`
headers, those values will not be validated.
"""

# SMTP settings
ENV_PHOENIX_SMTP_HOSTNAME = "PHOENIX_SMTP_HOSTNAME"
Expand Down Expand Up @@ -321,6 +329,22 @@ def get_env_refresh_token_expiry() -> timedelta:
return timedelta(minutes=minutes)


def get_env_csrf_trusted_origins() -> List[str]:
origins: List[str] = []
if not (csrf_trusted_origins := os.getenv(ENV_PHOENIX_CSRF_TRUSTED_ORIGINS)):
return origins
for origin in csrf_trusted_origins.split(","):
if not origin:
continue
if not urlparse(origin).hostname:
raise ValueError(
f"The environment variable `{ENV_PHOENIX_CSRF_TRUSTED_ORIGINS}` contains a url "
f"with missing hostname. Please ensure that each url has a valid hostname."
)
origins.append(origin)
return sorted(set(origins))


def get_env_smtp_username() -> str:
return os.getenv(ENV_PHOENIX_SMTP_USERNAME) or ""

Expand Down
33 changes: 33 additions & 0 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
Union,
cast,
)
from urllib.parse import urlparse

import strawberry
from fastapi import APIRouter, Depends, FastAPI
Expand All @@ -42,6 +43,7 @@
from starlette.requests import Request
from starlette.responses import PlainTextResponse, Response
from starlette.staticfiles import StaticFiles
from starlette.status import HTTP_401_UNAUTHORIZED
from starlette.templating import Jinja2Templates
from starlette.types import Scope, StatefulLifespan
from strawberry.extensions import SchemaExtension
Expand All @@ -53,8 +55,10 @@
import phoenix.trace.v1 as pb
from phoenix.config import (
DEFAULT_PROJECT_NAME,
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS,
SERVER_DIR,
OAuth2ClientConfig,
get_env_csrf_trusted_origins,
get_env_host,
get_env_port,
server_instrumentation_is_enabled,
Expand Down Expand Up @@ -226,6 +230,25 @@ async def get_response(self, path: str, scope: Scope) -> Response:
return response


class RequestOriginHostnameValidator(BaseHTTPMiddleware):
def __init__(self, trusted_hostnames: List[str], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._trusted_hostnames = trusted_hostnames

async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
headers = request.headers
for key in "origin", "referer":
if not (url := headers.get(key)):
continue
if urlparse(url).hostname not in self._trusted_hostnames:
return Response(f"untrusted {key}", status_code=HTTP_401_UNAUTHORIZED)
return await call_next(request)


class HeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(
self,
Expand Down Expand Up @@ -660,6 +683,16 @@ def create_app(
)
last_updated_at = LastUpdatedAt()
middlewares: List[Middleware] = [Middleware(HeadersMiddleware)]
if origins := get_env_csrf_trusted_origins():
trusted_hostnames = [h for o in origins if o and (h := urlparse(o).hostname)]
middlewares.append(Middleware(RequestOriginHostnameValidator, trusted_hostnames))
elif email_sender or oauth2_client_configs:
logger.warning(
"CSRF protection can be enabled by listing trusted origins via "
f"the `{ENV_PHOENIX_CSRF_TRUSTED_ORIGINS}` environment variable. "
"This is recommended when setting up OAuth2 clients or sending "
"password reset emails."
)
if authentication_enabled and secret:
token_store = JwtStore(db, secret)
middlewares.append(
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/auth/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from faker import Faker
from phoenix.auth import DEFAULT_SECRET_LENGTH
from phoenix.config import (
ENV_PHOENIX_CSRF_TRUSTED_ORIGINS,
ENV_PHOENIX_DISABLE_RATE_LIMIT,
ENV_PHOENIX_ENABLE_AUTH,
ENV_PHOENIX_SECRET,
Expand Down Expand Up @@ -52,6 +53,7 @@ def _app(
(ENV_PHOENIX_SMTP_PASSWORD, "test"),
(ENV_PHOENIX_SMTP_MAIL_FROM, _fake.email()),
(ENV_PHOENIX_SMTP_VALIDATE_CERTS, "false"),
(ENV_PHOENIX_CSRF_TRUSTED_ORIGINS, ",http://localhost,"),
)
with ExitStack() as stack:
stack.enter_context(mock.patch.dict(os.environ, values))
Expand Down
25 changes: 25 additions & 0 deletions tests/integration/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from datetime import datetime, timedelta, timezone
from functools import partial
from typing import (
Any,
ContextManager,
DefaultDict,
Dict,
Expand Down Expand Up @@ -53,6 +54,7 @@
_grpc_span_exporter,
_Headers,
_http_span_exporter,
_httpx_client,
_initiate_password_reset,
_log_in,
_log_out,
Expand All @@ -73,6 +75,29 @@
_TokenT = TypeVar("_TokenT", _AccessToken, _RefreshToken)


class TestOriginAndReferer:
@pytest.mark.parametrize(
"headers,expectation",
[
[dict(), _OK],
[dict(origin="http://localhost"), _OK],
[dict(referer="http://localhost/xyz"), _OK],
[dict(origin="http://xyz.com"), _EXPECTATION_401],
[dict(referer="http://xyz.com/xyz"), _EXPECTATION_401],
[dict(origin="http://xyz.com", referer="http://localhost/xyz"), _EXPECTATION_401],
[dict(origin="http://localhost", referer="http://xyz.com/xyz"), _EXPECTATION_401],
],
)
def test_csrf_origin_validation(
self,
headers: Dict[str, str],
expectation: ContextManager[Any],
) -> None:
resp = _httpx_client(headers=headers).get("/healthz")
with expectation:
resp.raise_for_status()


class TestLogIn:
@pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN])
def test_can_log_in(
Expand Down

0 comments on commit 26f8e4b

Please sign in to comment.