From 73dceee0f59e93f1609096e921a0936e1ae15592 Mon Sep 17 00:00:00 2001 From: Thomas Sileo Date: Fri, 2 Dec 2022 19:28:59 +0100 Subject: [PATCH] Fix proxy client --- app/main.py | 77 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/app/main.py b/app/main.py index 4daaa1f..bd8f298 100644 --- a/app/main.py +++ b/app/main.py @@ -1180,33 +1180,31 @@ async def nodeinfo( async def _proxy_get( - request: starlette.requests.Request, url: str, stream: bool + proxy_client: httpx.AsyncClient, + request: starlette.requests.Request, + url: str, + stream: bool, ) -> httpx.Response: - async with httpx.AsyncClient( - follow_redirects=True, - timeout=httpx.Timeout(timeout=10.0), - transport=httpx.AsyncHTTPTransport(retries=1), - ) as proxy_client: - # Request the URL (and filter request headers) - proxy_req = proxy_client.build_request( - request.method, - url, - headers=[ - (k, v) - for (k, v) in request.headers.raw - if k.lower() - not in [ - b"host", - b"cookie", - b"x-forwarded-for", - b"x-forwarded-proto", - b"x-real-ip", - b"user-agent", - ] + # Request the URL (and filter request headers) + proxy_req = proxy_client.build_request( + request.method, + url, + headers=[ + (k, v) + for (k, v) in request.headers.raw + if k.lower() + not in [ + b"host", + b"cookie", + b"x-forwarded-for", + b"x-forwarded-proto", + b"x-real-ip", + b"user-agent", ] - + [(b"user-agent", USER_AGENT.encode())], - ) - return await proxy_client.send(proxy_req, stream=stream) + ] + + [(b"user-agent", USER_AGENT.encode())], + ) + return await proxy_client.send(proxy_req, stream=stream) def _filter_proxy_resp_headers( @@ -1232,18 +1230,29 @@ async def serve_proxy_media( exp: int, sig: str, encoded_url: str, + background_tasks: fastapi.BackgroundTasks, ) -> StreamingResponse | PlainTextResponse: # Decode the base64-encoded URL url = base64.urlsafe_b64decode(encoded_url).decode() check_url(url) media.verify_proxied_media_sig(exp, url, sig) - proxy_resp = await _proxy_get(request, url, stream=True) + proxy_client = httpx.AsyncClient( + follow_redirects=True, + timeout=httpx.Timeout(timeout=10.0), + transport=httpx.AsyncHTTPTransport(retries=1), + ) + + async def _close_proxy_client(): + await proxy_client.aclose() + + background_tasks.add_task(_close_proxy_client) + proxy_resp = await _proxy_get(proxy_client, request, url, stream=True) if proxy_resp.status_code >= 300: logger.info(f"failed to proxy {url}, got {proxy_resp.status_code}") + await proxy_resp.aclose() return PlainTextResponse( - "proxy error", status_code=proxy_resp.status_code, ) @@ -1276,6 +1285,7 @@ async def serve_proxy_media_resized( sig: str, encoded_url: str, size: int, + background_tasks: fastapi.BackgroundTasks, ) -> PlainTextResponse: if size not in {50, 740}: raise ValueError("Unsupported size") @@ -1293,9 +1303,20 @@ async def serve_proxy_media_resized( headers=resp_headers, ) - proxy_resp = await _proxy_get(request, url, stream=False) + proxy_client = httpx.AsyncClient( + follow_redirects=True, + timeout=httpx.Timeout(timeout=10.0), + transport=httpx.AsyncHTTPTransport(retries=1), + ) + + async def _close_proxy_client(): + await proxy_client.aclose() + + background_tasks.add_task(_close_proxy_client) + proxy_resp = await _proxy_get(proxy_client, request, url, stream=False) if proxy_resp.status_code >= 300: logger.info(f"failed to proxy {url}, got {proxy_resp.status_code}") + await proxy_resp.aclose() return PlainTextResponse( "proxy error", status_code=proxy_resp.status_code,