Improve caching
parent
d371e3cd4f
commit
6458d2a6c7
|
@ -8,22 +8,27 @@ import hashlib
|
||||||
import typing
|
import typing
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from functools import lru_cache
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import httpx
|
import httpx
|
||||||
|
from cachetools import LFUCache
|
||||||
from Crypto.Hash import SHA256
|
from Crypto.Hash import SHA256
|
||||||
from Crypto.Signature import PKCS1_v1_5
|
from Crypto.Signature import PKCS1_v1_5
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app import activitypub as ap
|
from app import activitypub as ap
|
||||||
from app import config
|
from app import config
|
||||||
|
from app.database import AsyncSession
|
||||||
|
from app.database import get_db_session
|
||||||
from app.key import Key
|
from app.key import Key
|
||||||
from app.key import get_key
|
from app.key import get_key
|
||||||
|
|
||||||
|
_KEY_CACHE = LFUCache(256)
|
||||||
|
|
||||||
|
|
||||||
def _build_signed_string(
|
def _build_signed_string(
|
||||||
signed_headers: str, method: str, path: str, headers: Any, body_digest: str | None
|
signed_headers: str, method: str, path: str, headers: Any, body_digest: str | None
|
||||||
|
@ -62,9 +67,25 @@ def _body_digest(body: bytes) -> str:
|
||||||
return "SHA-256=" + base64.b64encode(h.digest()).decode("utf-8")
|
return "SHA-256=" + base64.b64encode(h.digest()).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(32)
|
async def _get_public_key(db_session: AsyncSession, key_id: str) -> Key:
|
||||||
async def _get_public_key(key_id: str) -> Key:
|
if cached_key := _KEY_CACHE.get(key_id):
|
||||||
# TODO: use DB to use cache actor
|
return cached_key
|
||||||
|
|
||||||
|
# Check if the key belongs to an actor already in DB
|
||||||
|
from app import models
|
||||||
|
existing_actor = (
|
||||||
|
await db_session.scalars(
|
||||||
|
select(models.Actor).where(models.Actor.ap_id == key_id.split("#")[0])
|
||||||
|
)
|
||||||
|
).one_or_none()
|
||||||
|
if existing_actor and existing_actor.public_key_id == key_id:
|
||||||
|
k = Key(existing_actor.ap_id, key_id)
|
||||||
|
k.load_pub(existing_actor.public_key_as_pem)
|
||||||
|
logger.info(f"Found {key_id} on an existing actor")
|
||||||
|
_KEY_CACHE[key_id] = k
|
||||||
|
return k
|
||||||
|
|
||||||
|
# Fetch it
|
||||||
from app import activitypub as ap
|
from app import activitypub as ap
|
||||||
|
|
||||||
actor = await ap.fetch(key_id)
|
actor = await ap.fetch(key_id)
|
||||||
|
@ -82,6 +103,7 @@ async def _get_public_key(key_id: str) -> Key:
|
||||||
f"failed to fetch requested key {key_id}: got {actor['publicKey']['id']}"
|
f"failed to fetch requested key {key_id}: got {actor['publicKey']['id']}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_KEY_CACHE[key_id] = k
|
||||||
return k
|
return k
|
||||||
|
|
||||||
|
|
||||||
|
@ -93,6 +115,7 @@ class HTTPSigInfo:
|
||||||
|
|
||||||
async def httpsig_checker(
|
async def httpsig_checker(
|
||||||
request: fastapi.Request,
|
request: fastapi.Request,
|
||||||
|
db_session: AsyncSession = fastapi.Depends(get_db_session),
|
||||||
) -> HTTPSigInfo:
|
) -> HTTPSigInfo:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
||||||
|
@ -111,7 +134,7 @@ async def httpsig_checker(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
k = await _get_public_key(hsig["keyId"])
|
k = await _get_public_key(db_session, hsig["keyId"])
|
||||||
except ap.ObjectIsGoneError:
|
except ap.ObjectIsGoneError:
|
||||||
logger.info("Actor is gone")
|
logger.info("Actor is gone")
|
||||||
return HTTPSigInfo(has_valid_signature=False)
|
return HTTPSigInfo(has_valid_signature=False)
|
||||||
|
|
62
app/main.py
62
app/main.py
|
@ -8,6 +8,7 @@ from typing import Any
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
from cachetools import LFUCache
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi import Form
|
from fastapi import Form
|
||||||
|
@ -56,6 +57,9 @@ from app.utils import pagination
|
||||||
from app.utils.emoji import EMOJIS_BY_NAME
|
from app.utils.emoji import EMOJIS_BY_NAME
|
||||||
from app.webfinger import get_remote_follow_template
|
from app.webfinger import get_remote_follow_template
|
||||||
|
|
||||||
|
_RESIZED_CACHE = LFUCache(32)
|
||||||
|
|
||||||
|
|
||||||
# TODO(ts):
|
# TODO(ts):
|
||||||
#
|
#
|
||||||
# Next:
|
# Next:
|
||||||
|
@ -728,7 +732,7 @@ async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResp
|
||||||
|
|
||||||
|
|
||||||
@app.get("/proxy/media/{encoded_url}/{size}")
|
@app.get("/proxy/media/{encoded_url}/{size}")
|
||||||
def serve_proxy_media_resized(
|
async def serve_proxy_media_resized(
|
||||||
request: Request,
|
request: Request,
|
||||||
encoded_url: str,
|
encoded_url: str,
|
||||||
size: int,
|
size: int,
|
||||||
|
@ -738,18 +742,38 @@ def serve_proxy_media_resized(
|
||||||
|
|
||||||
# Decode the base64-encoded URL
|
# Decode the base64-encoded URL
|
||||||
url = base64.urlsafe_b64decode(encoded_url).decode()
|
url = base64.urlsafe_b64decode(encoded_url).decode()
|
||||||
|
|
||||||
|
is_cached = False
|
||||||
|
is_resized = False
|
||||||
|
if cached_resp := _RESIZED_CACHE.get((url, size)):
|
||||||
|
is_resized, resized_content, resized_mimetype, resp_headers = cached_resp
|
||||||
|
if is_resized:
|
||||||
|
return PlainTextResponse(
|
||||||
|
resized_content,
|
||||||
|
media_type=resized_mimetype,
|
||||||
|
headers=resp_headers,
|
||||||
|
)
|
||||||
|
is_cached = True
|
||||||
|
|
||||||
# Request the URL (and filter request headers)
|
# Request the URL (and filter request headers)
|
||||||
proxy_resp = httpx.get(
|
async with httpx.AsyncClient() as client:
|
||||||
url,
|
proxy_resp = await client.get(
|
||||||
headers=[
|
url,
|
||||||
(k, v)
|
headers=[
|
||||||
for (k, v) in request.headers.raw
|
(k, v)
|
||||||
if k.lower()
|
for (k, v) in request.headers.raw
|
||||||
not in [b"host", b"cookie", b"x-forwarded-for", b"x-real-ip", b"user-agent"]
|
if k.lower()
|
||||||
]
|
not in [
|
||||||
+ [(b"user-agent", USER_AGENT.encode())],
|
b"host",
|
||||||
)
|
b"cookie",
|
||||||
if proxy_resp.status_code != 200:
|
b"x-forwarded-for",
|
||||||
|
b"x-real-ip",
|
||||||
|
b"user-agent",
|
||||||
|
]
|
||||||
|
]
|
||||||
|
+ [(b"user-agent", USER_AGENT.encode())],
|
||||||
|
)
|
||||||
|
if proxy_resp.status_code != 200 or (is_cached and not is_resized):
|
||||||
return PlainTextResponse(
|
return PlainTextResponse(
|
||||||
proxy_resp.content,
|
proxy_resp.content,
|
||||||
status_code=proxy_resp.status_code,
|
status_code=proxy_resp.status_code,
|
||||||
|
@ -772,15 +796,23 @@ def serve_proxy_media_resized(
|
||||||
try:
|
try:
|
||||||
out = BytesIO(proxy_resp.content)
|
out = BytesIO(proxy_resp.content)
|
||||||
i = Image.open(out)
|
i = Image.open(out)
|
||||||
if i.is_animated:
|
if getattr(i, "is_animated", False):
|
||||||
raise ValueError
|
raise ValueError
|
||||||
i.thumbnail((size, size))
|
i.thumbnail((size, size))
|
||||||
resized_buf = BytesIO()
|
resized_buf = BytesIO()
|
||||||
i.save(resized_buf, format=i.format)
|
i.save(resized_buf, format=i.format)
|
||||||
resized_buf.seek(0)
|
resized_buf.seek(0)
|
||||||
|
resized_content = resized_buf.read()
|
||||||
|
resized_mimetype = i.get_format_mimetype() # type: ignore
|
||||||
|
_RESIZED_CACHE[(url, size)] = (
|
||||||
|
True,
|
||||||
|
resized_content,
|
||||||
|
resized_mimetype,
|
||||||
|
proxy_resp_headers,
|
||||||
|
)
|
||||||
return PlainTextResponse(
|
return PlainTextResponse(
|
||||||
resized_buf.read(),
|
resized_content,
|
||||||
media_type=i.get_format_mimetype(), # type: ignore
|
media_type=resized_mimetype,
|
||||||
headers=proxy_resp_headers,
|
headers=proxy_resp_headers,
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
|
|
@ -190,7 +190,8 @@ def _clean_html(html: str, note: Object) -> str:
|
||||||
strip=True,
|
strip=True,
|
||||||
),
|
),
|
||||||
note,
|
note,
|
||||||
)
|
),
|
||||||
|
is_local=note.ap_id.startswith(BASE_URL),
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise
|
raise
|
||||||
|
@ -241,12 +242,15 @@ def _html2text(content: str) -> str:
|
||||||
return H2T.handle(content)
|
return H2T.handle(content)
|
||||||
|
|
||||||
|
|
||||||
def _replace_emoji(u, data):
|
def _replace_emoji(u: str, _) -> str:
|
||||||
filename = hex(ord(u))[2:]
|
filename = hex(ord(u))[2:]
|
||||||
return config.EMOJI_TPL.format(filename=filename, raw=u)
|
return config.EMOJI_TPL.format(filename=filename, raw=u)
|
||||||
|
|
||||||
|
|
||||||
def _emojify(text: str):
|
def _emojify(text: str, is_local: bool) -> str:
|
||||||
|
if not is_local:
|
||||||
|
return text
|
||||||
|
|
||||||
return emoji.replace_emoji(
|
return emoji.replace_emoji(
|
||||||
text,
|
text,
|
||||||
replace=_replace_emoji,
|
replace=_replace_emoji,
|
||||||
|
|
|
@ -16,7 +16,10 @@
|
||||||
</div>
|
</div>
|
||||||
{{ utils.display_actor(inbox_object.actor, actors_metadata) }}
|
{{ utils.display_actor(inbox_object.actor, actors_metadata) }}
|
||||||
{% else %}
|
{% else %}
|
||||||
|
<p>
|
||||||
Implement {{ inbox_object.ap_type }}
|
Implement {{ inbox_object.ap_type }}
|
||||||
|
{{ inbox_object.ap_object }}
|
||||||
|
</p>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
{% endfor %}
|
{% endfor %}
|
||||||
|
|
||||||
|
|
|
@ -1143,7 +1143,7 @@ dev = ["pytest (>=4.6.2)", "black (>=19.3b0)"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "1.1"
|
lock-version = "1.1"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "19151bbc858317aec5747a8f45a86b47cc198111422cc166a94634ad1941d8bc"
|
content-hash = "91e35a13d21bb5fd3e8916aee95c0a8019bec3cf4f0c677bb86641f1d88dcfe3"
|
||||||
|
|
||||||
[metadata.files]
|
[metadata.files]
|
||||||
aiosqlite = [
|
aiosqlite = [
|
||||||
|
|
|
@ -40,6 +40,7 @@ emoji = "^1.7.0"
|
||||||
PyLD = "^2.0.3"
|
PyLD = "^2.0.3"
|
||||||
aiosqlite = "^0.17.0"
|
aiosqlite = "^0.17.0"
|
||||||
sqlalchemy2-stubs = "^0.0.2-alpha.24"
|
sqlalchemy2-stubs = "^0.0.2-alpha.24"
|
||||||
|
cachetools = "^5.2.0"
|
||||||
|
|
||||||
[tool.poetry.dev-dependencies]
|
[tool.poetry.dev-dependencies]
|
||||||
black = "^22.3.0"
|
black = "^22.3.0"
|
||||||
|
|
Loading…
Reference in New Issue