diff --git a/server/src/index.ts b/server/src/index.ts index 1c8b1dc..f0676a0 100644 --- a/server/src/index.ts +++ b/server/src/index.ts @@ -4,7 +4,19 @@ import pino from "pino"; import { WebSocketServer } from "ws"; import type { RawData, WebSocket } from "ws"; import type { IncomingMessage } from "http"; -import { ProtocolVersion, isMessage, nowTs } from "@assistenza/protocol"; +import { + ProtocolVersion, + isMessage, + nowTs, + type AgentRegister, + type AgentRegistered, + type AgentHeartbeat, + type ClientLogin, + type ClientLoginResult, + type ClientListDevices, + type ClientDeviceList, + type ErrorMessage, +} from "@assistenza/protocol"; export function serverBanner(): string { return `server v${ProtocolVersion} ${nowTs()}`; @@ -13,6 +25,23 @@ export function serverBanner(): string { const logger = pino({ level: process.env.LOG_LEVEL ?? "info" }); const app = Fastify({ logger }); +const seedUsername = process.env.SEED_USERNAME ?? ""; +const seedPassword = process.env.SEED_PASSWORD ?? ""; +const seedUserId = process.env.SEED_USER_ID ?? ""; +const seedPairingKey = process.env.SEED_PAIRING_KEY ?? ""; + +const clientSessions = new Map(); +const agents = new Map< + string, + { + deviceId: string; + deviceName: string; + userId: string; + ws: WebSocket; + online: boolean; + } +>(); + app.get("/health", async (_request, reply) => { reply.code(200).type("text/plain").send("ok"); }); @@ -55,6 +84,17 @@ function sendBadRequest(ws: WebSocket, requestId: string | undefined, message: s ws.send(JSON.stringify(payload)); } +function sendUnauthorized(ws: WebSocket, requestId: string | undefined, message = "Unauthorized"): void { + const payload: ErrorMessage = { + v: ProtocolVersion, + type: "error", + code: "UNAUTHORIZED", + message, + requestId: requestId ?? "unknown", + }; + ws.send(JSON.stringify(payload)); +} + function rawToString(data: RawData): string { if (typeof data === "string") return data; if (data instanceof Buffer) return data.toString("utf8"); @@ -62,6 +102,100 @@ function rawToString(data: RawData): string { return Buffer.from(data).toString("utf8"); } +function sendClientLoginResult(ws: WebSocket, requestId: string, ok: boolean): void { + const payload: ClientLoginResult = { + v: ProtocolVersion, + type: "client_login_result", + requestId, + ok, + ...(ok ? { clientId: seedUserId } : { message: "Invalid credentials" }), + }; + ws.send(JSON.stringify(payload)); +} + +function sendAgentRegistered(ws: WebSocket, requestId: string, deviceId: string): void { + const payload: AgentRegistered = { + v: ProtocolVersion, + type: "agent_registered", + requestId, + deviceId, + }; + ws.send(JSON.stringify(payload)); +} + +function sendClientDeviceList( + ws: WebSocket, + requestId: string, + devices: ClientDeviceList["devices"] +): void { + const payload: ClientDeviceList = { + v: ProtocolVersion, + type: "client_device_list", + requestId, + devices, + }; + ws.send(JSON.stringify(payload)); +} + +function handleClientMessage(ws: WebSocket, message: ClientLogin | ClientListDevices): void { + switch (message.type) { + case "client_login": { + const ok = message.username === seedUsername && message.password === seedPassword; + if (!ok) { + sendUnauthorized(ws, message.requestId); + return; + } + clientSessions.set(ws, { userId: seedUserId }); + sendClientLoginResult(ws, message.requestId, true); + return; + } + case "client_list_devices": { + const session = clientSessions.get(ws); + const devices = Array.from(agents.values()) + .filter((agent) => agent.userId === session.userId) + .map((agent) => ({ + deviceId: agent.deviceId, + deviceName: agent.deviceName, + online: agent.online, + })); + sendClientDeviceList(ws, message.requestId, devices); + return; + } + default: + sendBadRequest(ws, message.requestId, "Unsupported client message"); + } +} + +function handleAgentMessage(ws: WebSocket, message: AgentRegister | AgentHeartbeat): void { + switch (message.type) { + case "agent_register": { + if (message.pairingKey !== seedPairingKey) { + sendUnauthorized(ws, message.requestId); + return; + } + agents.set(message.deviceId, { + deviceId: message.deviceId, + deviceName: message.deviceName, + userId: seedUserId, + ws, + online: true, + }); + sendAgentRegistered(ws, message.requestId, message.deviceId); + return; + } + case "agent_heartbeat": { + const agent = agents.get(message.deviceId); + if (agent) { + agent.online = true; + agent.ws = ws; + } + return; + } + default: + sendBadRequest(ws, message.requestId, "Unsupported agent message"); + } +} + function handleConnection(kind: WsKind, ws: WebSocket, request: IncomingMessage): void { const path = getPath(request); const ip = getIp(request); @@ -82,6 +216,25 @@ function handleConnection(kind: WsKind, ws: WebSocket, request: IncomingMessage) sendBadRequest(ws, requestId, "Invalid message shape"); return; } + + if (kind === "client") { + if (parsed.type !== "client_login" && !clientSessions.has(ws)) { + sendUnauthorized(ws, parsed.requestId); + return; + } + if (parsed.type === "client_login" || parsed.type === "client_list_devices") { + handleClientMessage(ws, parsed); + return; + } + sendBadRequest(ws, parsed.requestId, "Unsupported client message"); + return; + } + + if (parsed.type === "agent_register" || parsed.type === "agent_heartbeat") { + handleAgentMessage(ws, parsed); + return; + } + sendBadRequest(ws, parsed.requestId, "Unsupported agent message"); }); ws.on("close", (code, reason) => { @@ -89,6 +242,17 @@ function handleConnection(kind: WsKind, ws: WebSocket, request: IncomingMessage) { ip, path, kind, code, reason: reason.toString() }, "ws disconnect" ); + + if (kind === "client") { + clientSessions.delete(ws); + return; + } + + for (const agent of agents.values()) { + if (agent.ws === ws) { + agent.online = false; + } + } }); ws.on("error", (err) => {