WebsocketEndpoint.java
/*
* The coLAB project
* Copyright (C) 2021-2023 AlbaSim, MEI, HEIG-VD, HES-SO
*
* Licensed under the MIT License
*/
package ch.colabproject.colab.api.ws;
import ch.colabproject.colab.api.Helper;
import ch.colabproject.colab.api.controller.WebsocketManager;
import ch.colabproject.colab.api.ws.message.WsMessage;
import ch.colabproject.colab.api.ws.message.WsPing;
import ch.colabproject.colab.api.ws.message.WsPong;
import ch.colabproject.colab.api.ws.message.WsSessionIdentifier;
import ch.colabproject.colab.api.ws.utils.JsonDecoder;
import ch.colabproject.colab.api.ws.utils.JsonEncoder;
import ch.colabproject.colab.api.ws.utils.JsonWsMessageListDecoder;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import javax.enterprise.context.ApplicationScoped;
import javax.inject.Inject;
import javax.websocket.CloseReason;
import javax.websocket.EncodeException;
import javax.websocket.OnClose;
import javax.websocket.OnError;
import javax.websocket.OnMessage;
import javax.websocket.OnOpen;
import javax.websocket.Session;
import javax.websocket.server.ServerEndpoint;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.flakeidgen.FlakeIdGenerator;
/**
* Websocket endpoint
*
* @author maxence
*/
@ApplicationScoped
@ServerEndpoint(value = "/ws", encoders = JsonEncoder.class,
decoders = { JsonDecoder.class, JsonWsMessageListDecoder.class })
public class WebsocketEndpoint {
/**
* To generate cluster-wide unique id
*/
@Inject
private HazelcastInstance hzInstance;
/**
* Websocket business logic.
*/
@Inject
private WebsocketManager websocketManager;
/**
* Logger
*/
private static final Logger logger = LoggerFactory.getLogger(WebsocketEndpoint.class);
/**
* Map of active sessions
*/
private static Set<Session> sessions = Collections.synchronizedSet(new HashSet<>());
/**
* Map session to session id
*/
private static Map<Session, String> sessionToIds = Collections.synchronizedMap(new HashMap<>());
/**
* Map session id to sessions
*/
private static Map<String, Session> idsToSessions = Collections
.synchronizedMap(new HashMap<>());
/**
* Send a message to all clients
*
* @param message the message to send
*/
public static void broadcastMessage(WsMessage message) {
for (Session session : sessions) {
try {
session.getBasicRemote().sendObject(message);
} catch (IOException | EncodeException e) {
logger.error("Broadcast message exception: {}", e);
}
}
}
/**
* On new connection. Send back a WsInitMessage to let the client known it's own sessionId
*
* @param session brand new session
*
* @throws IOException if sending the initMessage fails
* @throws EncodeException if sending the initMessage fails
*/
@OnOpen
public void onOpen(Session session) throws IOException, EncodeException {
logger.info("WebSocket opened: {}", session.getId());
sessions.add(session);
FlakeIdGenerator idGenerator = hzInstance.getFlakeIdGenerator("WS_SESSION_ID_GENERATOR");
String sessionId = "ws-" + Helper.generateHexSalt(8) + idGenerator.newId();
sessionToIds.put(session, sessionId);
idsToSessions.put(sessionId, session);
session.getBasicRemote().sendObject(new WsSessionIdentifier(sessionId));
long maxIdleTimeout = session.getMaxIdleTimeout();
logger.trace("Session Timeout: {} ms", maxIdleTimeout);
}
/**
* called when client send messages
*
* @param message message received from client
* @param session client session
*/
@OnMessage
public void onMessage(WsMessage message, Session session) {
logger.trace("Message received: {} from {}", message, session.getId());
if (message instanceof WsPing) {
try {
session.getBasicRemote().sendObject(new WsPong());
} catch (IOException | EncodeException ex) {
logger.warn("Fail to reply to ping", ex);
}
}
}
/**
* TO handle errors. TBD TODO
*
* @param session erroneous session
* @param throwable error
*/
@OnError
public void onError(Session session, Throwable throwable) {
logger.info("WebSocket error for {} {}", session.getId(), throwable.getMessage());
}
/**
* When a client is leaving
*
* @param session session to close
* @param closeReason some reason...
*/
@OnClose
public void onClose(Session session, CloseReason closeReason) {
logger.info("WebSocket closed for {} with reason {}: {}",
session.getId(), closeReason.getCloseCode(), closeReason.getReasonPhrase());
sessions.remove(session);
String id = sessionToIds.get(session);
idsToSessions.remove(id);
sessionToIds.remove(session);
websocketManager.unsubscribeFromAll(session, id);
}
/**
* Get session by its id
*
* @param sessionId id of the session
*
* @return the session or null
*/
public static Session getSession(String sessionId) {
return idsToSessions.get(sessionId);
}
/**
* Get id by session
*
* @param session the session
*
* @return sessionid or null
*/
public static String getSessionId(Session session) {
return sessionToIds.get(session);
}
}