Media proxy cleanup

main
Thomas Sileo 2022-07-19 08:12:49 +02:00
parent 66a9778995
commit 9882fc555c
2 changed files with 52 additions and 55 deletions

View File

@ -9,6 +9,7 @@ from typing import MutableMapping
from typing import Type from typing import Type
import httpx import httpx
import starlette
from asgiref.typing import ASGI3Application from asgiref.typing import ASGI3Application
from asgiref.typing import ASGIReceiveCallable from asgiref.typing import ASGIReceiveCallable
from asgiref.typing import ASGISendCallable from asgiref.typing import ASGISendCallable
@ -57,7 +58,6 @@ from app.config import DOMAIN
from app.config import ID from app.config import ID
from app.config import USER_AGENT from app.config import USER_AGENT
from app.config import USERNAME from app.config import USERNAME
from app.config import generate_csrf_token
from app.config import is_activitypub_requested from app.config import is_activitypub_requested
from app.config import verify_csrf_token from app.config import verify_csrf_token
from app.database import AsyncSession from app.database import AsyncSession
@ -76,6 +76,7 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
# TODO(ts): # TODO(ts):
# #
# Next: # Next:
# - Article support
# - indieauth tweaks # - indieauth tweaks
# - API for posting notes # - API for posting notes
# - allow to block servers # - allow to block servers
@ -390,7 +391,6 @@ async def following(
.all() .all()
) )
# TODO: support next_cursor/prev_cursor
actors_metadata = {} actors_metadata = {}
if is_current_user_admin(request): if is_current_user_admin(request):
actors_metadata = await get_actors_metadata( actors_metadata = await get_actors_metadata(
@ -482,13 +482,17 @@ async def _check_outbox_object_acl(
ap.VisibilityEnum.UNLISTED, ap.VisibilityEnum.UNLISTED,
]: ]:
return None return None
elif ap_object.visibility == ap.VisibilityEnum.FOLLOWERS_ONLY: elif ap_object.visibility == ap.VisibilityEnum.FOLLOWERS_ONLY:
# Is the signing actor a follower?
followers = await boxes.fetch_actor_collection( followers = await boxes.fetch_actor_collection(
db_session, BASE_URL + "/followers" db_session, BASE_URL + "/followers"
) )
if httpsig_info.signed_by_ap_actor_id in [actor.ap_id for actor in followers]: if httpsig_info.signed_by_ap_actor_id in [actor.ap_id for actor in followers]:
return None return None
elif ap_object.visibility == ap.VisibilityEnum.DIRECT: elif ap_object.visibility == ap.VisibilityEnum.DIRECT:
# Is the signing actor targeted in the object audience?
audience = ap_object.ap_object.get("to", []) + ap_object.ap_object.get("cc", []) audience = ap_object.ap_object.get("to", []) + ap_object.ap_object.get("cc", [])
if httpsig_info.signed_by_ap_actor_id in audience: if httpsig_info.signed_by_ap_actor_id in audience:
return None return None
@ -718,7 +722,7 @@ async def get_remote_follow(
db_session, db_session,
request, request,
"remote_follow.html", "remote_follow.html",
{"remote_follow_csrf_token": generate_csrf_token()}, {},
) )
@ -733,6 +737,7 @@ async def post_remote_follow(
remote_follow_template = await get_remote_follow_template(profile) remote_follow_template = await get_remote_follow_template(profile)
if not remote_follow_template: if not remote_follow_template:
# TODO(ts): error message to user
raise HTTPException(status_code=404) raise HTTPException(status_code=404)
return RedirectResponse( return RedirectResponse(
@ -812,12 +817,9 @@ async def nodeinfo(
proxy_client = httpx.AsyncClient(follow_redirects=True, http2=True) proxy_client = httpx.AsyncClient(follow_redirects=True, http2=True)
@app.get("/proxy/media/{encoded_url}") async def _proxy_get(
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse: request: starlette.requests.Request, url: str, stream: bool
# Decode the base64-encoded URL ) -> httpx.Response:
url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url)
# Request the URL (and filter request headers) # Request the URL (and filter request headers)
proxy_req = proxy_client.build_request( proxy_req = proxy_client.build_request(
request.method, request.method,
@ -830,13 +832,32 @@ async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResp
] ]
+ [(b"user-agent", USER_AGENT.encode())], + [(b"user-agent", USER_AGENT.encode())],
) )
proxy_resp = await proxy_client.send(proxy_req, stream=True) return await proxy_client.send(proxy_req, stream=stream)
# Filter the headers
proxy_resp_headers = [
(k, v) def _filter_proxy_resp_headers(
for (k, v) in proxy_resp.headers.items() proxy_resp: httpx.Response,
if k.lower() allowed_headers: list[str],
in [ ) -> dict[str, str]:
return {
k: v for (k, v) in proxy_resp.headers.items() if k.lower() in allowed_headers
}
@app.get("/proxy/media/{encoded_url}")
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse:
# Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url)
proxy_resp = await _proxy_get(request, url, stream=True)
return StreamingResponse(
proxy_resp.aiter_raw(),
status_code=proxy_resp.status_code,
headers=_filter_proxy_resp_headers(
proxy_resp,
[
"content-length", "content-length",
"content-type", "content-type",
"content-range", "content-range",
@ -845,12 +866,8 @@ async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResp
"expires", "expires",
"date", "date",
"last-modified", "last-modified",
] ],
] ),
return StreamingResponse(
proxy_resp.aiter_raw(),
status_code=proxy_resp.status_code,
headers=dict(proxy_resp_headers),
background=BackgroundTask(proxy_resp.aclose), background=BackgroundTask(proxy_resp.aclose),
) )
@ -876,25 +893,7 @@ async def serve_proxy_media_resized(
headers=resp_headers, headers=resp_headers,
) )
# Request the URL (and filter request headers) proxy_resp = await _proxy_get(request, url, stream=False)
async with httpx.AsyncClient() as client:
proxy_resp = await client.get(
url,
headers=[
(k, v)
for (k, v) in request.headers.raw
if k.lower()
not in [
b"host",
b"cookie",
b"x-forwarded-for",
b"x-real-ip",
b"user-agent",
]
]
+ [(b"user-agent", USER_AGENT.encode())],
follow_redirects=True,
)
if proxy_resp.status_code != 200: if proxy_resp.status_code != 200:
return PlainTextResponse( return PlainTextResponse(
proxy_resp.content, proxy_resp.content,
@ -902,18 +901,16 @@ async def serve_proxy_media_resized(
) )
# Filter the headers # Filter the headers
proxy_resp_headers = { proxy_resp_headers = _filter_proxy_resp_headers(
k: v proxy_resp,
for (k, v) in proxy_resp.headers.items() [
if k.lower()
in [
"content-type", "content-type",
"etag", "etag",
"cache-control", "cache-control",
"expires", "expires",
"last-modified", "last-modified",
] ],
} )
try: try:
out = BytesIO(proxy_resp.content) out = BytesIO(proxy_resp.content)

View File

@ -11,9 +11,9 @@
<div class="box"> <div class="box">
<h2>Remotely follow {{ local_actor.display_name }}</h2> <h2>Remotely follow {{ local_actor.display_name }}</h2>
<form class="form" action="{{ url_for("post_remote_follow") }}" method="POST"> <form class="form" action="{{ url_for("post_remote_follow") }}" method="POST">
<input type="hidden" name="csrf_token" value="{{remote_follow_csrf_token}}"> {{ utils.embed_csrf_token() }}
<input type="text" name="profile" placeholder="you@instance.tld" autofocus> <input type="text" name="profile" placeholder="you@instance.tld" autofocus>
<input type="submit" value="Follow"> <input type="submit" value="follow">
</form> </form>
</div> </div>