Backend: add messenger.py

This commit is contained in:
efir369999 2026-05-05 17:17:02 +03:00
parent 42311bcb74
commit 6374b72a1a

View File

@ -0,0 +1,274 @@
import asyncio
import base64
import hashlib
import json
import os
import sqlite3
import time
from contextlib import asynccontextmanager
from typing import Optional
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey
from fastapi import FastAPI, HTTPException, Header, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
DB_PATH = os.environ.get("MESS_DB", "/opt/mess/data.db")
MAX_BODY = 1 << 20
SIG_WINDOW = 120
def b64d(s: str) -> bytes:
s += "=" * (-len(s) % 4)
return base64.urlsafe_b64decode(s)
def b64e(b: bytes) -> str:
return base64.urlsafe_b64encode(b).rstrip(b"=").decode()
def account_id_for(ed_pub: bytes) -> str:
return hashlib.sha256(ed_pub).digest()[:8].hex()
def db():
c = sqlite3.connect(DB_PATH, isolation_level=None)
c.execute("PRAGMA journal_mode=WAL")
c.execute("PRAGMA foreign_keys=ON")
c.row_factory = sqlite3.Row
return c
SCHEMA = """
CREATE TABLE IF NOT EXISTS accounts (
account_id TEXT PRIMARY KEY,
ed_pub BLOB NOT NULL,
x_pub BLOB NOT NULL,
name TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
sender TEXT NOT NULL,
recipient TEXT NOT NULL,
nonce BLOB NOT NULL,
ciphertext BLOB NOT NULL,
sig BLOB NOT NULL,
sent_at INTEGER NOT NULL,
FOREIGN KEY (sender) REFERENCES accounts(account_id),
FOREIGN KEY (recipient) REFERENCES accounts(account_id)
);
CREATE INDEX IF NOT EXISTS idx_msg_recipient ON messages(recipient, id);
CREATE INDEX IF NOT EXISTS idx_msg_pair ON messages(sender, recipient, id);
"""
connections: dict[str, set[WebSocket]] = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
with db() as c:
c.executescript(SCHEMA)
yield
app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None)
class RegisterRequest(BaseModel):
name: str = Field(min_length=1, max_length=64)
ed_pub: str
x_pub: str
class RegisterResponse(BaseModel):
account_id: str
@app.post("/api/accounts", response_model=RegisterResponse)
def register(req: RegisterRequest):
try:
ed_pub = b64d(req.ed_pub)
x_pub = b64d(req.x_pub)
except Exception:
raise HTTPException(400, "bad_pubkey_b64")
if len(ed_pub) != 32 or len(x_pub) != 32:
raise HTTPException(400, "pubkey_len")
try:
Ed25519PublicKey.from_public_bytes(ed_pub)
except Exception:
raise HTTPException(400, "bad_ed25519")
aid = account_id_for(ed_pub)
now = int(time.time())
with db() as c:
try:
c.execute(
"INSERT INTO accounts(account_id, ed_pub, x_pub, name, created_at) VALUES (?,?,?,?,?)",
(aid, ed_pub, x_pub, req.name.strip(), now)
)
except sqlite3.IntegrityError:
row = c.execute("SELECT name FROM accounts WHERE account_id=?", (aid,)).fetchone()
if row and row["name"] == req.name.strip():
return RegisterResponse(account_id=aid)
raise HTTPException(409, "account_exists")
return RegisterResponse(account_id=aid)
class AccountInfo(BaseModel):
account_id: str
name: str
ed_pub: str
x_pub: str
created_at: int
@app.get("/api/accounts/{aid}", response_model=AccountInfo)
def get_account(aid: str):
with db() as c:
row = c.execute("SELECT * FROM accounts WHERE account_id=?", (aid,)).fetchone()
if not row:
raise HTTPException(404, "not_found")
return AccountInfo(
account_id=row["account_id"],
name=row["name"],
ed_pub=b64e(row["ed_pub"]),
x_pub=b64e(row["x_pub"]),
created_at=row["created_at"],
)
def verify_auth(account_id: str, x_montana_auth: str, body: bytes) -> None:
try:
ts_str, sig_b64 = x_montana_auth.split(":", 1)
ts = int(ts_str)
except Exception:
raise HTTPException(401, "bad_auth_header")
if abs(int(time.time()) - ts) > SIG_WINDOW:
raise HTTPException(401, "stale_signature")
try:
sig = b64d(sig_b64)
except Exception:
raise HTTPException(401, "bad_sig_b64")
with db() as c:
row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (account_id,)).fetchone()
if not row:
raise HTTPException(401, "unknown_account")
msg = ts_str.encode() + b"\n" + body
try:
Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(sig, msg)
except InvalidSignature:
raise HTTPException(401, "invalid_signature")
class SendRequest(BaseModel):
to: str
nonce: str
ciphertext: str
@app.post("/api/messages/send")
async def send_message(
req: SendRequest,
x_montana_account: str = Header(...),
x_montana_auth: str = Header(...),
):
raw = json.dumps(req.model_dump(), separators=(",", ":"), sort_keys=True).encode()
if len(raw) > MAX_BODY:
raise HTTPException(413, "too_large")
verify_auth(x_montana_account, x_montana_auth, raw)
try:
nonce = b64d(req.nonce)
ct = b64d(req.ciphertext)
except Exception:
raise HTTPException(400, "bad_b64")
if len(nonce) != 12 or len(ct) > 64 * 1024:
raise HTTPException(400, "bad_lengths")
with db() as c:
if not c.execute("SELECT 1 FROM accounts WHERE account_id=?", (req.to,)).fetchone():
raise HTTPException(404, "recipient_not_found")
sent_at = int(time.time() * 1000)
cur = c.execute(
"INSERT INTO messages(sender, recipient, nonce, ciphertext, sig, sent_at) VALUES (?,?,?,?,?,?)",
(x_montana_account, req.to, nonce, ct, b"", sent_at)
)
msg_id = cur.lastrowid
payload = {
"id": msg_id,
"from": x_montana_account,
"to": req.to,
"nonce": req.nonce,
"ciphertext": req.ciphertext,
"sent_at": sent_at,
}
await broadcast(req.to, payload)
return payload
@app.get("/api/messages/inbox")
def inbox(
since: int = 0,
x_montana_account: str = Header(...),
x_montana_auth: str = Header(...),
):
body = f"since={since}".encode()
verify_auth(x_montana_account, x_montana_auth, body)
with db() as c:
rows = c.execute(
"SELECT id, sender, recipient, nonce, ciphertext, sent_at FROM messages "
"WHERE (recipient=? OR sender=?) AND id > ? ORDER BY id ASC LIMIT 500",
(x_montana_account, x_montana_account, since)
).fetchall()
return {
"messages": [
{
"id": r["id"],
"from": r["sender"],
"to": r["recipient"],
"nonce": b64e(r["nonce"]),
"ciphertext": b64e(r["ciphertext"]),
"sent_at": r["sent_at"],
}
for r in rows
]
}
async def broadcast(account_id: str, payload: dict):
sockets = list(connections.get(account_id, ()))
if not sockets:
return
text = json.dumps(payload)
dead = []
for ws in sockets:
try:
await ws.send_text(text)
except Exception:
dead.append(ws)
for ws in dead:
connections.get(account_id, set()).discard(ws)
@app.websocket("/api/ws")
async def ws_endpoint(ws: WebSocket):
await ws.accept()
try:
hello = await asyncio.wait_for(ws.receive_text(), timeout=10)
data = json.loads(hello)
aid = data["account_id"]
ts_str = str(data["ts"])
sig_b64 = data["sig"]
except Exception:
await ws.close(code=4400)
return
if abs(int(time.time()) - int(ts_str)) > SIG_WINDOW:
await ws.close(code=4401)
return
with db() as c:
row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (aid,)).fetchone()
if not row:
await ws.close(code=4404)
return
try:
Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(b64d(sig_b64), ts_str.encode() + b"\nws")
except Exception:
await ws.close(code=4401)
return
connections.setdefault(aid, set()).add(ws)
await ws.send_text(json.dumps({"hello": "ok"}))
try:
while True:
msg = await ws.receive_text()
if msg == "ping":
await ws.send_text("pong")
except WebSocketDisconnect:
pass
finally:
connections.get(aid, set()).discard(ws)
@app.get("/api/health")
def health():
return {"ok": True, "ts": int(time.time())}