More thorough URL checks

main
Thomas Sileo 2022-07-15 20:50:27 +02:00
parent 0d3b41272f
commit c160b75851
6 changed files with 25 additions and 6 deletions

View File

@ -10,6 +10,7 @@ from app import config
from app.config import AP_CONTENT_TYPE # noqa: F401 from app.config import AP_CONTENT_TYPE # noqa: F401
from app.httpsig import auth from app.httpsig import auth
from app.key import get_pubkey_as_pem from app.key import get_pubkey_as_pem
from app.utils.url import check_url
if TYPE_CHECKING: if TYPE_CHECKING:
from app.actor import Actor from app.actor import Actor
@ -112,6 +113,8 @@ async def fetch(
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
disable_httpsig: bool = False, disable_httpsig: bool = False,
) -> RawObject: ) -> RawObject:
check_url(url)
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
resp = await client.get( resp = await client.get(
url, url,
@ -291,6 +294,8 @@ def remove_context(raw_object: RawObject) -> RawObject:
def post(url: str, payload: dict[str, Any]) -> httpx.Response: def post(url: str, payload: dict[str, Any]) -> httpx.Response:
check_url(url)
resp = httpx.post( resp = httpx.post(
url, url,
headers={ headers={

View File

@ -66,6 +66,7 @@ from app.templates import is_current_user_admin
from app.uploads import UPLOAD_DIR from app.uploads import UPLOAD_DIR
from app.utils import pagination from app.utils import pagination
from app.utils.emoji import EMOJIS_BY_NAME from app.utils.emoji import EMOJIS_BY_NAME
from app.utils.url import check_url
from app.webfinger import get_remote_follow_template from app.webfinger import get_remote_follow_template
_RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCache(32) _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCache(32)
@ -76,15 +77,12 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
# Next: # Next:
# - fix stream (only content from follows + mention, and dedup shares) # - fix stream (only content from follows + mention, and dedup shares)
# - custom emoji in data/ # - custom emoji in data/
# - handle remove activity
# - retries httpx?
# - DB models for webmentions
# - allow to undo follow requests # - allow to undo follow requests
# - indieauth tweaks # - indieauth tweaks
# - API for posting notes # - API for posting notes
# - allow to block servers # - allow to block servers
# - FT5 text search # - FT5 text search
# - support update post with history # - support update post with history?
# #
# - [ ] block support # - [ ] block support
# - [ ] prevent SSRF (urlutils from little-boxes) # - [ ] prevent SSRF (urlutils from little-boxes)
@ -93,6 +91,12 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
class CustomMiddleware: class CustomMiddleware:
"""Raw ASGI middleware as using starlette base middleware causes issues
with both:
- Jinja2: https://github.com/encode/starlette/issues/472
- async SQLAchemy: https://github.com/tiangolo/fastapi/issues/4719
"""
def __init__( def __init__(
self, self,
app: ASGI3Application, app: ASGI3Application,
@ -808,6 +812,8 @@ proxy_client = httpx.AsyncClient(follow_redirects=True, http2=True)
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse: async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse:
# Decode the base64-encoded URL # Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode() url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url)
# Request the URL (and filter request headers) # Request the URL (and filter request headers)
proxy_req = proxy_client.build_request( proxy_req = proxy_client.build_request(
request.method, request.method,
@ -856,6 +862,7 @@ async def serve_proxy_media_resized(
# Decode the base64-encoded URL # Decode the base64-encoded URL
url = base64.urlsafe_b64decode(encoded_url).decode() url = base64.urlsafe_b64decode(encoded_url).decode()
check_url(url)
if cached_resp := _RESIZED_CACHE.get((url, size)): if cached_resp := _RESIZED_CACHE.get((url, size)):
resized_content, resized_mimetype, resp_headers = cached_resp resized_content, resized_mimetype, resp_headers = cached_resp

View File

@ -22,6 +22,7 @@ from app.database import AsyncSession
from app.database import SessionLocal from app.database import SessionLocal
from app.key import Key from app.key import Key
from app.utils.datetime import now from app.utils.datetime import now
from app.utils.url import check_url
_MAX_RETRIES = 16 _MAX_RETRIES = 16
@ -218,6 +219,7 @@ def process_next_outgoing_activity(db: Session) -> bool:
"target": next_activity.webmention_target, "target": next_activity.webmention_target,
} }
logger.info(f"{webmention_payload=}") logger.info(f"{webmention_payload=}")
check_url(next_activity.recipient)
resp = httpx.post( resp = httpx.post(
next_activity.recipient, next_activity.recipient,
data=webmention_payload, data=webmention_payload,

View File

@ -24,7 +24,7 @@ class InvalidURLError(Exception):
pass pass
@functools.lru_cache @functools.lru_cache(maxsize=256)
def _getaddrinfo(hostname: str, port: int) -> str: def _getaddrinfo(hostname: str, port: int) -> str:
try: try:
ip_address = str(ipaddress.ip_address(hostname)) ip_address = str(ipaddress.ip_address(hostname))
@ -65,7 +65,7 @@ def is_url_valid(url: str) -> bool:
return True return True
def check_url(url: str, debug: bool = False) -> None: def check_url(url: str) -> None:
logger.debug(f"check_url {url=}") logger.debug(f"check_url {url=}")
if not is_url_valid(url): if not is_url_valid(url):
raise InvalidURLError(f'"{url}" is invalid') raise InvalidURLError(f'"{url}" is invalid')

View File

@ -8,6 +8,7 @@ from loguru import logger
from app import config from app import config
from app.utils.datetime import now from app.utils.datetime import now
from app.utils.url import check_url
from app.utils.url import is_url_valid from app.utils.url import is_url_valid
from app.utils.url import make_abs from app.utils.url import make_abs
@ -46,6 +47,8 @@ async def discover_webmention_endpoint(url: str) -> str | None:
Passes all the tests at https://webmention.rocks! Passes all the tests at https://webmention.rocks!
""" """
check_url(url)
wurl = await _discover_webmention_endoint(url) wurl = await _discover_webmention_endoint(url)
if wurl is None: if wurl is None:
return None return None

View File

@ -5,6 +5,7 @@ import httpx
from loguru import logger from loguru import logger
from app import config from app import config
from app.utils.url import check_url
async def webfinger( async def webfinger(
@ -32,6 +33,7 @@ async def webfinger(
for i, proto in enumerate(protos): for i, proto in enumerate(protos):
try: try:
url = f"{proto}://{host}/.well-known/webfinger" url = f"{proto}://{host}/.well-known/webfinger"
check_url(url)
resp = await client.get( resp = await client.get(
url, url,
params={"resource": resource}, params={"resource": resource},