275 lines
8.5 KiB
Python
275 lines
8.5 KiB
Python
|
|
import asyncio
|
||
|
|
import base64
|
||
|
|
import hashlib
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import sqlite3
|
||
|
|
import time
|
||
|
|
from contextlib import asynccontextmanager
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
from cryptography.exceptions import InvalidSignature
|
||
|
|
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
|
||
|
|
from fastapi import FastAPI, HTTPException, Header, WebSocket, WebSocketDisconnect
|
||
|
|
from fastapi.responses import JSONResponse
|
||
|
|
from pydantic import BaseModel, Field
|
||
|
|
|
||
|
|
DB_PATH = os.environ.get("MESS_DB", "/opt/mess/data.db")
|
||
|
|
MAX_BODY = 1 << 20
|
||
|
|
SIG_WINDOW = 120
|
||
|
|
|
||
|
|
def b64d(s: str) -> bytes:
|
||
|
|
s += "=" * (-len(s) % 4)
|
||
|
|
return base64.urlsafe_b64decode(s)
|
||
|
|
|
||
|
|
def b64e(b: bytes) -> str:
|
||
|
|
return base64.urlsafe_b64encode(b).rstrip(b"=").decode()
|
||
|
|
|
||
|
|
def account_id_for(ed_pub: bytes) -> str:
|
||
|
|
return hashlib.sha256(ed_pub).digest()[:8].hex()
|
||
|
|
|
||
|
|
def db():
|
||
|
|
c = sqlite3.connect(DB_PATH, isolation_level=None)
|
||
|
|
c.execute("PRAGMA journal_mode=WAL")
|
||
|
|
c.execute("PRAGMA foreign_keys=ON")
|
||
|
|
c.row_factory = sqlite3.Row
|
||
|
|
return c
|
||
|
|
|
||
|
|
SCHEMA = """
|
||
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
||
|
|
account_id TEXT PRIMARY KEY,
|
||
|
|
ed_pub BLOB NOT NULL,
|
||
|
|
x_pub BLOB NOT NULL,
|
||
|
|
name TEXT NOT NULL,
|
||
|
|
created_at INTEGER NOT NULL
|
||
|
|
);
|
||
|
|
CREATE TABLE IF NOT EXISTS messages (
|
||
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
|
|
sender TEXT NOT NULL,
|
||
|
|
recipient TEXT NOT NULL,
|
||
|
|
nonce BLOB NOT NULL,
|
||
|
|
ciphertext BLOB NOT NULL,
|
||
|
|
sig BLOB NOT NULL,
|
||
|
|
sent_at INTEGER NOT NULL,
|
||
|
|
FOREIGN KEY (sender) REFERENCES accounts(account_id),
|
||
|
|
FOREIGN KEY (recipient) REFERENCES accounts(account_id)
|
||
|
|
);
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_msg_recipient ON messages(recipient, id);
|
||
|
|
CREATE INDEX IF NOT EXISTS idx_msg_pair ON messages(sender, recipient, id);
|
||
|
|
"""
|
||
|
|
|
||
|
|
connections: dict[str, set[WebSocket]] = {}
|
||
|
|
|
||
|
|
@asynccontextmanager
|
||
|
|
async def lifespan(app: FastAPI):
|
||
|
|
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
|
||
|
|
with db() as c:
|
||
|
|
c.executescript(SCHEMA)
|
||
|
|
yield
|
||
|
|
|
||
|
|
app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None)
|
||
|
|
|
||
|
|
class RegisterRequest(BaseModel):
|
||
|
|
name: str = Field(min_length=1, max_length=64)
|
||
|
|
ed_pub: str
|
||
|
|
x_pub: str
|
||
|
|
|
||
|
|
class RegisterResponse(BaseModel):
|
||
|
|
account_id: str
|
||
|
|
|
||
|
|
@app.post("/api/accounts", response_model=RegisterResponse)
|
||
|
|
def register(req: RegisterRequest):
|
||
|
|
try:
|
||
|
|
ed_pub = b64d(req.ed_pub)
|
||
|
|
x_pub = b64d(req.x_pub)
|
||
|
|
except Exception:
|
||
|
|
raise HTTPException(400, "bad_pubkey_b64")
|
||
|
|
if len(ed_pub) != 32 or len(x_pub) != 32:
|
||
|
|
raise HTTPException(400, "pubkey_len")
|
||
|
|
try:
|
||
|
|
Ed25519PublicKey.from_public_bytes(ed_pub)
|
||
|
|
except Exception:
|
||
|
|
raise HTTPException(400, "bad_ed25519")
|
||
|
|
aid = account_id_for(ed_pub)
|
||
|
|
now = int(time.time())
|
||
|
|
with db() as c:
|
||
|
|
try:
|
||
|
|
c.execute(
|
||
|
|
"INSERT INTO accounts(account_id, ed_pub, x_pub, name, created_at) VALUES (?,?,?,?,?)",
|
||
|
|
(aid, ed_pub, x_pub, req.name.strip(), now)
|
||
|
|
)
|
||
|
|
except sqlite3.IntegrityError:
|
||
|
|
row = c.execute("SELECT name FROM accounts WHERE account_id=?", (aid,)).fetchone()
|
||
|
|
if row and row["name"] == req.name.strip():
|
||
|
|
return RegisterResponse(account_id=aid)
|
||
|
|
raise HTTPException(409, "account_exists")
|
||
|
|
return RegisterResponse(account_id=aid)
|
||
|
|
|
||
|
|
class AccountInfo(BaseModel):
|
||
|
|
account_id: str
|
||
|
|
name: str
|
||
|
|
ed_pub: str
|
||
|
|
x_pub: str
|
||
|
|
created_at: int
|
||
|
|
|
||
|
|
@app.get("/api/accounts/{aid}", response_model=AccountInfo)
|
||
|
|
def get_account(aid: str):
|
||
|
|
with db() as c:
|
||
|
|
row = c.execute("SELECT * FROM accounts WHERE account_id=?", (aid,)).fetchone()
|
||
|
|
if not row:
|
||
|
|
raise HTTPException(404, "not_found")
|
||
|
|
return AccountInfo(
|
||
|
|
account_id=row["account_id"],
|
||
|
|
name=row["name"],
|
||
|
|
ed_pub=b64e(row["ed_pub"]),
|
||
|
|
x_pub=b64e(row["x_pub"]),
|
||
|
|
created_at=row["created_at"],
|
||
|
|
)
|
||
|
|
|
||
|
|
def verify_auth(account_id: str, x_montana_auth: str, body: bytes) -> None:
|
||
|
|
try:
|
||
|
|
ts_str, sig_b64 = x_montana_auth.split(":", 1)
|
||
|
|
ts = int(ts_str)
|
||
|
|
except Exception:
|
||
|
|
raise HTTPException(401, "bad_auth_header")
|
||
|
|
if abs(int(time.time()) - ts) > SIG_WINDOW:
|
||
|
|
raise HTTPException(401, "stale_signature")
|
||
|
|
try:
|
||
|
|
sig = b64d(sig_b64)
|
||
|
|
except Exception:
|
||
|
|
raise HTTPException(401, "bad_sig_b64")
|
||
|
|
with db() as c:
|
||
|
|
row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (account_id,)).fetchone()
|
||
|
|
if not row:
|
||
|
|
raise HTTPException(401, "unknown_account")
|
||
|
|
msg = ts_str.encode() + b"\n" + body
|
||
|
|
try:
|
||
|
|
Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(sig, msg)
|
||
|
|
except InvalidSignature:
|
||
|
|
raise HTTPException(401, "invalid_signature")
|
||
|
|
|
||
|
|
class SendRequest(BaseModel):
|
||
|
|
to: str
|
||
|
|
nonce: str
|
||
|
|
ciphertext: str
|
||
|
|
|
||
|
|
@app.post("/api/messages/send")
|
||
|
|
async def send_message(
|
||
|
|
req: SendRequest,
|
||
|
|
x_montana_account: str = Header(...),
|
||
|
|
x_montana_auth: str = Header(...),
|
||
|
|
):
|
||
|
|
raw = json.dumps(req.model_dump(), separators=(",", ":"), sort_keys=True).encode()
|
||
|
|
if len(raw) > MAX_BODY:
|
||
|
|
raise HTTPException(413, "too_large")
|
||
|
|
verify_auth(x_montana_account, x_montana_auth, raw)
|
||
|
|
try:
|
||
|
|
nonce = b64d(req.nonce)
|
||
|
|
ct = b64d(req.ciphertext)
|
||
|
|
except Exception:
|
||
|
|
raise HTTPException(400, "bad_b64")
|
||
|
|
if len(nonce) != 12 or len(ct) > 64 * 1024:
|
||
|
|
raise HTTPException(400, "bad_lengths")
|
||
|
|
with db() as c:
|
||
|
|
if not c.execute("SELECT 1 FROM accounts WHERE account_id=?", (req.to,)).fetchone():
|
||
|
|
raise HTTPException(404, "recipient_not_found")
|
||
|
|
sent_at = int(time.time() * 1000)
|
||
|
|
cur = c.execute(
|
||
|
|
"INSERT INTO messages(sender, recipient, nonce, ciphertext, sig, sent_at) VALUES (?,?,?,?,?,?)",
|
||
|
|
(x_montana_account, req.to, nonce, ct, b"", sent_at)
|
||
|
|
)
|
||
|
|
msg_id = cur.lastrowid
|
||
|
|
payload = {
|
||
|
|
"id": msg_id,
|
||
|
|
"from": x_montana_account,
|
||
|
|
"to": req.to,
|
||
|
|
"nonce": req.nonce,
|
||
|
|
"ciphertext": req.ciphertext,
|
||
|
|
"sent_at": sent_at,
|
||
|
|
}
|
||
|
|
await broadcast(req.to, payload)
|
||
|
|
return payload
|
||
|
|
|
||
|
|
@app.get("/api/messages/inbox")
|
||
|
|
def inbox(
|
||
|
|
since: int = 0,
|
||
|
|
x_montana_account: str = Header(...),
|
||
|
|
x_montana_auth: str = Header(...),
|
||
|
|
):
|
||
|
|
body = f"since={since}".encode()
|
||
|
|
verify_auth(x_montana_account, x_montana_auth, body)
|
||
|
|
with db() as c:
|
||
|
|
rows = c.execute(
|
||
|
|
"SELECT id, sender, recipient, nonce, ciphertext, sent_at FROM messages "
|
||
|
|
"WHERE (recipient=? OR sender=?) AND id > ? ORDER BY id ASC LIMIT 500",
|
||
|
|
(x_montana_account, x_montana_account, since)
|
||
|
|
).fetchall()
|
||
|
|
return {
|
||
|
|
"messages": [
|
||
|
|
{
|
||
|
|
"id": r["id"],
|
||
|
|
"from": r["sender"],
|
||
|
|
"to": r["recipient"],
|
||
|
|
"nonce": b64e(r["nonce"]),
|
||
|
|
"ciphertext": b64e(r["ciphertext"]),
|
||
|
|
"sent_at": r["sent_at"],
|
||
|
|
}
|
||
|
|
for r in rows
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
async def broadcast(account_id: str, payload: dict):
|
||
|
|
sockets = list(connections.get(account_id, ()))
|
||
|
|
if not sockets:
|
||
|
|
return
|
||
|
|
text = json.dumps(payload)
|
||
|
|
dead = []
|
||
|
|
for ws in sockets:
|
||
|
|
try:
|
||
|
|
await ws.send_text(text)
|
||
|
|
except Exception:
|
||
|
|
dead.append(ws)
|
||
|
|
for ws in dead:
|
||
|
|
connections.get(account_id, set()).discard(ws)
|
||
|
|
|
||
|
|
@app.websocket("/api/ws")
|
||
|
|
async def ws_endpoint(ws: WebSocket):
|
||
|
|
await ws.accept()
|
||
|
|
try:
|
||
|
|
hello = await asyncio.wait_for(ws.receive_text(), timeout=10)
|
||
|
|
data = json.loads(hello)
|
||
|
|
aid = data["account_id"]
|
||
|
|
ts_str = str(data["ts"])
|
||
|
|
sig_b64 = data["sig"]
|
||
|
|
except Exception:
|
||
|
|
await ws.close(code=4400)
|
||
|
|
return
|
||
|
|
if abs(int(time.time()) - int(ts_str)) > SIG_WINDOW:
|
||
|
|
await ws.close(code=4401)
|
||
|
|
return
|
||
|
|
with db() as c:
|
||
|
|
row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (aid,)).fetchone()
|
||
|
|
if not row:
|
||
|
|
await ws.close(code=4404)
|
||
|
|
return
|
||
|
|
try:
|
||
|
|
Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(b64d(sig_b64), ts_str.encode() + b"\nws")
|
||
|
|
except Exception:
|
||
|
|
await ws.close(code=4401)
|
||
|
|
return
|
||
|
|
connections.setdefault(aid, set()).add(ws)
|
||
|
|
await ws.send_text(json.dumps({"hello": "ok"}))
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
msg = await ws.receive_text()
|
||
|
|
if msg == "ping":
|
||
|
|
await ws.send_text("pong")
|
||
|
|
except WebSocketDisconnect:
|
||
|
|
pass
|
||
|
|
finally:
|
||
|
|
connections.get(aid, set()).discard(ws)
|
||
|
|
|
||
|
|
@app.get("/api/health")
|
||
|
|
def health():
|
||
|
|
return {"ok": True, "ts": int(time.time())}
|