WebsocketEndpoint.java

/*
 * The coLAB project
 * Copyright (C) 2021-2023 AlbaSim, MEI, HEIG-VD, HES-SO
 *
 * Licensed under the MIT License
 */
package ch.colabproject.colab.api.ws;

import ch.colabproject.colab.api.Helper;
import ch.colabproject.colab.api.controller.WebsocketManager;
import ch.colabproject.colab.api.ws.message.WsMessage;
import ch.colabproject.colab.api.ws.message.WsPing;
import ch.colabproject.colab.api.ws.message.WsPong;
import ch.colabproject.colab.api.ws.message.WsSessionIdentifier;
import ch.colabproject.colab.api.ws.utils.JsonDecoder;
import ch.colabproject.colab.api.ws.utils.JsonEncoder;
import ch.colabproject.colab.api.ws.utils.JsonWsMessageListDecoder;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import javax.websocket.CloseReason;
import javax.websocket.EncodeException;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.flakeidgen.FlakeIdGenerator;

/**
 * Websocket endpoint
 *
 * @author maxence
 */
@ApplicationScoped
@ServerEndpoint(value = "/ws", encoders = JsonEncoder.class,
    decoders = { JsonDecoder.class, JsonWsMessageListDecoder.class })
public class WebsocketEndpoint {

    /**
     * To generate cluster-wide unique id
     */
    @Inject
    private HazelcastInstance hzInstance;

    /**
     * Websocket business logic.
     */
    @Inject
    private WebsocketManager websocketManager;

    /**
     * Logger
     */
    private static final Logger logger = LoggerFactory.getLogger(WebsocketEndpoint.class);

    /**
     * Map of active sessions
     */
    private static Set<Session> sessions = Collections.synchronizedSet(new HashSet<>());

    /**
     * Map session to session id
     */
    private static Map<Session, String> sessionToIds = Collections.synchronizedMap(new HashMap<>());

    /**
     * Map session id to sessions
     */
    private static Map<String, Session> idsToSessions = Collections
        .synchronizedMap(new HashMap<>());

    /**
     * Send a message to all clients
     *
     * @param message the message to send
     */
    public static void broadcastMessage(WsMessage message) {
        for (Session session : sessions) {
            try {
                session.getBasicRemote().sendObject(message);
            } catch (IOException | EncodeException e) {
                logger.error("Broadcast message exception: {}", e);
            }
        }
    }

    /**
     * On new connection. Send back a WsInitMessage to let the client known it's own sessionId
     *
     * @param session brand new session
     *
     * @throws IOException     if sending the initMessage fails
     * @throws EncodeException if sending the initMessage fails
     */
    @OnOpen
    public void onOpen(Session session) throws IOException, EncodeException {
        logger.info("WebSocket opened: {}", session.getId());
        sessions.add(session);
        FlakeIdGenerator idGenerator = hzInstance.getFlakeIdGenerator("WS_SESSION_ID_GENERATOR");
        String sessionId = "ws-" + Helper.generateHexSalt(8) + idGenerator.newId();
        sessionToIds.put(session, sessionId);
        idsToSessions.put(sessionId, session);
        session.getBasicRemote().sendObject(new WsSessionIdentifier(sessionId));

        long maxIdleTimeout = session.getMaxIdleTimeout();
        logger.trace("Session Timeout: {} ms", maxIdleTimeout);
    }

    /**
     * called when client send messages
     *
     * @param message message received from client
     * @param session client session
     */
    @OnMessage
    public void onMessage(WsMessage message, Session session) {
        logger.trace("Message received: {} from {}", message, session.getId());
        if (message instanceof WsPing) {
            try {
                session.getBasicRemote().sendObject(new WsPong());
            } catch (IOException | EncodeException ex) {
                logger.warn("Fail to reply to ping", ex);
            }
        }
    }

    /**
     * TO handle errors. TBD TODO
     *
     * @param session   erroneous session
     * @param throwable error
     */
    @OnError
    public void onError(Session session, Throwable throwable) {
        logger.info("WebSocket error for {} {}", session.getId(), throwable.getMessage());
    }

    /**
     * When a client is leaving
     *
     * @param session     session to close
     * @param closeReason some reason...
     */
    @OnClose
    public void onClose(Session session, CloseReason closeReason) {
        logger.info("WebSocket closed for {} with reason {}: {}",
            session.getId(), closeReason.getCloseCode(), closeReason.getReasonPhrase());
        sessions.remove(session);
        String id = sessionToIds.get(session);
        idsToSessions.remove(id);
        sessionToIds.remove(session);
        websocketManager.unsubscribeFromAll(session, id);
    }

    /**
     * Get session by its id
     *
     * @param sessionId id of the session
     *
     * @return the session or null
     */
    public static Session getSession(String sessionId) {
        return idsToSessions.get(sessionId);
    }

    /**
     * Get id by session
     *
     * @param session the session
     *
     * @return sessionid or null
     */
    public static String getSessionId(Session session) {
        return sessionToIds.get(session);
    }
}