Skip to content

WebSocket Notifications - Project Reference Pattern

Component: Real-time WebSocket Notifications Status: 🟢 Stable Created: 2025-12-29 Last Updated: 2025-12-29


Overview

Purpose

This PRP defines implementation patterns for WebSocket-based real-time notifications including connection management, message formats, subscription handling, and client reconnection logic.

Scope

Responsibilities: - WebSocket connection lifecycle - Message format definitions (JSON schemas) - Subscription management - Client reconnection protocol - Rate limiting

Out of Scope: - Alert engine business logic (see alert-engine architecture) - Telegram/Push notifications (see notifications-prp) - Authentication (see global.md)


Patterns

Pattern: Connection Endpoint

Problem: Clients need a WebSocket endpoint with authentication.

Solution: Single endpoint authenticated via HTTP-only cookie.

Implementation:

# central/src/api/websocket.py
from fastapi import WebSocket, WebSocketDisconnect, Depends
from ..services.auth import get_current_user_ws
from ..services.websocket_manager import manager

@app.websocket("/ws")
async def websocket_endpoint(
    websocket: WebSocket,
    user = Depends(get_current_user_ws)
):
    """
    WebSocket endpoint for real-time notifications.

    Authentication via HTTP-only cookie (same as REST API).
    """
    await manager.connect(websocket, user.id)
    try:
        while True:
            data = await websocket.receive_json()
            await manager.handle_message(websocket, user.id, data)
    except WebSocketDisconnect:
        manager.disconnect(websocket, user.id)

Endpoint: wss://server.example.com/ws


Pattern: Message Format

Problem: Need consistent message structure for all WebSocket communication.

Solution: JSON messages with required type field and optional fields.

Base Message Structure:

{
  "type": "message_type",
  "payload": { },
  "timestamp": "2025-01-15T14:32:05.123Z",
  "id": "msg_abc123"
}

Field Definitions:

Field Type Required Description
type string Yes Message type identifier
payload object Yes Type-specific data
timestamp string Server→Client ISO 8601 UTC timestamp
id string Server→Client Unique message ID for deduplication

Implementation:

from pydantic import BaseModel, Field
from datetime import datetime
from typing import Any
import uuid

class WebSocketMessage(BaseModel):
    """Base WebSocket message structure."""
    type: str
    payload: dict[str, Any]
    timestamp: datetime = Field(default_factory=datetime.utcnow)
    id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4().hex[:12]}")

    def to_json(self) -> str:
        return self.model_dump_json()


# Server-to-client messages
class DetectionMessage(WebSocketMessage):
    type: str = "detection"

class AlertMessage(WebSocketMessage):
    type: str = "alert"

class CollectorStatusMessage(WebSocketMessage):
    type: str = "collector_status"

class PingMessage(WebSocketMessage):
    type: str = "ping"

class ErrorMessage(WebSocketMessage):
    type: str = "error"

Pattern: Server → Client Messages

detection:

def create_detection_message(detection: Detection) -> WebSocketMessage:
    """Create detection message from detection."""
    return WebSocketMessage(
        type="detection",
        payload={
            "detection_id": str(detection.id),
            "plate_number": detection.plate_number,
            "plate_normalized": detection.plate_normalized,
            "confidence": detection.confidence,
            "captured_at": detection.captured_at.isoformat(),
            "collector": {
                "id": str(detection.collector_id),
                "name": detection.collector.site_name
            },
            "camera": {
                "id": str(detection.camera_id),
                "name": detection.camera.name
            },
            "direction": detection.direction,
            "image_urls": {
                "full": get_presigned_url(detection.id, "full"),
                "crop": get_presigned_url(detection.id, "crop")
            }
        }
    )

Alert Notification:

def create_alert_message(alert: Alert) -> WebSocketMessage:
    """Create alert message from alert record."""
    return WebSocketMessage(
        type="alert",
        payload={
            "alert_id": str(alert.id),
            "detection_id": str(alert.detection_id),
            "plate_number": alert.detection.plate_number,
            "watchlist": {
                "id": str(alert.watchlist_id),
                "name": alert.watchlist.name,
                "priority": alert.watchlist.priority
            },
            "entry": {
                "notes": alert.watchlist_entry.notes,
                "added_by": alert.watchlist_entry.added_by_name,
                "added_at": alert.watchlist_entry.created_at.isoformat()
            },
            "detection": {
                "captured_at": alert.detection.captured_at.isoformat(),
                "confidence": alert.detection.confidence,
                "collector_name": alert.detection.collector.site_name,
                "camera_name": alert.detection.camera.name,
                "image_urls": {
                    "full": get_presigned_url(alert.detection_id, "full"),
                    "crop": get_presigned_url(alert.detection_id, "crop")
                }
            }
        }
    )

Collector Status:

def create_collector_status_message(
    collector: Collector,
    previous_status: str
) -> WebSocketMessage:
    """Create collector status change message."""
    return WebSocketMessage(
        type="collector_status",
        payload={
            "collector_id": str(collector.id),
            "collector_name": collector.site_name,
            "status": collector.status,
            "previous_status": previous_status,
            "last_heartbeat_at": collector.last_heartbeat_at.isoformat(),
            "cameras_online": collector.cameras_online,
            "cameras_total": collector.cameras_total
        }
    )

Ping (Health Check):

async def send_ping(websocket: WebSocket):
    """Send ping to check connection health."""
    message = WebSocketMessage(
        type="ping",
        payload={"server_time": datetime.utcnow().isoformat()}
    )
    await websocket.send_json(message.model_dump())

Error:

def create_error_message(
    code: str,
    message: str,
    details: dict | None = None
) -> WebSocketMessage:
    """Create error message."""
    return WebSocketMessage(
        type="error",
        payload={
            "code": code,
            "message": message,
            **(details or {})
        }
    )

Pattern: Client → Server Messages

Pong Response:

async def handle_pong(websocket: WebSocket, user_id: UUID, payload: dict):
    """Handle pong response from client."""
    client_time = payload.get("client_time")
    # Update last activity timestamp
    manager.update_activity(websocket, user_id)

Subscribe:

async def handle_subscribe(websocket: WebSocket, user_id: UUID, payload: dict):
    """Handle subscription request."""
    feeds = payload.get("feeds", [])

    for feed in feeds:
        # Validate feed format and permissions
        if not validate_feed_permission(user_id, feed):
            await websocket.send_json(
                create_error_message(
                    "SUBSCRIPTION_DENIED",
                    f"You do not have permission to subscribe to {feed}",
                    {"feed": feed}
                ).model_dump()
            )
            continue

        manager.subscribe(websocket, user_id, feed)
        await websocket.send_json(
            WebSocketMessage(
                type="subscribed",
                payload={"feed": feed, "success": True}
            ).model_dump()
        )


def validate_feed_permission(user_id: UUID, feed: str) -> bool:
    """Check if user can subscribe to feed."""
    user = get_user(user_id)

    if feed == "collectors:status":
        return user.role == "admin"

    if feed.startswith("alerts:"):
        watchlist_id = feed.split(":")[1]
        if watchlist_id == "*":
            return True  # All users can subscribe to their alerts
        return user_has_watchlist_access(user_id, watchlist_id)

    if feed == "detections:live":
        return user.role in ("admin", "operator")

    return False

Unsubscribe:

async def handle_unsubscribe(websocket: WebSocket, user_id: UUID, payload: dict):
    """Handle unsubscription request."""
    feeds = payload.get("feeds", [])

    for feed in feeds:
        manager.unsubscribe(websocket, user_id, feed)

Pattern: Connection Manager

Problem: Need to track connections and route messages to appropriate clients.

Solution: Centralized manager with subscription routing.

Implementation:

# central/src/services/websocket_manager.py
from fastapi import WebSocket
from collections import defaultdict
from uuid import UUID
import asyncio

class WebSocketManager:
    """Manage WebSocket connections and message routing."""

    def __init__(self):
        # user_id -> list of websockets
        self.connections: dict[UUID, list[WebSocket]] = defaultdict(list)
        # websocket -> set of subscribed feeds
        self.subscriptions: dict[WebSocket, set[str]] = defaultdict(set)
        # websocket -> last activity timestamp
        self.last_activity: dict[WebSocket, datetime] = {}

    async def connect(self, websocket: WebSocket, user_id: UUID):
        """Accept new connection."""
        await websocket.accept()
        self.connections[user_id].append(websocket)
        self.last_activity[websocket] = datetime.utcnow()

        # Default subscriptions
        self.subscriptions[websocket].add("alerts:*")

    def disconnect(self, websocket: WebSocket, user_id: UUID):
        """Remove connection."""
        if websocket in self.connections[user_id]:
            self.connections[user_id].remove(websocket)
        self.subscriptions.pop(websocket, None)
        self.last_activity.pop(websocket, None)

    def subscribe(self, websocket: WebSocket, user_id: UUID, feed: str):
        """Add feed subscription."""
        self.subscriptions[websocket].add(feed)

    def unsubscribe(self, websocket: WebSocket, user_id: UUID, feed: str):
        """Remove feed subscription."""
        self.subscriptions[websocket].discard(feed)

    async def broadcast_to_feed(self, feed: str, message: WebSocketMessage):
        """Send message to all subscribers of a feed."""
        for websocket, feeds in self.subscriptions.items():
            if self._matches_feed(feed, feeds):
                try:
                    await websocket.send_json(message.model_dump())
                except Exception:
                    # Connection may be closed; cleanup happens elsewhere
                    pass

    async def send_to_user(self, user_id: UUID, message: WebSocketMessage):
        """Send message to all connections for a user."""
        for websocket in self.connections.get(user_id, []):
            try:
                await websocket.send_json(message.model_dump())
            except Exception:
                pass

    def _matches_feed(self, target: str, subscriptions: set[str]) -> bool:
        """Check if target feed matches any subscription."""
        if target in subscriptions:
            return True

        # Wildcard matching: "alerts:*" matches "alerts:wl_123"
        for sub in subscriptions:
            if sub.endswith(":*"):
                prefix = sub[:-1]  # "alerts:"
                if target.startswith(prefix):
                    return True

        return False

    def update_activity(self, websocket: WebSocket, user_id: UUID):
        """Update last activity timestamp."""
        self.last_activity[websocket] = datetime.utcnow()


# Singleton instance
manager = WebSocketManager()

Pattern: Reconnection Protocol

Problem: Clients lose connection; need seamless recovery without missing messages.

Solution: Client-side reconnection with exponential backoff and message recovery.

Client Implementation (JavaScript):

class WebSocketClient {
    constructor(url) {
        this.url = url;
        this.ws = null;
        this.reconnectAttempts = 0;
        this.maxReconnectDelay = 60000; // 60 seconds
        this.subscriptions = new Set(['alerts:*']);
        this.lastMessageTimestamp = null;
        this.messageHandlers = new Map();
    }

    connect() {
        this.ws = new WebSocket(this.url);

        this.ws.onopen = () => {
            console.log('WebSocket connected');
            this.reconnectAttempts = 0;

            // Re-subscribe to feeds
            this.ws.send(JSON.stringify({
                type: 'subscribe',
                payload: { feeds: Array.from(this.subscriptions) }
            }));

            // Fetch missed messages
            this.fetchMissedMessages();
        };

        this.ws.onmessage = (event) => {
            const message = JSON.parse(event.data);
            this.lastMessageTimestamp = message.timestamp;
            this.handleMessage(message);
        };

        this.ws.onclose = () => {
            console.log('WebSocket disconnected');
            this.scheduleReconnect();
        };

        this.ws.onerror = (error) => {
            console.error('WebSocket error:', error);
        };
    }

    scheduleReconnect() {
        // Exponential backoff with jitter
        const baseDelay = Math.min(
            1000 * Math.pow(2, this.reconnectAttempts),
            this.maxReconnectDelay
        );
        const jitter = Math.random() * 1000;
        const delay = baseDelay + jitter;

        console.log(`Reconnecting in ${delay}ms...`);

        setTimeout(() => {
            this.reconnectAttempts++;
            this.connect();
        }, delay);
    }

    async fetchMissedMessages() {
        if (!this.lastMessageTimestamp) return;

        try {
            const response = await fetch(
                `/api/v1/alerts?since=${this.lastMessageTimestamp}`
            );
            const alerts = await response.json();

            for (const alert of alerts.data) {
                this.handleMessage({
                    type: 'alert',
                    payload: alert,
                    timestamp: alert.created_at,
                    id: `recovered_${alert.id}`
                });
            }
        } catch (error) {
            console.error('Failed to fetch missed messages:', error);
        }
    }

    handleMessage(message) {
        const handler = this.messageHandlers.get(message.type);
        if (handler) {
            handler(message.payload);
        }

        // Handle ping
        if (message.type === 'ping') {
            this.ws.send(JSON.stringify({
                type: 'pong',
                payload: { client_time: new Date().toISOString() }
            }));
        }
    }

    on(type, handler) {
        this.messageHandlers.set(type, handler);
    }

    subscribe(feed) {
        this.subscriptions.add(feed);
        if (this.ws?.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify({
                type: 'subscribe',
                payload: { feeds: [feed] }
            }));
        }
    }

    unsubscribe(feed) {
        this.subscriptions.delete(feed);
        if (this.ws?.readyState === WebSocket.OPEN) {
            this.ws.send(JSON.stringify({
                type: 'unsubscribe',
                payload: { feeds: [feed] }
            }));
        }
    }
}

// Usage
const client = new WebSocketClient('wss://server.example.com/ws');
client.on('alert', (payload) => showAlertNotification(payload));
client.on('detection', (payload) => updateLiveFeed(payload));
client.connect();

Pattern: Rate Limiting

Problem: Prevent clients from flooding server with messages.

Solution: Per-connection rate limiting with graceful rejection.

Implementation:

from collections import defaultdict
from datetime import datetime, timedelta

class RateLimiter:
    """Rate limit WebSocket messages per connection."""

    def __init__(self, max_messages: int = 10, window_seconds: int = 1):
        self.max_messages = max_messages
        self.window = timedelta(seconds=window_seconds)
        self.message_times: dict[WebSocket, list[datetime]] = defaultdict(list)

    def check(self, websocket: WebSocket) -> bool:
        """
        Check if message is allowed.

        Returns:
            True if allowed, False if rate limited
        """
        now = datetime.utcnow()
        cutoff = now - self.window

        # Remove old timestamps
        self.message_times[websocket] = [
            t for t in self.message_times[websocket]
            if t > cutoff
        ]

        # Check limit
        if len(self.message_times[websocket]) >= self.max_messages:
            return False

        self.message_times[websocket].append(now)
        return True


rate_limiter = RateLimiter(max_messages=10, window_seconds=1)

async def handle_message(websocket: WebSocket, user_id: UUID, data: dict):
    """Handle incoming message with rate limiting."""
    if not rate_limiter.check(websocket):
        await websocket.send_json(
            create_error_message(
                "RATE_LIMITED",
                "Too many messages. Please slow down."
            ).model_dump()
        )
        return

    # Process message...

Pattern: Health Check (Ping/Pong)

Problem: Detect dead connections that haven't properly closed.

Solution: Periodic ping from server, disconnect on missed pongs.

Implementation:

import asyncio

class ConnectionHealthChecker:
    """Monitor connection health via ping/pong."""

    def __init__(
        self,
        manager: WebSocketManager,
        ping_interval: int = 30,
        max_missed_pongs: int = 3
    ):
        self.manager = manager
        self.ping_interval = ping_interval
        self.max_missed_pongs = max_missed_pongs
        self.missed_pongs: dict[WebSocket, int] = defaultdict(int)

    async def run(self):
        """Run health check loop."""
        while True:
            await asyncio.sleep(self.ping_interval)
            await self.check_all_connections()

    async def check_all_connections(self):
        """Ping all connections and check for dead ones."""
        for user_id, websockets in list(self.manager.connections.items()):
            for websocket in list(websockets):
                # Check if connection is stale
                last_activity = self.manager.last_activity.get(websocket)
                if last_activity:
                    stale_threshold = datetime.utcnow() - timedelta(
                        seconds=self.ping_interval * self.max_missed_pongs
                    )
                    if last_activity < stale_threshold:
                        # Connection is dead
                        self.missed_pongs[websocket] += 1

                        if self.missed_pongs[websocket] >= self.max_missed_pongs:
                            await self.close_dead_connection(websocket, user_id)
                            continue

                # Send ping
                try:
                    await send_ping(websocket)
                except Exception:
                    await self.close_dead_connection(websocket, user_id)

    async def close_dead_connection(self, websocket: WebSocket, user_id: UUID):
        """Close and cleanup dead connection."""
        try:
            await websocket.close()
        except Exception:
            pass
        self.manager.disconnect(websocket, user_id)
        self.missed_pongs.pop(websocket, None)

Error Codes

Code Description Client Action
AUTH_FAILED Authentication failed or session expired Re-authenticate
AUTH_REQUIRED Message sent before auth complete Wait for connection
SUBSCRIPTION_DENIED User lacks permission for feed Don't retry
INVALID_MESSAGE Malformed JSON or missing fields Fix message
RATE_LIMITED Too many messages Slow down
INTERNAL_ERROR Server-side error Retry later

Testing Strategies

import pytest
from fastapi.testclient import TestClient
from fastapi.websockets import WebSocket

@pytest.fixture
def ws_client(client, user_token):
    """WebSocket test client with authentication."""
    with client.websocket_connect(
        "/ws",
        cookies={"session": user_token}
    ) as websocket:
        yield websocket

async def test_subscribe_to_alerts(ws_client):
    """Test subscribing to alert feed."""
    ws_client.send_json({
        "type": "subscribe",
        "payload": {"feeds": ["alerts:*"]}
    })

    response = ws_client.receive_json()
    assert response["type"] == "subscribed"
    assert response["payload"]["success"] is True

async def test_rate_limiting(ws_client):
    """Test that rate limiting kicks in."""
    # Send 15 messages rapidly (limit is 10)
    for i in range(15):
        ws_client.send_json({"type": "pong", "payload": {}})

    # Should receive rate limit error
    responses = []
    for _ in range(15):
        try:
            responses.append(ws_client.receive_json(timeout=0.1))
        except:
            break

    rate_limited = [r for r in responses if r.get("type") == "error"]
    assert len(rate_limited) > 0
    assert rate_limited[0]["payload"]["code"] == "RATE_LIMITED"


Maintainer: Development Team Review Cycle: Quarterly