WebsocketEndpoint.java

  1. /*
  2.  * The coLAB project
  3.  * Copyright (C) 2021-2023 AlbaSim, MEI, HEIG-VD, HES-SO
  4.  *
  5.  * Licensed under the MIT License
  6.  */
  7. package ch.colabproject.colab.api.ws;

  8. import ch.colabproject.colab.api.Helper;
  9. import ch.colabproject.colab.api.controller.WebsocketManager;
  10. import ch.colabproject.colab.api.ws.message.WsMessage;
  11. import ch.colabproject.colab.api.ws.message.WsPing;
  12. import ch.colabproject.colab.api.ws.message.WsPong;
  13. import ch.colabproject.colab.api.ws.message.WsSessionIdentifier;
  14. import ch.colabproject.colab.api.ws.utils.JsonDecoder;
  15. import ch.colabproject.colab.api.ws.utils.JsonEncoder;
  16. import ch.colabproject.colab.api.ws.utils.JsonWsMessageListDecoder;
  17. import java.io.IOException;
  18. import java.util.Collections;
  19. import java.util.HashMap;
  20. import java.util.HashSet;
  21. import java.util.Map;
  22. import java.util.Set;
  23. import javax.enterprise.context.ApplicationScoped;
  24. import javax.inject.Inject;
  25. import javax.websocket.CloseReason;
  26. import javax.websocket.EncodeException;
  27. import javax.websocket.OnClose;
  28. import javax.websocket.OnError;
  29. import javax.websocket.OnMessage;
  30. import javax.websocket.OnOpen;
  31. import javax.websocket.Session;
  32. import javax.websocket.server.ServerEndpoint;
  33. import org.slf4j.Logger;
  34. import org.slf4j.LoggerFactory;
  35. import com.hazelcast.core.HazelcastInstance;
  36. import com.hazelcast.flakeidgen.FlakeIdGenerator;

  37. /**
  38.  * Websocket endpoint
  39.  *
  40.  * @author maxence
  41.  */
  42. @ApplicationScoped
  43. @ServerEndpoint(value = "/ws", encoders = JsonEncoder.class,
  44.     decoders = { JsonDecoder.class, JsonWsMessageListDecoder.class })
  45. public class WebsocketEndpoint {

  46.     /**
  47.      * To generate cluster-wide unique id
  48.      */
  49.     @Inject
  50.     private HazelcastInstance hzInstance;

  51.     /**
  52.      * Websocket business logic.
  53.      */
  54.     @Inject
  55.     private WebsocketManager websocketManager;

  56.     /**
  57.      * Logger
  58.      */
  59.     private static final Logger logger = LoggerFactory.getLogger(WebsocketEndpoint.class);

  60.     /**
  61.      * Map of active sessions
  62.      */
  63.     private static Set<Session> sessions = Collections.synchronizedSet(new HashSet<>());

  64.     /**
  65.      * Map session to session id
  66.      */
  67.     private static Map<Session, String> sessionToIds = Collections.synchronizedMap(new HashMap<>());

  68.     /**
  69.      * Map session id to sessions
  70.      */
  71.     private static Map<String, Session> idsToSessions = Collections
  72.         .synchronizedMap(new HashMap<>());

  73.     /**
  74.      * Send a message to all clients
  75.      *
  76.      * @param message the message to send
  77.      */
  78.     public static void broadcastMessage(WsMessage message) {
  79.         for (Session session : sessions) {
  80.             try {
  81.                 session.getBasicRemote().sendObject(message);
  82.             } catch (IOException | EncodeException e) {
  83.                 logger.error("Broadcast message exception: {}", e);
  84.             }
  85.         }
  86.     }

  87.     /**
  88.      * On new connection. Send back a WsInitMessage to let the client known it's own sessionId
  89.      *
  90.      * @param session brand new session
  91.      *
  92.      * @throws IOException     if sending the initMessage fails
  93.      * @throws EncodeException if sending the initMessage fails
  94.      */
  95.     @OnOpen
  96.     public void onOpen(Session session) throws IOException, EncodeException {
  97.         logger.info("WebSocket opened: {}", session.getId());
  98.         sessions.add(session);
  99.         FlakeIdGenerator idGenerator = hzInstance.getFlakeIdGenerator("WS_SESSION_ID_GENERATOR");
  100.         String sessionId = "ws-" + Helper.generateHexSalt(8) + idGenerator.newId();
  101.         sessionToIds.put(session, sessionId);
  102.         idsToSessions.put(sessionId, session);
  103.         session.getBasicRemote().sendObject(new WsSessionIdentifier(sessionId));

  104.         long maxIdleTimeout = session.getMaxIdleTimeout();
  105.         logger.trace("Session Timeout: {} ms", maxIdleTimeout);
  106.     }

  107.     /**
  108.      * called when client send messages
  109.      *
  110.      * @param message message received from client
  111.      * @param session client session
  112.      */
  113.     @OnMessage
  114.     public void onMessage(WsMessage message, Session session) {
  115.         logger.trace("Message received: {} from {}", message, session.getId());
  116.         if (message instanceof WsPing) {
  117.             try {
  118.                 session.getBasicRemote().sendObject(new WsPong());
  119.             } catch (IOException | EncodeException ex) {
  120.                 logger.warn("Fail to reply to ping", ex);
  121.             }
  122.         }
  123.     }

  124.     /**
  125.      * TO handle errors. TBD TODO
  126.      *
  127.      * @param session   erroneous session
  128.      * @param throwable error
  129.      */
  130.     @OnError
  131.     public void onError(Session session, Throwable throwable) {
  132.         logger.info("WebSocket error for {} {}", session.getId(), throwable.getMessage());
  133.     }

  134.     /**
  135.      * When a client is leaving
  136.      *
  137.      * @param session     session to close
  138.      * @param closeReason some reason...
  139.      */
  140.     @OnClose
  141.     public void onClose(Session session, CloseReason closeReason) {
  142.         logger.info("WebSocket closed for {} with reason {}: {}",
  143.             session.getId(), closeReason.getCloseCode(), closeReason.getReasonPhrase());
  144.         sessions.remove(session);
  145.         String id = sessionToIds.get(session);
  146.         idsToSessions.remove(id);
  147.         sessionToIds.remove(session);
  148.         websocketManager.unsubscribeFromAll(session, id);
  149.     }

  150.     /**
  151.      * Get session by its id
  152.      *
  153.      * @param sessionId id of the session
  154.      *
  155.      * @return the session or null
  156.      */
  157.     public static Session getSession(String sessionId) {
  158.         return idsToSessions.get(sessionId);
  159.     }

  160.     /**
  161.      * Get id by session
  162.      *
  163.      * @param session the session
  164.      *
  165.      * @return sessionid or null
  166.      */
  167.     public static String getSessionId(Session session) {
  168.         return sessionToIds.get(session);
  169.     }
  170. }