Switch to raw ASGI middleware
parent
dd50db40d9
commit
a39f874ad5
|
@ -29,4 +29,7 @@ def now() -> datetime.datetime:
|
||||||
|
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
async with async_session() as session:
|
async with async_session() as session:
|
||||||
yield session
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
await session.close()
|
||||||
|
|
159
app/main.py
159
app/main.py
|
@ -9,6 +9,10 @@ from typing import MutableMapping
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from asgiref.typing import ASGI3Application
|
||||||
|
from asgiref.typing import ASGIReceiveCallable
|
||||||
|
from asgiref.typing import ASGISendCallable
|
||||||
|
from asgiref.typing import Scope
|
||||||
from cachetools import LFUCache
|
from cachetools import LFUCache
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
@ -28,7 +32,9 @@ from sqlalchemy import func
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from starlette.background import BackgroundTask
|
from starlette.background import BackgroundTask
|
||||||
|
from starlette.datastructures import MutableHeaders
|
||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import Message
|
||||||
|
|
||||||
from app import activitypub as ap
|
from app import activitypub as ap
|
||||||
from app import admin
|
from app import admin
|
||||||
|
@ -82,12 +88,107 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
|
||||||
# - [ ] Dockerization
|
# - [ ] Dockerization
|
||||||
# - [ ] cleanup tasks
|
# - [ ] cleanup tasks
|
||||||
|
|
||||||
|
|
||||||
|
class CustomMiddleware:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: "ASGI3Application",
|
||||||
|
) -> None:
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
if scope["type"] in ("http", "websocket"):
|
||||||
|
scope = cast(HTTPScope | WebSocketScope, scope)
|
||||||
|
client_addr: tuple[str, int] | None = scope.get("client")
|
||||||
|
client_host = client_addr[0] if client_addr else None
|
||||||
|
|
||||||
|
if self.always_trust or client_host in self.trusted_hosts:
|
||||||
|
headers = dict(scope["headers"])
|
||||||
|
|
||||||
|
if b"x-forwarded-proto" in headers:
|
||||||
|
# Determine if the incoming request was http or https based on
|
||||||
|
# the X-Forwarded-Proto header.
|
||||||
|
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
|
||||||
|
scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index]
|
||||||
|
|
||||||
|
if b"x-forwarded-for" in headers:
|
||||||
|
# Determine the client address from the last trusted IP in the
|
||||||
|
# X-Forwarded-For header. We've lost the connecting client's port
|
||||||
|
# information by now, so only include the host.
|
||||||
|
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
|
||||||
|
x_forwarded_for_hosts = [
|
||||||
|
item.strip() for item in x_forwarded_for.split(",")
|
||||||
|
]
|
||||||
|
host = self.get_trusted_client_host(x_forwarded_for_hosts)
|
||||||
|
port = 0
|
||||||
|
scope["client"] = (host, port) # type: ignore[arg-type]
|
||||||
|
"""
|
||||||
|
|
||||||
|
if scope["type"] != "http":
|
||||||
|
await self.app(scope, receive, send)
|
||||||
|
return
|
||||||
|
|
||||||
|
instance = {"http_status_code": None}
|
||||||
|
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
request_id = os.urandom(8).hex()
|
||||||
|
|
||||||
|
async def send_wrapper(message: Message) -> None:
|
||||||
|
if message["type"] == "http.response.start":
|
||||||
|
instance["http_status_code"] = message["status"]
|
||||||
|
|
||||||
|
headers = MutableHeaders(scope=message)
|
||||||
|
headers["X-Request-ID"] = request_id
|
||||||
|
headers["Server"] = "microblogpub"
|
||||||
|
headers[
|
||||||
|
"referrer-policy"
|
||||||
|
] = "no-referrer, strict-origin-when-cross-origin"
|
||||||
|
headers["x-content-type-options"] = "nosniff"
|
||||||
|
headers["x-xss-protection"] = "1; mode=block"
|
||||||
|
headers["x-frame-options"] = "SAMEORIGIN"
|
||||||
|
# TODO(ts): disallow inline CSS?
|
||||||
|
headers["content-security-policy"] = (
|
||||||
|
"default-src 'self'" + " style-src 'self' 'unsafe-inline';"
|
||||||
|
)
|
||||||
|
if not DEBUG:
|
||||||
|
headers[
|
||||||
|
"strict-transport-security"
|
||||||
|
] = "max-age=63072000; includeSubdomains"
|
||||||
|
|
||||||
|
await send(message) # type: ignore
|
||||||
|
|
||||||
|
with logger.contextualize(request_id=request_id):
|
||||||
|
client_host, client_port = scope["client"] # type: ignore
|
||||||
|
scheme = scope["scheme"]
|
||||||
|
server_host, server_port = scope["server"] # type: ignore
|
||||||
|
request_method = scope["method"]
|
||||||
|
request_path = scope["path"]
|
||||||
|
logger.info(
|
||||||
|
f"{client_host}:{client_port} - "
|
||||||
|
f"{request_method} {scheme}://{server_host}:{server_port}{request_path}"
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await self.app(scope, receive, send_wrapper) # type: ignore
|
||||||
|
finally:
|
||||||
|
elapsed_time = time.perf_counter() - start_time
|
||||||
|
logger.info(
|
||||||
|
f"status_code={instance['http_status_code']} "
|
||||||
|
f"{elapsed_time=:.2f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(docs_url=None, redoc_url=None)
|
app = FastAPI(docs_url=None, redoc_url=None)
|
||||||
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
app.mount("/static", StaticFiles(directory="app/static"), name="static")
|
||||||
app.include_router(admin.router, prefix="/admin")
|
app.include_router(admin.router, prefix="/admin")
|
||||||
app.include_router(admin.unauthenticated_router, prefix="/admin")
|
app.include_router(admin.unauthenticated_router, prefix="/admin")
|
||||||
app.include_router(indieauth.router)
|
app.include_router(indieauth.router)
|
||||||
app.include_router(webmentions.router)
|
app.include_router(webmentions.router)
|
||||||
|
app.add_middleware(CustomMiddleware)
|
||||||
|
|
||||||
logger.configure(extra={"request_id": "no_req_id"})
|
logger.configure(extra={"request_id": "no_req_id"})
|
||||||
logger.remove()
|
logger.remove()
|
||||||
|
@ -100,64 +201,6 @@ logger_format = (
|
||||||
logger.add(sys.stdout, format=logger_format)
|
logger.add(sys.stdout, format=logger_format)
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
|
||||||
async def request_middleware(request, call_next):
|
|
||||||
start_time = time.perf_counter()
|
|
||||||
request_id = os.urandom(8).hex()
|
|
||||||
with logger.contextualize(request_id=request_id):
|
|
||||||
logger.info(
|
|
||||||
f"{request.client.host}:{request.client.port} - "
|
|
||||||
f"{request.method} {request.url}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await call_next(request)
|
|
||||||
response.headers["X-Request-ID"] = request_id
|
|
||||||
response.headers["Server"] = "microblogpub"
|
|
||||||
elapsed_time = time.perf_counter() - start_time
|
|
||||||
logger.info(f"status_code={response.status_code} {elapsed_time=:.2f}s")
|
|
||||||
return response
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Request failed")
|
|
||||||
raise
|
|
||||||
|
|
||||||
|
|
||||||
@app.middleware("http")
|
|
||||||
async def add_security_headers(request: Request, call_next):
|
|
||||||
try:
|
|
||||||
response = await call_next(request)
|
|
||||||
except RuntimeError as exc:
|
|
||||||
# https://github.com/encode/starlette/discussions/1527#discussioncomment-2234702
|
|
||||||
if await request.is_disconnected() and str(exc) == "No response returned.":
|
|
||||||
return Response(status_code=204)
|
|
||||||
|
|
||||||
response.headers["referrer-policy"] = "no-referrer, strict-origin-when-cross-origin"
|
|
||||||
response.headers["x-content-type-options"] = "nosniff"
|
|
||||||
response.headers["x-xss-protection"] = "1; mode=block"
|
|
||||||
response.headers["x-frame-options"] = "SAMEORIGIN"
|
|
||||||
if request.url.path.startswith("/admin/login") or (
|
|
||||||
is_current_user_admin(request)
|
|
||||||
and not (
|
|
||||||
request.url.path.startswith("/attachments")
|
|
||||||
or request.url.path.startswith("/proxy")
|
|
||||||
or request.url.path.startswith("/static")
|
|
||||||
)
|
|
||||||
):
|
|
||||||
# Prevent caching (to prevent caching CSRF tokens)
|
|
||||||
response.headers["Cache-Control"] = "private"
|
|
||||||
|
|
||||||
# TODO(ts): disallow inline CSS?
|
|
||||||
if DEBUG:
|
|
||||||
return response
|
|
||||||
response.headers["content-security-policy"] = (
|
|
||||||
"default-src 'self'" + " style-src 'self' 'unsafe-inline';"
|
|
||||||
)
|
|
||||||
if not DEBUG:
|
|
||||||
response.headers[
|
|
||||||
"strict-transport-security"
|
|
||||||
] = "max-age=63072000; includeSubdomains"
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
class ActivityPubResponse(JSONResponse):
|
class ActivityPubResponse(JSONResponse):
|
||||||
media_type = "application/activity+json"
|
media_type = "application/activity+json"
|
||||||
|
|
||||||
|
|
|
@ -1202,7 +1202,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "7bc5ba65a004438ac015dcd01c27e1d327dbf491f9f881a48a2a790bb0bbf710"
|
content-hash = "4353bb98b40254eea5277799de3329b6658e21178a6da44113e78c897c7f140b"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
aiosqlite = [
|
aiosqlite = [
|
||||||
|
|
|
@ -40,6 +40,7 @@ aiosqlite = "^0.17.0"
|
||||||
cachetools = "^5.2.0"
|
cachetools = "^5.2.0"
|
||||||
humanize = "^4.2.3"
|
humanize = "^4.2.3"
|
||||||
tabulate = "^0.8.10"
|
tabulate = "^0.8.10"
|
||||||
|
asgiref = "^3.5.2"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
|
|
Loading…
Reference in New Issue