Start supporting a server blocklist
parent
2d2b2e5873
commit
1e6a290fb3
|
@ -59,6 +59,8 @@ class ObjectNotFoundError(Exception):
|
||||||
class FetchErrorTypeEnum(str, enum.Enum):
|
class FetchErrorTypeEnum(str, enum.Enum):
|
||||||
TIMEOUT = "TIMEOUT"
|
TIMEOUT = "TIMEOUT"
|
||||||
NOT_FOUND = "NOT_FOUND"
|
NOT_FOUND = "NOT_FOUND"
|
||||||
|
UNAUHTORIZED = "UNAUTHORIZED"
|
||||||
|
|
||||||
INTERNAL_ERROR = "INTERNAL_ERROR"
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -114,6 +114,10 @@ class Actor:
|
||||||
def attachments(self) -> list[ap.RawObject]:
|
def attachments(self) -> list[ap.RawObject]:
|
||||||
return ap.as_list(self.ap_actor.get("attachment", []))
|
return ap.as_list(self.ap_actor.get("attachment", []))
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def server(self) -> str:
|
||||||
|
return urlparse(self.ap_id).netloc
|
||||||
|
|
||||||
|
|
||||||
class RemoteActor(Actor):
|
class RemoteActor(Actor):
|
||||||
def __init__(self, ap_actor: ap.RawObject) -> None:
|
def __init__(self, ap_actor: ap.RawObject) -> None:
|
||||||
|
|
|
@ -26,6 +26,7 @@ from app.actor import fetch_actor
|
||||||
from app.actor import save_actor
|
from app.actor import save_actor
|
||||||
from app.ap_object import RemoteObject
|
from app.ap_object import RemoteObject
|
||||||
from app.config import BASE_URL
|
from app.config import BASE_URL
|
||||||
|
from app.config import BLOCKED_SERVERS
|
||||||
from app.config import ID
|
from app.config import ID
|
||||||
from app.config import MANUALLY_APPROVES_FOLLOWERS
|
from app.config import MANUALLY_APPROVES_FOLLOWERS
|
||||||
from app.database import AsyncSession
|
from app.database import AsyncSession
|
||||||
|
@ -1447,6 +1448,10 @@ async def save_to_inbox(
|
||||||
logger.exception("Failed to fetch actor")
|
logger.exception("Failed to fetch actor")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if actor.server in BLOCKED_SERVERS:
|
||||||
|
logger.warning(f"Server {actor.server} is blocked")
|
||||||
|
return
|
||||||
|
|
||||||
if "id" not in raw_object:
|
if "id" not in raw_object:
|
||||||
await _process_transient_object(db_session, raw_object, actor)
|
await _process_transient_object(db_session, raw_object, actor)
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -50,6 +50,11 @@ class _ProfileMetadata(pydantic.BaseModel):
|
||||||
value: str
|
value: str
|
||||||
|
|
||||||
|
|
||||||
|
class _BlockedServer(pydantic.BaseModel):
|
||||||
|
hostname: str
|
||||||
|
reason: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class Config(pydantic.BaseModel):
|
class Config(pydantic.BaseModel):
|
||||||
domain: str
|
domain: str
|
||||||
username: str
|
username: str
|
||||||
|
@ -65,6 +70,7 @@ class Config(pydantic.BaseModel):
|
||||||
privacy_replace: list[_PrivacyReplace] | None = None
|
privacy_replace: list[_PrivacyReplace] | None = None
|
||||||
metadata: list[_ProfileMetadata] | None = None
|
metadata: list[_ProfileMetadata] | None = None
|
||||||
code_highlighting_theme = "friendly_grayscale"
|
code_highlighting_theme = "friendly_grayscale"
|
||||||
|
blocked_servers: list[_BlockedServer] = []
|
||||||
|
|
||||||
# Config items to make tests easier
|
# Config items to make tests easier
|
||||||
sqlalchemy_database: str | None = None
|
sqlalchemy_database: str | None = None
|
||||||
|
@ -109,6 +115,9 @@ MANUALLY_APPROVES_FOLLOWERS = CONFIG.manually_approves_followers
|
||||||
PRIVACY_REPLACE = None
|
PRIVACY_REPLACE = None
|
||||||
if CONFIG.privacy_replace:
|
if CONFIG.privacy_replace:
|
||||||
PRIVACY_REPLACE = {pr.domain: pr.replace_by for pr in CONFIG.privacy_replace}
|
PRIVACY_REPLACE = {pr.domain: pr.replace_by for pr in CONFIG.privacy_replace}
|
||||||
|
|
||||||
|
BLOCKED_SERVERS = {blocked_server.hostname for blocked_server in CONFIG.blocked_servers}
|
||||||
|
|
||||||
BASE_URL = ID
|
BASE_URL = ID
|
||||||
DEBUG = CONFIG.debug
|
DEBUG = CONFIG.debug
|
||||||
DB_PATH = CONFIG.sqlalchemy_database or ROOT_DIR / "data" / "microblogpub.db"
|
DB_PATH = CONFIG.sqlalchemy_database or ROOT_DIR / "data" / "microblogpub.db"
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from typing import MutableMapping
|
from typing import MutableMapping
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import fastapi
|
import fastapi
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -21,6 +22,7 @@ from sqlalchemy import select
|
||||||
|
|
||||||
from app import activitypub as ap
|
from app import activitypub as ap
|
||||||
from app import config
|
from app import config
|
||||||
|
from app.config import BLOCKED_SERVERS
|
||||||
from app.config import KEY_PATH
|
from app.config import KEY_PATH
|
||||||
from app.database import AsyncSession
|
from app.database import AsyncSession
|
||||||
from app.database import get_db_session
|
from app.database import get_db_session
|
||||||
|
@ -144,6 +146,7 @@ class HTTPSigInfo:
|
||||||
is_ap_actor_gone: bool = False
|
is_ap_actor_gone: bool = False
|
||||||
is_unsupported_algorithm: bool = False
|
is_unsupported_algorithm: bool = False
|
||||||
is_expired: bool = False
|
is_expired: bool = False
|
||||||
|
server: str | None = None
|
||||||
|
|
||||||
|
|
||||||
async def httpsig_checker(
|
async def httpsig_checker(
|
||||||
|
@ -157,11 +160,22 @@ async def httpsig_checker(
|
||||||
logger.info("No HTTP signature found")
|
logger.info("No HTTP signature found")
|
||||||
return HTTPSigInfo(has_valid_signature=False)
|
return HTTPSigInfo(has_valid_signature=False)
|
||||||
|
|
||||||
|
try:
|
||||||
|
key_id = hsig["keyId"]
|
||||||
|
except KeyError:
|
||||||
|
logger.info("Missing keyId")
|
||||||
|
return HTTPSigInfo(
|
||||||
|
has_valid_signature=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
server = urlparse(key_id).hostname
|
||||||
|
|
||||||
if alg := hsig.get("algorithm") not in ["rsa-sha256", "hs2019"]:
|
if alg := hsig.get("algorithm") not in ["rsa-sha256", "hs2019"]:
|
||||||
logger.info(f"Unsupported HTTP sig algorithm: {alg}")
|
logger.info(f"Unsupported HTTP sig algorithm: {alg}")
|
||||||
return HTTPSigInfo(
|
return HTTPSigInfo(
|
||||||
has_valid_signature=False,
|
has_valid_signature=False,
|
||||||
is_unsupported_algorithm=True,
|
is_unsupported_algorithm=True,
|
||||||
|
server=server,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"hsig={hsig}")
|
logger.debug(f"hsig={hsig}")
|
||||||
|
@ -180,6 +194,7 @@ async def httpsig_checker(
|
||||||
return HTTPSigInfo(
|
return HTTPSigInfo(
|
||||||
has_valid_signature=False,
|
has_valid_signature=False,
|
||||||
is_expired=True,
|
is_expired=True,
|
||||||
|
server=server,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -196,6 +211,7 @@ async def httpsig_checker(
|
||||||
signed_string, base64.b64decode(hsig["signature"]), k.pubkey
|
signed_string, base64.b64decode(hsig["signature"]), k.pubkey
|
||||||
),
|
),
|
||||||
signed_by_ap_actor_id=k.owner,
|
signed_by_ap_actor_id=k.owner,
|
||||||
|
server=server,
|
||||||
)
|
)
|
||||||
logger.info(f"Valid HTTP signature for {httpsig_info.signed_by_ap_actor_id}")
|
logger.info(f"Valid HTTP signature for {httpsig_info.signed_by_ap_actor_id}")
|
||||||
return httpsig_info
|
return httpsig_info
|
||||||
|
@ -206,6 +222,10 @@ async def enforce_httpsig(
|
||||||
httpsig_info: HTTPSigInfo = fastapi.Depends(httpsig_checker),
|
httpsig_info: HTTPSigInfo = fastapi.Depends(httpsig_checker),
|
||||||
) -> HTTPSigInfo:
|
) -> HTTPSigInfo:
|
||||||
"""FastAPI Depends"""
|
"""FastAPI Depends"""
|
||||||
|
if httpsig_info.server in BLOCKED_SERVERS:
|
||||||
|
logger.warning(f"{httpsig_info.server} is blocked")
|
||||||
|
raise fastapi.HTTPException(status_code=403, detail="Blocked")
|
||||||
|
|
||||||
if not httpsig_info.has_valid_signature:
|
if not httpsig_info.has_valid_signature:
|
||||||
logger.warning(f"Invalid HTTP sig {httpsig_info=}")
|
logger.warning(f"Invalid HTTP sig {httpsig_info=}")
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from urllib.parse import urlparse
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from app.config import BLOCKED_SERVERS
|
||||||
from app.config import DEBUG
|
from app.config import DEBUG
|
||||||
|
|
||||||
|
|
||||||
|
@ -53,6 +54,10 @@ def is_url_valid(url: str) -> bool:
|
||||||
if not parsed.hostname or parsed.hostname.lower() in ["localhost"]:
|
if not parsed.hostname or parsed.hostname.lower() in ["localhost"]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if parsed.hostname in BLOCKED_SERVERS:
|
||||||
|
logger.warning(f"{parsed.hostname} is blocked")
|
||||||
|
return False
|
||||||
|
|
||||||
ip_address = _getaddrinfo(
|
ip_address = _getaddrinfo(
|
||||||
parsed.hostname, parsed.port or (80 if parsed.scheme == "http" else 443)
|
parsed.hostname, parsed.port or (80 if parsed.scheme == "http" else 443)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue