Make most of the HTTP requests async

main
Thomas Sileo 2022-06-30 00:28:07 +02:00
parent 3e17e17e2a
commit d371e3cd4f
12 changed files with 88 additions and 82 deletions

View File

@ -103,8 +103,9 @@ class NotAnObjectError(Exception):
self.resp = resp self.resp = resp
def fetch(url: str, params: dict[str, Any] | None = None) -> dict[str, Any]: async def fetch(url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
resp = httpx.get( async with httpx.AsyncClient() as client:
resp = await client.get(
url, url,
headers={ headers={
"User-Agent": config.USER_AGENT, "User-Agent": config.USER_AGENT,
@ -125,7 +126,7 @@ def fetch(url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
raise NotAnObjectError(url, resp) raise NotAnObjectError(url, resp)
def parse_collection( # noqa: C901 async def parse_collection( # noqa: C901
url: str | None = None, url: str | None = None,
payload: RawObject | None = None, payload: RawObject | None = None,
level: int = 0, level: int = 0,
@ -137,7 +138,7 @@ def parse_collection( # noqa: C901
# Go through all the pages # Go through all the pages
out: list[RawObject] = [] out: list[RawObject] = []
if url: if url:
payload = fetch(url) payload = await fetch(url)
if not payload: if not payload:
raise ValueError("must at least prove a payload or an URL") raise ValueError("must at least prove a payload or an URL")
@ -155,7 +156,9 @@ def parse_collection( # noqa: C901
return payload["items"] return payload["items"]
if "first" in payload: if "first" in payload:
if isinstance(payload["first"], str): if isinstance(payload["first"], str):
out.extend(parse_collection(url=payload["first"], level=level + 1)) out.extend(
await parse_collection(url=payload["first"], level=level + 1)
)
else: else:
if "orderedItems" in payload["first"]: if "orderedItems" in payload["first"]:
out.extend(payload["first"]["orderedItems"]) out.extend(payload["first"]["orderedItems"])
@ -163,7 +166,7 @@ def parse_collection( # noqa: C901
out.extend(payload["first"]["items"]) out.extend(payload["first"]["items"])
n = payload["first"].get("next") n = payload["first"].get("next")
if n: if n:
out.extend(parse_collection(url=n, level=level + 1)) out.extend(await parse_collection(url=n, level=level + 1))
return out return out
while payload: while payload:
@ -175,7 +178,7 @@ def parse_collection( # noqa: C901
n = payload.get("next") n = payload.get("next")
if n is None: if n is None:
break break
payload = fetch(n) payload = await fetch(n)
else: else:
raise ValueError("unexpected activity type {}".format(payload["type"])) raise ValueError("unexpected activity type {}".format(payload["type"]))
@ -263,18 +266,6 @@ def remove_context(raw_object: RawObject) -> RawObject:
return a return a
def get(url: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
resp = httpx.get(
url,
headers={"User-Agent": config.USER_AGENT, "Accept": config.AP_CONTENT_TYPE},
params=params,
follow_redirects=True,
auth=auth,
)
resp.raise_for_status()
return resp.json()
def post(url: str, payload: dict[str, Any]) -> httpx.Response: def post(url: str, payload: dict[str, Any]) -> httpx.Response:
resp = httpx.post( resp = httpx.post(
url, url,

View File

@ -160,7 +160,7 @@ async def fetch_actor(db_session: AsyncSession, actor_id: str) -> "ActorModel":
if existing_actor: if existing_actor:
return existing_actor return existing_actor
ap_actor = ap.get(actor_id) ap_actor = await ap.fetch(actor_id)
return await save_actor(db_session, ap_actor) return await save_actor(db_session, ap_actor)

View File

@ -178,26 +178,35 @@ class Attachment(BaseModel):
class RemoteObject(Object): class RemoteObject(Object):
def __init__(self, raw_object: ap.RawObject, actor: Actor | None = None): def __init__(self, raw_object: ap.RawObject, actor: Actor):
self._raw_object = raw_object self._raw_object = raw_object
self._actor: Actor self._actor = actor
if self._actor.ap_id != ap.get_actor_id(self._raw_object):
raise ValueError(f"Invalid actor {self._actor.ap_id}")
@classmethod
async def from_raw_object(
cls,
raw_object: ap.RawObject,
actor: Actor | None = None,
):
# Pre-fetch the actor # Pre-fetch the actor
actor_id = ap.get_actor_id(raw_object) actor_id = ap.get_actor_id(raw_object)
if actor_id == LOCAL_ACTOR.ap_id: if actor_id == LOCAL_ACTOR.ap_id:
self._actor = LOCAL_ACTOR _actor = LOCAL_ACTOR
elif actor: elif actor:
if actor.ap_id != actor_id: if actor.ap_id != actor_id:
raise ValueError( raise ValueError(
f"Invalid actor, got {actor.ap_id}, " f"expected {actor_id}" f"Invalid actor, got {actor.ap_id}, " f"expected {actor_id}"
) )
self._actor = actor _actor = actor # type: ignore
else: else:
self._actor = RemoteActor( _actor = RemoteActor(
ap_actor=ap.fetch(ap.get_actor_id(raw_object)), ap_actor=await ap.fetch(ap.get_actor_id(raw_object)),
) )
self._og_meta = None return cls(raw_object, _actor)
@property @property
def og_meta(self) -> list[dict[str, Any]] | None: def og_meta(self) -> list[dict[str, Any]] | None:

View File

@ -52,7 +52,7 @@ async def save_outbox_object(
relates_to_actor_id: int | None = None, relates_to_actor_id: int | None = None,
source: str | None = None, source: str | None = None,
) -> models.OutboxObject: ) -> models.OutboxObject:
ra = RemoteObject(raw_object) ra = await RemoteObject.from_raw_object(raw_object)
outbox_object = models.OutboxObject( outbox_object = models.OutboxObject(
public_id=public_id, public_id=public_id,
@ -368,13 +368,13 @@ async def _compute_recipients(
continue continue
# Fetch the object # Fetch the object
raw_object = ap.fetch(r) raw_object = await ap.fetch(r)
if raw_object.get("type") in ap.ACTOR_TYPES: if raw_object.get("type") in ap.ACTOR_TYPES:
saved_actor = await save_actor(db_session, raw_object) saved_actor = await save_actor(db_session, raw_object)
recipients.add(saved_actor.shared_inbox_url or saved_actor.inbox_url) recipients.add(saved_actor.shared_inbox_url or saved_actor.inbox_url)
else: else:
# Assume it's a collection of actors # Assume it's a collection of actors
for raw_actor in ap.parse_collection(payload=raw_object): for raw_actor in await ap.parse_collection(payload=raw_object):
actor = RemoteActor(raw_actor) actor = RemoteActor(raw_actor)
recipients.add(actor.shared_inbox_url or actor.inbox_url) recipients.add(actor.shared_inbox_url or actor.inbox_url)
@ -741,7 +741,7 @@ async def save_to_inbox(db_session: AsyncSession, raw_object: ap.RawObject) -> N
# Save it as an inbox object # Save it as an inbox object
if not ra.activity_object_ap_id: if not ra.activity_object_ap_id:
raise ValueError("Should never happen") raise ValueError("Should never happen")
announced_raw_object = ap.fetch(ra.activity_object_ap_id) announced_raw_object = await ap.fetch(ra.activity_object_ap_id)
announced_actor = await fetch_actor( announced_actor = await fetch_actor(
db_session, ap.get_actor_id(announced_raw_object) db_session, ap.get_actor_id(announced_raw_object)
) )
@ -830,7 +830,7 @@ async def fetch_actor_collection(db_session: AsyncSession, url: str) -> list[Act
else: else:
raise ValueError(f"internal collection for {url}) not supported") raise ValueError(f"internal collection for {url}) not supported")
return [RemoteActor(actor) for actor in ap.parse_collection(url)] return [RemoteActor(actor) for actor in await ap.parse_collection(url)]
@dataclass @dataclass

View File

@ -63,11 +63,11 @@ def _body_digest(body: bytes) -> str:
@lru_cache(32) @lru_cache(32)
def _get_public_key(key_id: str) -> Key: async def _get_public_key(key_id: str) -> Key:
# TODO: use DB to use cache actor # TODO: use DB to use cache actor
from app import activitypub as ap from app import activitypub as ap
actor = ap.fetch(key_id) actor = await ap.fetch(key_id)
if actor["type"] == "Key": if actor["type"] == "Key":
# The Key is not embedded in the Person # The Key is not embedded in the Person
k = Key(actor["owner"], actor["id"]) k = Key(actor["owner"], actor["id"])
@ -111,7 +111,7 @@ async def httpsig_checker(
) )
try: try:
k = _get_public_key(hsig["keyId"]) k = await _get_public_key(hsig["keyId"])
except ap.ObjectIsGoneError: except ap.ObjectIsGoneError:
logger.info("Actor is gone") logger.info("Actor is gone")
return HTTPSigInfo(has_valid_signature=False) return HTTPSigInfo(has_valid_signature=False)

View File

@ -10,13 +10,13 @@ from app.database import AsyncSession
async def lookup(db_session: AsyncSession, query: str) -> Actor | RemoteObject: async def lookup(db_session: AsyncSession, query: str) -> Actor | RemoteObject:
if query.startswith("@"): if query.startswith("@"):
query = webfinger.get_actor_url(query) # type: ignore # None check below query = await webfinger.get_actor_url(query) # type: ignore # None check below
if not query: if not query:
raise ap.NotAnObjectError(query) raise ap.NotAnObjectError(query)
try: try:
ap_obj = ap.fetch(query) ap_obj = await ap.fetch(query)
except ap.NotAnObjectError as not_an_object_error: except ap.NotAnObjectError as not_an_object_error:
resp = not_an_object_error.resp resp = not_an_object_error.resp
if not resp: if not resp:
@ -26,7 +26,7 @@ async def lookup(db_session: AsyncSession, query: str) -> Actor | RemoteObject:
if resp.headers.get("content-type", "").startswith("text/html"): if resp.headers.get("content-type", "").startswith("text/html"):
for alternate in mf2py.parse(doc=resp.text).get("alternates", []): for alternate in mf2py.parse(doc=resp.text).get("alternates", []):
if alternate.get("type") == "application/activity+json": if alternate.get("type") == "application/activity+json":
alternate_obj = ap.fetch(alternate["url"]) alternate_obj = await ap.fetch(alternate["url"])
if alternate_obj: if alternate_obj:
ap_obj = alternate_obj ap_obj = alternate_obj
@ -37,4 +37,4 @@ async def lookup(db_session: AsyncSession, query: str) -> Actor | RemoteObject:
actor = await fetch_actor(db_session, ap_obj["id"]) actor = await fetch_actor(db_session, ap_obj["id"])
return actor return actor
else: else:
return RemoteObject(ap_obj) return await RemoteObject.from_raw_object(ap_obj)

View File

@ -604,7 +604,7 @@ async def post_remote_follow(
if not profile.startswith("@"): if not profile.startswith("@"):
profile = f"@{profile}" profile = f"@{profile}"
remote_follow_template = get_remote_follow_template(profile) remote_follow_template = await get_remote_follow_template(profile)
if not remote_follow_template: if not remote_follow_template:
raise HTTPException(status_code=404) raise HTTPException(status_code=404)

View File

@ -52,7 +52,7 @@ async def _mentionify(
) )
).scalar_one_or_none() ).scalar_one_or_none()
if not actor: if not actor:
actor_url = webfinger.get_actor_url(mention) actor_url = await webfinger.get_actor_url(mention)
if not actor_url: if not actor_url:
# FIXME(ts): raise an error? # FIXME(ts): raise an error?
continue continue

View File

@ -7,7 +7,7 @@ from loguru import logger
from app import config from app import config
def webfinger( async def webfinger(
resource: str, resource: str,
) -> dict[str, Any] | None: # noqa: C901 ) -> dict[str, Any] | None: # noqa: C901
"""Mastodon-like WebFinger resolution to retrieve the activity stream Actor URL.""" """Mastodon-like WebFinger resolution to retrieve the activity stream Actor URL."""
@ -28,10 +28,11 @@ def webfinger(
is_404 = False is_404 = False
async with httpx.AsyncClient() as client:
for i, proto in enumerate(protos): for i, proto in enumerate(protos):
try: try:
url = f"{proto}://{host}/.well-known/webfinger" url = f"{proto}://{host}/.well-known/webfinger"
resp = httpx.get( resp = await client.get(
url, url,
params={"resource": resource}, params={"resource": resource},
headers={ headers={
@ -57,8 +58,8 @@ def webfinger(
return resp.json() return resp.json()
def get_remote_follow_template(resource: str) -> str | None: async def get_remote_follow_template(resource: str) -> str | None:
data = webfinger(resource) data = await webfinger(resource)
if data is None: if data is None:
return None return None
for link in data["links"]: for link in data["links"]:
@ -67,13 +68,13 @@ def get_remote_follow_template(resource: str) -> str | None:
return None return None
def get_actor_url(resource: str) -> str | None: async def get_actor_url(resource: str) -> str | None:
"""Mastodon-like WebFinger resolution to retrieve the activity stream Actor URL. """Mastodon-like WebFinger resolution to retrieve the activity stream Actor URL.
Returns: Returns:
the Actor URL or None if the resolution failed. the Actor URL or None if the resolution failed.
""" """
data = webfinger(resource) data = await webfinger(resource)
if data is None: if data is None:
return None return None
for link in data["links"]: for link in data["links"]:

View File

@ -43,7 +43,8 @@ def test_inbox_follow_request(
factories.build_follow_activity( factories.build_follow_activity(
from_remote_actor=ra, from_remote_actor=ra,
for_remote_actor=LOCAL_ACTOR, for_remote_actor=LOCAL_ACTOR,
) ),
ra,
) )
with mock_httpsig_checker(ra): with mock_httpsig_checker(ra):
response = client.post( response = client.post(
@ -100,7 +101,8 @@ def test_inbox_accept_follow_request(
from_remote_actor=LOCAL_ACTOR, from_remote_actor=LOCAL_ACTOR,
for_remote_actor=ra, for_remote_actor=ra,
outbox_public_id=follow_id, outbox_public_id=follow_id,
) ),
LOCAL_ACTOR,
) )
outbox_object = factories.OutboxObjectFactory.from_remote_object( outbox_object = factories.OutboxObjectFactory.from_remote_object(
follow_id, follow_from_outbox follow_id, follow_from_outbox
@ -111,7 +113,8 @@ def test_inbox_accept_follow_request(
factories.build_accept_activity( factories.build_accept_activity(
from_remote_actor=ra, from_remote_actor=ra,
for_remote_object=follow_from_outbox, for_remote_object=follow_from_outbox,
) ),
ra,
) )
with mock_httpsig_checker(ra): with mock_httpsig_checker(ra):
response = client.post( response = client.post(

View File

@ -112,7 +112,8 @@ def test_send_create_activity__with_followers(
from_remote_actor=ra, from_remote_actor=ra,
for_remote_actor=LOCAL_ACTOR, for_remote_actor=LOCAL_ACTOR,
outbox_public_id=follow_id, outbox_public_id=follow_id,
) ),
ra,
) )
inbox_object = factories.InboxObjectFactory.from_remote_object( inbox_object = factories.InboxObjectFactory.from_remote_object(
follow_from_inbox, actor follow_from_inbox, actor

View File

@ -31,7 +31,8 @@ def _setup_outbox_object() -> models.OutboxObject:
from_remote_actor=LOCAL_ACTOR, from_remote_actor=LOCAL_ACTOR,
for_remote_actor=ra, for_remote_actor=ra,
outbox_public_id=follow_id, outbox_public_id=follow_id,
) ),
LOCAL_ACTOR,
) )
outbox_object = factories.OutboxObjectFactory.from_remote_object( outbox_object = factories.OutboxObjectFactory.from_remote_object(
follow_id, follow_from_outbox follow_id, follow_from_outbox