181 lines
4.9 KiB
Python
181 lines
4.9 KiB
Python
from os import getenv, path
|
|
from typing import Annotated
|
|
|
|
from docker import errors, from_env
|
|
from docker.models.containers import Container
|
|
from dotenv import load_dotenv
|
|
from fastapi import Depends, FastAPI, HTTPException, Request
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from fastapi.staticfiles import StaticFiles
|
|
from pydantic import BaseModel
|
|
from starlette import status, types
|
|
from uvicorn import run
|
|
|
|
load_dotenv()
|
|
client = from_env()
|
|
security = HTTPBasic()
|
|
|
|
|
|
def http_401() -> HTTPException:
|
|
return HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Incorrect username or password",
|
|
headers={"WWW-Authenticate": "Basic"},
|
|
)
|
|
|
|
|
|
async def check_auth(
|
|
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
|
|
) -> HTTPBasicCredentials:
|
|
usernames = getenv("USERNAMES", "").split(",")
|
|
passwords = getenv("PASSWORDS", "").split(",")
|
|
|
|
if credentials.username not in usernames or credentials.password not in passwords:
|
|
raise http_401()
|
|
|
|
user_index = usernames.index(credentials.username)
|
|
password = passwords[user_index]
|
|
|
|
if credentials.password != password:
|
|
raise http_401()
|
|
|
|
return credentials
|
|
|
|
|
|
app = FastAPI(dependencies=[Depends(check_auth)])
|
|
|
|
|
|
class SerializedContainer(BaseModel):
|
|
id: str
|
|
name: str | None
|
|
image: str | None
|
|
labels: dict[str, str]
|
|
status: str
|
|
health: str
|
|
engine: str | None
|
|
owner: str | None
|
|
environment: list[str]
|
|
|
|
|
|
def serialize_container(container: Container) -> SerializedContainer:
|
|
return SerializedContainer(
|
|
id=container.short_id,
|
|
name=container.name,
|
|
image=container.image.tags[0] if container.image else None,
|
|
labels=container.labels,
|
|
status=container.status,
|
|
health=container.health,
|
|
engine=container.labels.get("engine"),
|
|
owner=container.labels.get("owner"),
|
|
environment=container.attrs["Config"]["Env"],
|
|
)
|
|
|
|
|
|
def select_container(
|
|
container_name: str, credentials: Annotated[HTTPBasicCredentials, Depends(security)]
|
|
) -> Container:
|
|
try:
|
|
container = client.containers.get(container_name)
|
|
except errors.APIError:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
|
|
|
if (
|
|
credentials.username != "admin"
|
|
and container.labels.get("engine") != "pilotwings"
|
|
and container.labels.get("owner") != credentials.username
|
|
):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN)
|
|
|
|
return container
|
|
|
|
|
|
@app.get("/api/containers")
|
|
def get_containers(
|
|
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
|
|
) -> list[SerializedContainer]:
|
|
if credentials.username == "admin":
|
|
return [
|
|
serialize_container(container)
|
|
for container in client.containers.list(
|
|
filters={"label": ["engine=pilotwings"]}
|
|
)
|
|
]
|
|
|
|
return [
|
|
serialize_container(container)
|
|
for container in client.containers.list(
|
|
filters={"label": ["engine=pilotwings", f"owner={credentials.username}"]},
|
|
)
|
|
]
|
|
|
|
|
|
@app.get("/api/container/{container_name}")
|
|
def get_container(
|
|
container_name: str,
|
|
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
|
|
) -> SerializedContainer:
|
|
return serialize_container(select_container(container_name, credentials))
|
|
|
|
|
|
class ContainerRequest(BaseModel):
|
|
image: str
|
|
environment: dict[str, str]
|
|
|
|
|
|
@app.post("/api/container/{container_name}")
|
|
def create_or_update_container(
|
|
container_name: str,
|
|
request_body: ContainerRequest,
|
|
credentials: Annotated[HTTPBasicCredentials, Depends(security)],
|
|
) -> SerializedContainer:
|
|
networks = client.networks.list(names=["pilotwings"])
|
|
|
|
if not networks:
|
|
client.networks.create("pilotwings")
|
|
|
|
try:
|
|
container = select_container(container_name, credentials)
|
|
container.stop()
|
|
container.remove(v=True, force=True)
|
|
except errors.APIError:
|
|
pass
|
|
|
|
return serialize_container(
|
|
client.containers.run(
|
|
request_body.image,
|
|
detach=True,
|
|
environment=request_body.environment,
|
|
labels={"engine": "pilotwings", "owner": credentials.username},
|
|
name=container_name,
|
|
network="pilotwings",
|
|
restart_policy={"Name": "always"},
|
|
)
|
|
)
|
|
|
|
|
|
class AuthStaticFiles(StaticFiles):
|
|
async def __call__(
|
|
self, scope: types.Scope, receive: types.Receive, send: types.Send
|
|
) -> None:
|
|
request = Request(scope, receive)
|
|
credentials = await security(request)
|
|
|
|
if not credentials:
|
|
raise http_401()
|
|
|
|
await check_auth(credentials)
|
|
await super().__call__(scope, receive, send)
|
|
|
|
|
|
app.mount(
|
|
"/",
|
|
AuthStaticFiles(
|
|
directory=f"{path.dirname(path.realpath(__file__))}/dist", html=True
|
|
),
|
|
name="static",
|
|
)
|
|
|
|
|
|
def launch() -> None:
|
|
run(app, host="0.0.0.0")
|