From 6374b72a1a074e6949cb1069ed398d51beb14a4f Mon Sep 17 00:00:00 2001 From: efir369999 Date: Tue, 5 May 2026 17:17:02 +0300 Subject: [PATCH] Backend: add messenger.py --- Backend/Messenger/messenger.py | 274 +++++++++++++++++++++++++++++++++ 1 file changed, 274 insertions(+) create mode 100644 Backend/Messenger/messenger.py diff --git a/Backend/Messenger/messenger.py b/Backend/Messenger/messenger.py new file mode 100644 index 0000000..bdca837 --- /dev/null +++ b/Backend/Messenger/messenger.py @@ -0,0 +1,274 @@ +import asyncio +import base64 +import hashlib +import json +import os +import sqlite3 +import time +from contextlib import asynccontextmanager +from typing import Optional + +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PublicKey +from fastapi import FastAPI, HTTPException, Header, WebSocket, WebSocketDisconnect +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +DB_PATH = os.environ.get("MESS_DB", "/opt/mess/data.db") +MAX_BODY = 1 << 20 +SIG_WINDOW = 120 + +def b64d(s: str) -> bytes: + s += "=" * (-len(s) % 4) + return base64.urlsafe_b64decode(s) + +def b64e(b: bytes) -> str: + return base64.urlsafe_b64encode(b).rstrip(b"=").decode() + +def account_id_for(ed_pub: bytes) -> str: + return hashlib.sha256(ed_pub).digest()[:8].hex() + +def db(): + c = sqlite3.connect(DB_PATH, isolation_level=None) + c.execute("PRAGMA journal_mode=WAL") + c.execute("PRAGMA foreign_keys=ON") + c.row_factory = sqlite3.Row + return c + +SCHEMA = """ +CREATE TABLE IF NOT EXISTS accounts ( + account_id TEXT PRIMARY KEY, + ed_pub BLOB NOT NULL, + x_pub BLOB NOT NULL, + name TEXT NOT NULL, + created_at INTEGER NOT NULL +); +CREATE TABLE IF NOT EXISTS messages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + sender TEXT NOT NULL, + recipient TEXT NOT NULL, + nonce BLOB NOT NULL, + ciphertext BLOB NOT NULL, + sig BLOB NOT NULL, + sent_at INTEGER NOT NULL, + FOREIGN KEY (sender) REFERENCES accounts(account_id), + FOREIGN KEY (recipient) REFERENCES accounts(account_id) +); +CREATE INDEX IF NOT EXISTS idx_msg_recipient ON messages(recipient, id); +CREATE INDEX IF NOT EXISTS idx_msg_pair ON messages(sender, recipient, id); +""" + +connections: dict[str, set[WebSocket]] = {} + +@asynccontextmanager +async def lifespan(app: FastAPI): + os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) + with db() as c: + c.executescript(SCHEMA) + yield + +app = FastAPI(lifespan=lifespan, docs_url=None, redoc_url=None) + +class RegisterRequest(BaseModel): + name: str = Field(min_length=1, max_length=64) + ed_pub: str + x_pub: str + +class RegisterResponse(BaseModel): + account_id: str + +@app.post("/api/accounts", response_model=RegisterResponse) +def register(req: RegisterRequest): + try: + ed_pub = b64d(req.ed_pub) + x_pub = b64d(req.x_pub) + except Exception: + raise HTTPException(400, "bad_pubkey_b64") + if len(ed_pub) != 32 or len(x_pub) != 32: + raise HTTPException(400, "pubkey_len") + try: + Ed25519PublicKey.from_public_bytes(ed_pub) + except Exception: + raise HTTPException(400, "bad_ed25519") + aid = account_id_for(ed_pub) + now = int(time.time()) + with db() as c: + try: + c.execute( + "INSERT INTO accounts(account_id, ed_pub, x_pub, name, created_at) VALUES (?,?,?,?,?)", + (aid, ed_pub, x_pub, req.name.strip(), now) + ) + except sqlite3.IntegrityError: + row = c.execute("SELECT name FROM accounts WHERE account_id=?", (aid,)).fetchone() + if row and row["name"] == req.name.strip(): + return RegisterResponse(account_id=aid) + raise HTTPException(409, "account_exists") + return RegisterResponse(account_id=aid) + +class AccountInfo(BaseModel): + account_id: str + name: str + ed_pub: str + x_pub: str + created_at: int + +@app.get("/api/accounts/{aid}", response_model=AccountInfo) +def get_account(aid: str): + with db() as c: + row = c.execute("SELECT * FROM accounts WHERE account_id=?", (aid,)).fetchone() + if not row: + raise HTTPException(404, "not_found") + return AccountInfo( + account_id=row["account_id"], + name=row["name"], + ed_pub=b64e(row["ed_pub"]), + x_pub=b64e(row["x_pub"]), + created_at=row["created_at"], + ) + +def verify_auth(account_id: str, x_montana_auth: str, body: bytes) -> None: + try: + ts_str, sig_b64 = x_montana_auth.split(":", 1) + ts = int(ts_str) + except Exception: + raise HTTPException(401, "bad_auth_header") + if abs(int(time.time()) - ts) > SIG_WINDOW: + raise HTTPException(401, "stale_signature") + try: + sig = b64d(sig_b64) + except Exception: + raise HTTPException(401, "bad_sig_b64") + with db() as c: + row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (account_id,)).fetchone() + if not row: + raise HTTPException(401, "unknown_account") + msg = ts_str.encode() + b"\n" + body + try: + Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(sig, msg) + except InvalidSignature: + raise HTTPException(401, "invalid_signature") + +class SendRequest(BaseModel): + to: str + nonce: str + ciphertext: str + +@app.post("/api/messages/send") +async def send_message( + req: SendRequest, + x_montana_account: str = Header(...), + x_montana_auth: str = Header(...), +): + raw = json.dumps(req.model_dump(), separators=(",", ":"), sort_keys=True).encode() + if len(raw) > MAX_BODY: + raise HTTPException(413, "too_large") + verify_auth(x_montana_account, x_montana_auth, raw) + try: + nonce = b64d(req.nonce) + ct = b64d(req.ciphertext) + except Exception: + raise HTTPException(400, "bad_b64") + if len(nonce) != 12 or len(ct) > 64 * 1024: + raise HTTPException(400, "bad_lengths") + with db() as c: + if not c.execute("SELECT 1 FROM accounts WHERE account_id=?", (req.to,)).fetchone(): + raise HTTPException(404, "recipient_not_found") + sent_at = int(time.time() * 1000) + cur = c.execute( + "INSERT INTO messages(sender, recipient, nonce, ciphertext, sig, sent_at) VALUES (?,?,?,?,?,?)", + (x_montana_account, req.to, nonce, ct, b"", sent_at) + ) + msg_id = cur.lastrowid + payload = { + "id": msg_id, + "from": x_montana_account, + "to": req.to, + "nonce": req.nonce, + "ciphertext": req.ciphertext, + "sent_at": sent_at, + } + await broadcast(req.to, payload) + return payload + +@app.get("/api/messages/inbox") +def inbox( + since: int = 0, + x_montana_account: str = Header(...), + x_montana_auth: str = Header(...), +): + body = f"since={since}".encode() + verify_auth(x_montana_account, x_montana_auth, body) + with db() as c: + rows = c.execute( + "SELECT id, sender, recipient, nonce, ciphertext, sent_at FROM messages " + "WHERE (recipient=? OR sender=?) AND id > ? ORDER BY id ASC LIMIT 500", + (x_montana_account, x_montana_account, since) + ).fetchall() + return { + "messages": [ + { + "id": r["id"], + "from": r["sender"], + "to": r["recipient"], + "nonce": b64e(r["nonce"]), + "ciphertext": b64e(r["ciphertext"]), + "sent_at": r["sent_at"], + } + for r in rows + ] + } + +async def broadcast(account_id: str, payload: dict): + sockets = list(connections.get(account_id, ())) + if not sockets: + return + text = json.dumps(payload) + dead = [] + for ws in sockets: + try: + await ws.send_text(text) + except Exception: + dead.append(ws) + for ws in dead: + connections.get(account_id, set()).discard(ws) + +@app.websocket("/api/ws") +async def ws_endpoint(ws: WebSocket): + await ws.accept() + try: + hello = await asyncio.wait_for(ws.receive_text(), timeout=10) + data = json.loads(hello) + aid = data["account_id"] + ts_str = str(data["ts"]) + sig_b64 = data["sig"] + except Exception: + await ws.close(code=4400) + return + if abs(int(time.time()) - int(ts_str)) > SIG_WINDOW: + await ws.close(code=4401) + return + with db() as c: + row = c.execute("SELECT ed_pub FROM accounts WHERE account_id=?", (aid,)).fetchone() + if not row: + await ws.close(code=4404) + return + try: + Ed25519PublicKey.from_public_bytes(row["ed_pub"]).verify(b64d(sig_b64), ts_str.encode() + b"\nws") + except Exception: + await ws.close(code=4401) + return + connections.setdefault(aid, set()).add(ws) + await ws.send_text(json.dumps({"hello": "ok"})) + try: + while True: + msg = await ws.receive_text() + if msg == "ping": + await ws.send_text("pong") + except WebSocketDisconnect: + pass + finally: + connections.get(aid, set()).discard(ws) + +@app.get("/api/health") +def health(): + return {"ok": True, "ts": int(time.time())}