Tweak middleware

main
Thomas Sileo 2022-07-14 15:16:45 +02:00
parent a39f874ad5
commit d245201851
2 changed files with 11 additions and 35 deletions

View File

@ -145,8 +145,7 @@ async def save_actor(db_session: AsyncSession, ap_actor: ap.RawObject) -> "Actor
handle=_handle(ap_actor), handle=_handle(ap_actor),
) )
db_session.add(actor) db_session.add(actor)
await db_session.commit() await db_session.flush()
await db_session.refresh(actor)
return actor return actor

View File

@ -92,54 +92,29 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
class CustomMiddleware: class CustomMiddleware:
def __init__( def __init__(
self, self,
app: "ASGI3Application", app: ASGI3Application,
) -> None: ) -> None:
self.app = app self.app = app
async def __call__( async def __call__(
self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable self, scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable
) -> None: ) -> None:
""" # We only care about HTTP requests
if scope["type"] in ("http", "websocket"):
scope = cast(HTTPScope | WebSocketScope, scope)
client_addr: tuple[str, int] | None = scope.get("client")
client_host = client_addr[0] if client_addr else None
if self.always_trust or client_host in self.trusted_hosts:
headers = dict(scope["headers"])
if b"x-forwarded-proto" in headers:
# Determine if the incoming request was http or https based on
# the X-Forwarded-Proto header.
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
scope["scheme"] = x_forwarded_proto.strip() # type: ignore[index]
if b"x-forwarded-for" in headers:
# Determine the client address from the last trusted IP in the
# X-Forwarded-For header. We've lost the connecting client's port
# information by now, so only include the host.
x_forwarded_for = headers[b"x-forwarded-for"].decode("latin1")
x_forwarded_for_hosts = [
item.strip() for item in x_forwarded_for.split(",")
]
host = self.get_trusted_client_host(x_forwarded_for_hosts)
port = 0
scope["client"] = (host, port) # type: ignore[arg-type]
"""
if scope["type"] != "http": if scope["type"] != "http":
await self.app(scope, receive, send) await self.app(scope, receive, send)
return return
instance = {"http_status_code": None} response_details = {}
start_time = time.perf_counter() start_time = time.perf_counter()
request_id = os.urandom(8).hex() request_id = os.urandom(8).hex()
async def send_wrapper(message: Message) -> None: async def send_wrapper(message: Message) -> None:
if message["type"] == "http.response.start": if message["type"] == "http.response.start":
instance["http_status_code"] = message["status"]
# Extract the HTTP response status code
response_details["status_code"] = message["status"]
# And add the security headers
headers = MutableHeaders(scope=message) headers = MutableHeaders(scope=message)
headers["X-Request-ID"] = request_id headers["X-Request-ID"] = request_id
headers["Server"] = "microblogpub" headers["Server"] = "microblogpub"
@ -160,6 +135,8 @@ class CustomMiddleware:
await send(message) # type: ignore await send(message) # type: ignore
# Make loguru ouput the request ID on every log statement within
# the request
with logger.contextualize(request_id=request_id): with logger.contextualize(request_id=request_id):
client_host, client_port = scope["client"] # type: ignore client_host, client_port = scope["client"] # type: ignore
scheme = scope["scheme"] scheme = scope["scheme"]
@ -175,7 +152,7 @@ class CustomMiddleware:
finally: finally:
elapsed_time = time.perf_counter() - start_time elapsed_time = time.perf_counter() - start_time
logger.info( logger.info(
f"status_code={instance['http_status_code']} " f"status_code={response_details['status_code']} "
f"{elapsed_time=:.2f}s" f"{elapsed_time=:.2f}s"
) )