50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
|
from os import getenv
|
||
|
from typing import Annotated
|
||
|
|
||
|
from fastapi import Depends, HTTPException, Request
|
||
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||
|
from fastapi.staticfiles import StaticFiles
|
||
|
from starlette import status, types
|
||
|
|
||
|
security = HTTPBasic()
|
||
|
|
||
|
|
||
|
def http_401() -> HTTPException:
|
||
|
return HTTPException(
|
||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||
|
detail="Incorrect username or password",
|
||
|
headers={"WWW-Authenticate": "Basic"},
|
||
|
)
|
||
|
|
||
|
|
||
|
async def check_auth(
|
||
|
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
|
||
|
) -> HTTPBasicCredentials:
|
||
|
usernames = getenv("USERNAMES", "").split(",")
|
||
|
passwords = getenv("PASSWORDS", "").split(",")
|
||
|
|
||
|
if credentials.username not in usernames or credentials.password not in passwords:
|
||
|
raise http_401()
|
||
|
|
||
|
user_index = usernames.index(credentials.username)
|
||
|
password = passwords[user_index]
|
||
|
|
||
|
if credentials.password != password:
|
||
|
raise http_401()
|
||
|
|
||
|
return credentials
|
||
|
|
||
|
|
||
|
class AuthStaticFiles(StaticFiles):
|
||
|
async def __call__(
|
||
|
self, scope: types.Scope, receive: types.Receive, send: types.Send
|
||
|
) -> None:
|
||
|
request = Request(scope, receive)
|
||
|
credentials = await security(request)
|
||
|
|
||
|
if not credentials:
|
||
|
raise http_401()
|
||
|
|
||
|
await check_auth(credentials)
|
||
|
await super().__call__(scope, receive, send)
|