import asyncio import base64 import hashlib import json import os import sqlite3 import time from contextlib import asynccontextmanager from typing import Optional from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey from fastapi import FastAPI, HTTPException, Header, WebSocket, WebSocketDisconnect from fastapi.responses import JSONResponse from pydantic import BaseModel, Field DB_PATH = os.environ.get("MESS_DB", "/opt/mess/data.db") MAX_BODY = 1 << 20 SIG_WINDOW = 120 def b64d(s: str) -> bytes: s += "=" * (-len(s) % 4) return base64.urlsafe_b64decode(s) def b64e(b: bytes) -> str: return base64.urlsafe_b64encode(b).rstrip(b"=").decode() def account_id_for(ed_pub: bytes) -> str: return hashlib.sha256(ed_pub).digest()[:8].hex() def db(): c = sqlite3.connect(DB_PATH, isolation_level=None) c.execute("PRAGMA journal_mode=WAL") c.execute("PRAGMA foreign_keys=ON") c.row_factory = sqlite3.Row return c SCHEMA = """ CREATE TABLE IF NOT EXISTS accounts ( account_id TEXT PRIMARY KEY, ed_pub BLOB NOT NULL, x_pub BLOB NOT NULL, name TEXT NOT NULL, created_at INTEGER NOT NULL ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, sender TEXT NOT NULL, recipient TEXT NOT NULL, nonce BLOB NOT NULL, ciphertext BLOB NOT NULL, sig BLOB NOT NULL, sent_at INTEGER NOT NULL, FOREIGN KEY (sender) REFERENCES accounts(account_id), FOREIGN KEY (recipient) REFERENCES accounts(account_id) ); CREATE INDEX IF NOT EXISTS idx_msg_recipient ON messages(recipient, id); CREATE INDEX IF NOT EXISTS idx_msg_pair ON messages(sender, recipient, id); """ connections: dict[str, set[WebSocket]] = {} @asynccontextmanager async def lifespan(app: FastAPI): os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) with db() as c: c.executescript(SCHEMA) yield app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None) class RegisterRequest(BaseModel): name: str = Field(min_length=1, max_length=64) ed_pub: str x_pub: str class RegisterResponse(BaseModel): account_id: str @app.post("/api/accounts", response_model=RegisterResponse) def register(req: RegisterRequest): try: ed_pub = b64d(req.ed_pub) x_pub = b64d(req.x_pub) except Exception: raise HTTPException(400, "bad_pubkey_b64") if len(ed_pub) != 32 or len(x_pub) != 32: raise HTTPException(400, "pubkey_len") try: Ed25519PublicKey.from_public_bytes(ed_pub) except Exception: raise HTTPException(400, "bad_ed25519") aid = account_id_for(ed_pub) now = int(time.time()) with db() as c: try: c.execute( "INSERT INTO accounts(account_id, ed_pub, x_pub, name, created_at) VALUES (?,?,?,?,?)", (aid, ed_pub, x_pub, req.name.strip(), now) ) except sqlite3.IntegrityError: row = c.execute("SELECT name FROM accounts WHERE account_id=?", (aid,)).fetchone() if row and row["name"] == req.name.strip(): return RegisterResponse(account_id=aid) raise HTTPException(409, "account_exists") return RegisterResponse(account_id=aid) class AccountInfo(BaseModel): account_id: str name: str ed_pub: str x_pub: str created_at: int @app.get("/api/accounts/{aid}", response_model=AccountInfo) def get_account(aid: str): with db() as c: row = c.execute("SELECT * FROM accounts WHERE account_id=?", (aid,)).fetchone() if not row: raise HTTPException(404, "not_found") return AccountInfo( account_id=row["account_id"], name=row["name"], ed_pub=b64e(row["ed_pub"]), x_pub=b64e(row["x_pub"]), created_at=row["created_at"], ) def verify_auth(account_id: str, x_montana_auth: str, body: bytes) -> None: try: ts_str, sig_b64 = x_montana_auth.split(":", 1) ts = int(ts_str) except Exception: raise HTTPException(401, "bad_auth_header") if abs(int(time.time()) - ts) > SIG_WINDOW: raise HTTPException(401, "stale_signature") try: sig = b64d(sig_b64) except Exception: raise HTTPException(401, "bad_sig_b64") with db() as c: row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (account_id,)).fetchone() if not row: raise HTTPException(401, "unknown_account") msg = ts_str.encode() + b"\n" + body try: Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(sig, msg) except InvalidSignature: raise HTTPException(401, "invalid_signature") class SendRequest(BaseModel): to: str nonce: str ciphertext: str @app.post("/api/messages/send") async def send_message( req: SendRequest, x_montana_account: str = Header(...), x_montana_auth: str = Header(...), ): raw = json.dumps(req.model_dump(), separators=(",", ":"), sort_keys=True).encode() if len(raw) > MAX_BODY: raise HTTPException(413, "too_large") verify_auth(x_montana_account, x_montana_auth, raw) try: nonce = b64d(req.nonce) ct = b64d(req.ciphertext) except Exception: raise HTTPException(400, "bad_b64") if len(nonce) != 12 or len(ct) > 64 * 1024: raise HTTPException(400, "bad_lengths") with db() as c: if not c.execute("SELECT 1 FROM accounts WHERE account_id=?", (req.to,)).fetchone(): raise HTTPException(404, "recipient_not_found") sent_at = int(time.time() * 1000) cur = c.execute( "INSERT INTO messages(sender, recipient, nonce, ciphertext, sig, sent_at) VALUES (?,?,?,?,?,?)", (x_montana_account, req.to, nonce, ct, b"", sent_at) ) msg_id = cur.lastrowid payload = { "id": msg_id, "from": x_montana_account, "to": req.to, "nonce": req.nonce, "ciphertext": req.ciphertext, "sent_at": sent_at, } await broadcast(req.to, payload) return payload @app.get("/api/messages/inbox") def inbox( since: int = 0, x_montana_account: str = Header(...), x_montana_auth: str = Header(...), ): body = f"since={since}".encode() verify_auth(x_montana_account, x_montana_auth, body) with db() as c: rows = c.execute( "SELECT id, sender, recipient, nonce, ciphertext, sent_at FROM messages " "WHERE (recipient=? OR sender=?) AND id > ? ORDER BY id ASC LIMIT 500", (x_montana_account, x_montana_account, since) ).fetchall() return { "messages": [ { "id": r["id"], "from": r["sender"], "to": r["recipient"], "nonce": b64e(r["nonce"]), "ciphertext": b64e(r["ciphertext"]), "sent_at": r["sent_at"], } for r in rows ] } async def broadcast(account_id: str, payload: dict): sockets = list(connections.get(account_id, ())) if not sockets: return text = json.dumps(payload) dead = [] for ws in sockets: try: await ws.send_text(text) except Exception: dead.append(ws) for ws in dead: connections.get(account_id, set()).discard(ws) @app.websocket("/api/ws") async def ws_endpoint(ws: WebSocket): await ws.accept() try: hello = await asyncio.wait_for(ws.receive_text(), timeout=10) data = json.loads(hello) aid = data["account_id"] ts_str = str(data["ts"]) sig_b64 = data["sig"] except Exception: await ws.close(code=4400) return if abs(int(time.time()) - int(ts_str)) > SIG_WINDOW: await ws.close(code=4401) return with db() as c: row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (aid,)).fetchone() if not row: await ws.close(code=4404) return try: Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(b64d(sig_b64), ts_str.encode() + b"\nws") except Exception: await ws.close(code=4401) return connections.setdefault(aid, set()).add(ws) await ws.send_text(json.dumps({"hello": "ok"})) try: while True: msg = await ws.receive_text() if msg == "ping": await ws.send_text("pong") except WebSocketDisconnect: pass finally: connections.get(aid, set()).discard(ws) @app.get("/api/health") def health(): return {"ok": True, "ts": int(time.time())}