More thorough URL checks
parent
0d3b41272f
commit
c160b75851
|
@ -10,6 +10,7 @@ from app import config
|
||||||
from app.config import AP_CONTENT_TYPE # noqa: F401
|
from app.config import AP_CONTENT_TYPE # noqa: F401
|
||||||
from app.httpsig import auth
|
from app.httpsig import auth
|
||||||
from app.key import get_pubkey_as_pem
|
from app.key import get_pubkey_as_pem
|
||||||
|
from app.utils.url import check_url
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.actor import Actor
|
from app.actor import Actor
|
||||||
|
@ -112,6 +113,8 @@ async def fetch(
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
disable_httpsig: bool = False,
|
disable_httpsig: bool = False,
|
||||||
) -> RawObject:
|
) -> RawObject:
|
||||||
|
check_url(url)
|
||||||
|
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
url,
|
url,
|
||||||
|
@ -291,6 +294,8 @@ def remove_context(raw_object: RawObject) -> RawObject:
|
||||||
|
|
||||||
|
|
||||||
def post(url: str, payload: dict[str, Any]) -> httpx.Response:
|
def post(url: str, payload: dict[str, Any]) -> httpx.Response:
|
||||||
|
check_url(url)
|
||||||
|
|
||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
url,
|
url,
|
||||||
headers={
|
headers={
|
||||||
|
|
15
app/main.py
15
app/main.py
|
@ -66,6 +66,7 @@ from app.templates import is_current_user_admin
|
||||||
from app.uploads import UPLOAD_DIR
|
from app.uploads import UPLOAD_DIR
|
||||||
from app.utils import pagination
|
from app.utils import pagination
|
||||||
from app.utils.emoji import EMOJIS_BY_NAME
|
from app.utils.emoji import EMOJIS_BY_NAME
|
||||||
|
from app.utils.url import check_url
|
||||||
from app.webfinger import get_remote_follow_template
|
from app.webfinger import get_remote_follow_template
|
||||||
|
|
||||||
_RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCache(32)
|
_RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCache(32)
|
||||||
|
@ -76,15 +77,12 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
|
||||||
# Next:
|
# Next:
|
||||||
# - fix stream (only content from follows + mention, and dedup shares)
|
# - fix stream (only content from follows + mention, and dedup shares)
|
||||||
# - custom emoji in data/
|
# - custom emoji in data/
|
||||||
# - handle remove activity
|
|
||||||
# - retries httpx?
|
|
||||||
# - DB models for webmentions
|
|
||||||
# - allow to undo follow requests
|
# - allow to undo follow requests
|
||||||
# - indieauth tweaks
|
# - indieauth tweaks
|
||||||
# - API for posting notes
|
# - API for posting notes
|
||||||
# - allow to block servers
|
# - allow to block servers
|
||||||
# - FT5 text search
|
# - FT5 text search
|
||||||
# - support update post with history
|
# - support update post with history?
|
||||||
#
|
#
|
||||||
# - [ ] block support
|
# - [ ] block support
|
||||||
# - [ ] prevent SSRF (urlutils from little-boxes)
|
# - [ ] prevent SSRF (urlutils from little-boxes)
|
||||||
|
@ -93,6 +91,12 @@ _RESIZED_CACHE: MutableMapping[tuple[str, int], tuple[bytes, str, Any]] = LFUCac
|
||||||
|
|
||||||
|
|
||||||
class CustomMiddleware:
|
class CustomMiddleware:
|
||||||
|
"""Raw ASGI middleware as using starlette base middleware causes issues
|
||||||
|
with both:
|
||||||
|
- Jinja2: https://github.com/encode/starlette/issues/472
|
||||||
|
- async SQLAchemy: https://github.com/tiangolo/fastapi/issues/4719
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
app: ASGI3Application,
|
app: ASGI3Application,
|
||||||
|
@ -808,6 +812,8 @@ proxy_client = httpx.AsyncClient(follow_redirects=True, http2=True)
|
||||||
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse:
|
async def serve_proxy_media(request: Request, encoded_url: str) -> StreamingResponse:
|
||||||
# Decode the base64-encoded URL
|
# Decode the base64-encoded URL
|
||||||
url = base64.urlsafe_b64decode(encoded_url).decode()
|
url = base64.urlsafe_b64decode(encoded_url).decode()
|
||||||
|
check_url(url)
|
||||||
|
|
||||||
# Request the URL (and filter request headers)
|
# Request the URL (and filter request headers)
|
||||||
proxy_req = proxy_client.build_request(
|
proxy_req = proxy_client.build_request(
|
||||||
request.method,
|
request.method,
|
||||||
|
@ -856,6 +862,7 @@ async def serve_proxy_media_resized(
|
||||||
|
|
||||||
# Decode the base64-encoded URL
|
# Decode the base64-encoded URL
|
||||||
url = base64.urlsafe_b64decode(encoded_url).decode()
|
url = base64.urlsafe_b64decode(encoded_url).decode()
|
||||||
|
check_url(url)
|
||||||
|
|
||||||
if cached_resp := _RESIZED_CACHE.get((url, size)):
|
if cached_resp := _RESIZED_CACHE.get((url, size)):
|
||||||
resized_content, resized_mimetype, resp_headers = cached_resp
|
resized_content, resized_mimetype, resp_headers = cached_resp
|
||||||
|
|
|
@ -22,6 +22,7 @@ from app.database import AsyncSession
|
||||||
from app.database import SessionLocal
|
from app.database import SessionLocal
|
||||||
from app.key import Key
|
from app.key import Key
|
||||||
from app.utils.datetime import now
|
from app.utils.datetime import now
|
||||||
|
from app.utils.url import check_url
|
||||||
|
|
||||||
_MAX_RETRIES = 16
|
_MAX_RETRIES = 16
|
||||||
|
|
||||||
|
@ -218,6 +219,7 @@ def process_next_outgoing_activity(db: Session) -> bool:
|
||||||
"target": next_activity.webmention_target,
|
"target": next_activity.webmention_target,
|
||||||
}
|
}
|
||||||
logger.info(f"{webmention_payload=}")
|
logger.info(f"{webmention_payload=}")
|
||||||
|
check_url(next_activity.recipient)
|
||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
next_activity.recipient,
|
next_activity.recipient,
|
||||||
data=webmention_payload,
|
data=webmention_payload,
|
||||||
|
|
|
@ -24,7 +24,7 @@ class InvalidURLError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache(maxsize=256)
|
||||||
def _getaddrinfo(hostname: str, port: int) -> str:
|
def _getaddrinfo(hostname: str, port: int) -> str:
|
||||||
try:
|
try:
|
||||||
ip_address = str(ipaddress.ip_address(hostname))
|
ip_address = str(ipaddress.ip_address(hostname))
|
||||||
|
@ -65,7 +65,7 @@ def is_url_valid(url: str) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def check_url(url: str, debug: bool = False) -> None:
|
def check_url(url: str) -> None:
|
||||||
logger.debug(f"check_url {url=}")
|
logger.debug(f"check_url {url=}")
|
||||||
if not is_url_valid(url):
|
if not is_url_valid(url):
|
||||||
raise InvalidURLError(f'"{url}" is invalid')
|
raise InvalidURLError(f'"{url}" is invalid')
|
||||||
|
|
|
@ -8,6 +8,7 @@ from loguru import logger
|
||||||
|
|
||||||
from app import config
|
from app import config
|
||||||
from app.utils.datetime import now
|
from app.utils.datetime import now
|
||||||
|
from app.utils.url import check_url
|
||||||
from app.utils.url import is_url_valid
|
from app.utils.url import is_url_valid
|
||||||
from app.utils.url import make_abs
|
from app.utils.url import make_abs
|
||||||
|
|
||||||
|
@ -46,6 +47,8 @@ async def discover_webmention_endpoint(url: str) -> str | None:
|
||||||
Passes all the tests at https://webmention.rocks!
|
Passes all the tests at https://webmention.rocks!
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
check_url(url)
|
||||||
|
|
||||||
wurl = await _discover_webmention_endoint(url)
|
wurl = await _discover_webmention_endoint(url)
|
||||||
if wurl is None:
|
if wurl is None:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -5,6 +5,7 @@ import httpx
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from app import config
|
from app import config
|
||||||
|
from app.utils.url import check_url
|
||||||
|
|
||||||
|
|
||||||
async def webfinger(
|
async def webfinger(
|
||||||
|
@ -32,6 +33,7 @@ async def webfinger(
|
||||||
for i, proto in enumerate(protos):
|
for i, proto in enumerate(protos):
|
||||||
try:
|
try:
|
||||||
url = f"{proto}://{host}/.well-known/webfinger"
|
url = f"{proto}://{host}/.well-known/webfinger"
|
||||||
|
check_url(url)
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
url,
|
url,
|
||||||
params={"resource": resource},
|
params={"resource": resource},
|
||||||
|
|
Loading…
Reference in New Issue