274 lines
8.7 KiB
Python
274 lines
8.7 KiB
Python
import base64
|
|
import hashlib
|
|
import typing
|
|
from dataclasses import dataclass
|
|
from datetime import datetime
|
|
from datetime import timedelta
|
|
from datetime import timezone
|
|
from typing import Any
|
|
from typing import Dict
|
|
from typing import MutableMapping
|
|
from typing import Optional
|
|
|
|
import fastapi
|
|
import httpx
|
|
from cachetools import LFUCache
|
|
from Crypto.Hash import SHA256
|
|
from Crypto.Signature import PKCS1_v1_5
|
|
from dateutil.parser import parse
|
|
from loguru import logger
|
|
from sqlalchemy import select
|
|
|
|
from app import activitypub as ap
|
|
from app import config
|
|
from app.config import KEY_PATH
|
|
from app.database import AsyncSession
|
|
from app.database import get_db_session
|
|
from app.key import Key
|
|
from app.utils.datetime import now
|
|
|
|
_KEY_CACHE: MutableMapping[str, Key] = LFUCache(256)
|
|
|
|
|
|
def _build_signed_string(
|
|
signed_headers: str,
|
|
method: str,
|
|
path: str,
|
|
headers: Any,
|
|
body_digest: str | None,
|
|
sig_data: dict[str, Any],
|
|
) -> tuple[str, datetime | None]:
|
|
signature_date: datetime | None = None
|
|
out = []
|
|
for signed_header in signed_headers.split(" "):
|
|
if signed_header == "(created)":
|
|
signature_date = datetime.fromtimestamp(int(sig_data["created"])).replace(
|
|
tzinfo=timezone.utc
|
|
)
|
|
elif signed_header == "date":
|
|
signature_date = parse(headers["date"])
|
|
|
|
if signed_header == "(request-target)":
|
|
out.append("(request-target): " + method.lower() + " " + path)
|
|
elif signed_header == "digest" and body_digest:
|
|
out.append("digest: " + body_digest)
|
|
elif signed_header in ["(created)", "(expires)"]:
|
|
out.append(
|
|
signed_header
|
|
+ ": "
|
|
+ sig_data[signed_header[1 : len(signed_header) - 1]]
|
|
)
|
|
else:
|
|
out.append(signed_header + ": " + headers[signed_header])
|
|
return "\n".join(out), signature_date
|
|
|
|
|
|
def _parse_sig_header(val: Optional[str]) -> Optional[Dict[str, str]]:
|
|
if not val:
|
|
return None
|
|
out = {}
|
|
for data in val.split(","):
|
|
k, v = data.split("=", 1)
|
|
out[k] = v[1 : len(v) - 1] # noqa: black conflict
|
|
return out
|
|
|
|
|
|
def _verify_h(signed_string, signature, pubkey):
|
|
signer = PKCS1_v1_5.new(pubkey)
|
|
digest = SHA256.new()
|
|
digest.update(signed_string.encode("utf-8"))
|
|
return signer.verify(digest, signature)
|
|
|
|
|
|
def _body_digest(body: bytes) -> str:
|
|
h = hashlib.new("sha256")
|
|
h.update(body) # type: ignore
|
|
return "SHA-256=" + base64.b64encode(h.digest()).decode("utf-8")
|
|
|
|
|
|
async def _get_public_key(db_session: AsyncSession, key_id: str) -> Key:
|
|
if cached_key := _KEY_CACHE.get(key_id):
|
|
logger.info(f"Key {key_id} found in cache")
|
|
return cached_key
|
|
|
|
# Check if the key belongs to an actor already in DB
|
|
from app import models
|
|
|
|
existing_actor = (
|
|
await db_session.scalars(
|
|
select(models.Actor).where(models.Actor.ap_id == key_id.split("#")[0])
|
|
)
|
|
).one_or_none()
|
|
if existing_actor and existing_actor.public_key_id == key_id:
|
|
k = Key(existing_actor.ap_id, key_id)
|
|
k.load_pub(existing_actor.public_key_as_pem)
|
|
logger.info(f"Found {key_id} on an existing actor")
|
|
_KEY_CACHE[key_id] = k
|
|
return k
|
|
|
|
# Fetch it
|
|
from app import activitypub as ap
|
|
|
|
# Without signing the request as if it's the first contact, the 2 servers
|
|
# might race to fetch each other key
|
|
try:
|
|
actor = await ap.fetch(key_id, disable_httpsig=True)
|
|
except httpx.HTTPStatusError as http_err:
|
|
if http_err.response.status_code in [401, 403]:
|
|
actor = await ap.fetch(key_id, disable_httpsig=False)
|
|
else:
|
|
raise
|
|
|
|
if actor["type"] == "Key":
|
|
# The Key is not embedded in the Person
|
|
k = Key(actor["owner"], actor["id"])
|
|
k.load_pub(actor["publicKeyPem"])
|
|
else:
|
|
k = Key(actor["id"], actor["publicKey"]["id"])
|
|
k.load_pub(actor["publicKey"]["publicKeyPem"])
|
|
|
|
# Ensure the right key was fetch
|
|
if key_id not in [k.key_id(), k.owner]:
|
|
raise ValueError(
|
|
f"failed to fetch requested key {key_id}: got {actor['publicKey']}"
|
|
)
|
|
|
|
_KEY_CACHE[key_id] = k
|
|
return k
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HTTPSigInfo:
|
|
has_valid_signature: bool
|
|
signed_by_ap_actor_id: str | None = None
|
|
is_ap_actor_gone: bool = False
|
|
is_unsupported_algorithm: bool = False
|
|
is_expired: bool = False
|
|
|
|
|
|
async def httpsig_checker(
|
|
request: fastapi.Request,
|
|
db_session: AsyncSession = fastapi.Depends(get_db_session),
|
|
) -> HTTPSigInfo:
|
|
body = await request.body()
|
|
|
|
hsig = _parse_sig_header(request.headers.get("Signature"))
|
|
if not hsig:
|
|
logger.info("No HTTP signature found")
|
|
return HTTPSigInfo(has_valid_signature=False)
|
|
|
|
if alg := hsig.get("algorithm") not in ["rsa-sha256", "hs2019"]:
|
|
logger.info(f"Unsupported HTTP sig algorithm: {alg}")
|
|
return HTTPSigInfo(
|
|
has_valid_signature=False,
|
|
is_unsupported_algorithm=True,
|
|
)
|
|
|
|
logger.debug(f"hsig={hsig}")
|
|
signed_string, signature_date = _build_signed_string(
|
|
hsig["headers"],
|
|
request.method,
|
|
request.url.path,
|
|
request.headers,
|
|
_body_digest(body) if body else None,
|
|
hsig,
|
|
)
|
|
|
|
# Sanity checks on the signature date
|
|
if signature_date is None or now() - signature_date > timedelta(hours=12):
|
|
logger.info(f"Signature expired: {signature_date=}")
|
|
return HTTPSigInfo(
|
|
has_valid_signature=False,
|
|
is_expired=True,
|
|
)
|
|
|
|
try:
|
|
k = await _get_public_key(db_session, hsig["keyId"])
|
|
except (ap.ObjectIsGoneError, ap.ObjectNotFoundError):
|
|
logger.info("Actor is gone or not found")
|
|
return HTTPSigInfo(has_valid_signature=False, is_ap_actor_gone=True)
|
|
except Exception:
|
|
logger.exception(f'Failed to fetch HTTP sig key {hsig["keyId"]}')
|
|
return HTTPSigInfo(has_valid_signature=False)
|
|
|
|
httpsig_info = HTTPSigInfo(
|
|
has_valid_signature=_verify_h(
|
|
signed_string, base64.b64decode(hsig["signature"]), k.pubkey
|
|
),
|
|
signed_by_ap_actor_id=k.owner,
|
|
)
|
|
logger.info(f"Valid HTTP signature for {httpsig_info.signed_by_ap_actor_id}")
|
|
return httpsig_info
|
|
|
|
|
|
async def enforce_httpsig(
|
|
request: fastapi.Request,
|
|
httpsig_info: HTTPSigInfo = fastapi.Depends(httpsig_checker),
|
|
) -> HTTPSigInfo:
|
|
"""FastAPI Depends"""
|
|
if not httpsig_info.has_valid_signature:
|
|
logger.warning(f"Invalid HTTP sig {httpsig_info=}")
|
|
body = await request.body()
|
|
logger.info(f"{body=}")
|
|
|
|
# Special case for Mastoodon instance that keep resending Delete
|
|
# activities for actor we don't know about if we raise a 401
|
|
if httpsig_info.is_ap_actor_gone:
|
|
logger.info("Let's make Mastodon happy, returning a 202")
|
|
raise fastapi.HTTPException(status_code=202)
|
|
|
|
detail = "Invalid HTTP sig"
|
|
if httpsig_info.is_unsupported_algorithm:
|
|
detail = "Unsupported signature algorithm, must be rsa-sha256 or hs2019"
|
|
elif httpsig_info.is_expired:
|
|
detail = "Signature expired"
|
|
|
|
raise fastapi.HTTPException(status_code=401, detail=detail)
|
|
|
|
return httpsig_info
|
|
|
|
|
|
class HTTPXSigAuth(httpx.Auth):
|
|
def __init__(self, key: Key) -> None:
|
|
self.key = key
|
|
|
|
def auth_flow(
|
|
self, r: httpx.Request
|
|
) -> typing.Generator[httpx.Request, httpx.Response, None]:
|
|
logger.info(f"keyid={self.key.key_id()}")
|
|
|
|
bodydigest = None
|
|
if r.content:
|
|
bh = hashlib.new("sha256")
|
|
bh.update(r.content)
|
|
bodydigest = "SHA-256=" + base64.b64encode(bh.digest()).decode("utf-8")
|
|
|
|
date = datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
|
|
r.headers["Date"] = date
|
|
if bodydigest:
|
|
r.headers["Digest"] = bodydigest
|
|
sigheaders = "(request-target) user-agent host date digest content-type"
|
|
else:
|
|
sigheaders = "(request-target) user-agent host date accept"
|
|
|
|
to_be_signed, _ = _build_signed_string(
|
|
sigheaders, r.method, r.url.path, r.headers, bodydigest, {}
|
|
)
|
|
if not self.key.privkey:
|
|
raise ValueError("Should never happen")
|
|
signer = PKCS1_v1_5.new(self.key.privkey)
|
|
digest = SHA256.new()
|
|
digest.update(to_be_signed.encode("utf-8"))
|
|
sig = base64.b64encode(signer.sign(digest)).decode()
|
|
|
|
key_id = self.key.key_id()
|
|
sig_value = f'keyId="{key_id}",algorithm="rsa-sha256",headers="{sigheaders}",signature="{sig}"' # noqa: E501
|
|
logger.debug(f"signed request {sig_value=}")
|
|
r.headers["Signature"] = sig_value
|
|
yield r
|
|
|
|
|
|
k = Key(config.ID, f"{config.ID}#main-key")
|
|
k.load(KEY_PATH.read_text())
|
|
auth = HTTPXSigAuth(k)
|