WebsocketManager.java

/*
 * The coLAB project
 * Copyright (C) 2021-2023 AlbaSim, MEI, HEIG-VD, HES-SO
 *
 * Licensed under the MIT License
 */
package ch.colabproject.colab.api.controller;

import ch.colabproject.colab.api.controller.document.BlockManager;
import ch.colabproject.colab.api.controller.project.ProjectManager;
import ch.colabproject.colab.api.controller.security.SecurityManager;
import ch.colabproject.colab.api.model.document.TextDataBlock;
import ch.colabproject.colab.api.model.project.Project;
import ch.colabproject.colab.api.model.user.HttpSession;
import ch.colabproject.colab.api.model.user.User;
import ch.colabproject.colab.api.persistence.jpa.card.CardTypeDao;
import ch.colabproject.colab.api.persistence.jpa.project.ProjectDao;
import ch.colabproject.colab.api.persistence.jpa.team.TeamMemberDao;
import ch.colabproject.colab.api.persistence.jpa.user.UserDao;
import ch.colabproject.colab.api.presence.PresenceManager;
import ch.colabproject.colab.api.presence.model.TouchUserPresence;
import ch.colabproject.colab.api.security.permissions.Conditions;
import ch.colabproject.colab.api.ws.WebsocketEndpoint;
import ch.colabproject.colab.api.ws.WebsocketMessagePreparer;
import ch.colabproject.colab.api.ws.channel.model.BlockChannel;
import ch.colabproject.colab.api.ws.channel.model.BroadcastChannel;
import ch.colabproject.colab.api.ws.channel.model.ProjectContentChannel;
import ch.colabproject.colab.api.ws.channel.model.UserChannel;
import ch.colabproject.colab.api.ws.channel.model.WebsocketChannel;
import ch.colabproject.colab.api.ws.message.PrecomputedWsMessages;
import ch.colabproject.colab.api.ws.message.WsChannelUpdate;
import ch.colabproject.colab.api.ws.message.WsSessionIdentifier;
import ch.colabproject.colab.api.ws.message.WsSignOutMessage;
import ch.colabproject.colab.api.ws.utils.CallableGetChannel;
import ch.colabproject.colab.generator.model.exceptions.HttpErrorMessage;
import fish.payara.micro.cdi.Inbound;
import fish.payara.micro.cdi.Outbound;
import java.io.IOException;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors;
import javax.ejb.Lock;
import javax.ejb.LockType;
import javax.ejb.Singleton;
import javax.ejb.Startup;
import javax.enterprise.event.Event;
import javax.enterprise.event.Observes;
import javax.inject.Inject;
import javax.websocket.EncodeException;
import javax.websocket.Session;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.hazelcast.cluster.Member;
import com.hazelcast.core.HazelcastInstance;
import com.hazelcast.core.IExecutorService;

/**
 * Some methods to handle websocket connections. In the future, we may consider using external
 * services (eg Pusher) to delegate such thing. We may want to challenge this implementation in a
 * real "production-like" env.
 *
 * @author maxence
 */
@Singleton
// make sure the singleton is available as soon as possible with @Startup annotation
@Startup
// Since few methods actually need mutual exclusion, make sure to set default locktype to READ
// mutual exclusion is managed by hand in methods with synchronized blocks
@Lock(LockType.READ)
public class WebsocketManager {

    /** logger */
    private static final Logger logger = LoggerFactory.getLogger(WebsocketManager.class);

    /**
     * Subscription event name.
     */
    private static final String WS_SUBSCRIPTION_EVENT_CHANNEL = "WS_SUBSCRIPTION_CHANNEL";

    /**
     * Hazelcast instance.
     */
    @Inject
    private HazelcastInstance hzInstance;

    /**
     * Instance which receive the REST subscription request may not be the same as the one which
     * owns the websocket connection. Use cluster event to delegate the processing to the correct
     * instance.
     */
    @Inject
    @Outbound(eventName = WS_SUBSCRIPTION_EVENT_CHANNEL, loopBack = true)
    private Event<SubscriptionRequest> subscriptionEvents;

    /**
     * Cluster-wide propagation event channel
     */
    private static final String WS_PROPAGATION_CHANNEL = "WS_PROPAGATION_CHANNEL";

    /**
     * In order to propagate changes to everyone. Each cluster instance must propagate changes to
     * sessions the instance is in change. This is the hz event to request such a propagation.
     */
    @Inject
    @Outbound(eventName = WS_PROPAGATION_CHANNEL, loopBack = true)
    private Event<PrecomputedWsMessages> messagePropagation;

    /**
     * Access control
     */
    @Inject
    private SecurityManager securityManager;

    /**
     * Request sidekick
     */
    @Inject
    private RequestManager requestManager;

    /**
     * Project specific logic management
     */
    @Inject
    private ProjectManager projectManager;

    /**
     * Block specific logic management
     */
    @Inject
    private BlockManager blockManager;

    /**
     * User persistence handler
     */
    @Inject
    private UserDao userDao;

    /**
     * Team members persistence handler
     */
    @Inject
    private TeamMemberDao teamMemberDao;

    /**
     * Card type persistence handler
     */
    @Inject
    private CardTypeDao cardTypeDao;

    /**
     * Project persistence handler
     */
    @Inject
    private ProjectDao projectDao;

    /**
     * Presence Manager
     */
    @Inject
    private PresenceManager presenceManager;

    /** to propagate changes */
    @Inject
    private EntityGatheringBagForPropagation bag;

    /**
     * channel subscriptions.
     */
    private ConcurrentMap<WebsocketChannel, Set<Session>> subscriptions = new ConcurrentHashMap<>();

    /**
     * HTTP sessions to websocket sessions registry.
     * <p />
     * Since one can open several tabs in its browser, all tabs share the same httpSessionId
     * (cookie) but each has its own websocket session, we have to maintain such a map.
     * <p />
     * This is required cancel all subscription on logout.
     */
    private ConcurrentMap<Long, Set<Session>> httpSessionToWsSessions = new ConcurrentHashMap<>();

//    /**
//     * Each websocket session is linked to one http session.
//     * <p>
//     * This is the {@link httpSessionToWsSessions} reverse registry.
//     */
//    private Map<Session, Long> wsSessionToHttpSession = new HashMap<>();
    /**
     * List the channels each websocket session subscribe to.
     */
    private ConcurrentMap<Session, Set<WebsocketChannel>> wsSessionMap = new ConcurrentHashMap<>();

    /**
     * Get list of all occupied channels.
     * <p>
     * This method is cluster-aware. In short, {@link #getSubscriptionsCount() } will be called for
     * each instance of the cluster.
     *
     * @return list of all occupied channels URN mapped to the number of subscribers
     */
    public Map<String, Integer> getExistingChannels() {
        IExecutorService executorService = hzInstance.getExecutorService("COLAB_WS");
        Map<Member, Future<Map<String, Integer>>> results = executorService
            .submitToAllMembers(new CallableGetChannel());

        Map<String, Integer> map = new HashMap<>();

        results.values().stream()
            .flatMap((result) -> {
                try {
                    return result.get(5, TimeUnit.SECONDS).entrySet().stream();
                } catch (InterruptedException | ExecutionException | TimeoutException ex) {
                    return null;
                }
            })
            .forEach(entry -> {
                String key = entry.getKey();
                var currentCount = map.get(entry.getKey());
                if (currentCount != null) {
                    map.put(key, currentCount + entry.getValue());
                } else {
                    map.put(key, entry.getValue());
                }
            });

        return map;
    }

    /**
     * Get the list of occupied channels this instance is in charge.
     *
     * @return the list of channels and the number of sessions subscribed to each of them
     */
    public Map<String, Integer> getSubscriptionsCount() {
        return this.subscriptions.entrySet().stream()
            .collect(Collectors.toMap(entry -> entry.getKey().getUrn(),
                entry -> entry.getValue().size()));
    }

    /**
     * Current user wants to subscribe the broadcast channel
     *
     * @param sessionId websocket session identifier
     */
    public void subscribeToBroadcastChannel(WsSessionIdentifier sessionId) {
        logger.debug("Session {} want to subscribe to broadcast channel", sessionId);
        SubscriptionRequest request = SubscriptionRequest.build(
            SubscriptionRequest.SubscriptionType.SUBSCRIBE,
            SubscriptionRequest.ChannelType.BROADCAST,
            0L,
            sessionId.getSessionId(),
            requestManager.getAndAssertHttpSession().getId());
        subscriptionEvents.fire(request);
    }

    /**
     * Current user wants to unsubscribe from the broadcast channel
     *
     * @param sessionId websocket session identifier
     */
    public void unsubscribeFromBroadcastChannel(WsSessionIdentifier sessionId) {
        logger.debug("Session {} want to unsubscribe from the broadcast channel", sessionId);
        // assert current user is authenticated
        securityManager.assertAndGetCurrentUser();
        SubscriptionRequest request = SubscriptionRequest.build(
            SubscriptionRequest.SubscriptionType.UNSUBSCRIBE,
            SubscriptionRequest.ChannelType.BROADCAST,
            0L,
            sessionId.getSessionId(),
            requestManager.getAndAssertHttpSession().getId());
        subscriptionEvents.fire(request);
    }

    /**
     * Current user wants to subscribe to its own channel.
     *
     * @param sessionId websocket session identifier
     */
    public void subscribeToUserChannel(WsSessionIdentifier sessionId) {
        logger.debug("Session {} want to subscribe to its UserChannel", sessionId);
        User user = securityManager.assertAndGetCurrentUser();
        SubscriptionRequest request = SubscriptionRequest.build(
            SubscriptionRequest.SubscriptionType.SUBSCRIBE,
            SubscriptionRequest.ChannelType.USER,
            user.getId(),
            sessionId.getSessionId(),
            requestManager.getAndAssertHttpSession().getId());
        subscriptionEvents.fire(request);
    }

    /**
     * Current user wants to unsubscribe from its own channel.
     *
     * @param sessionId websocket session identifier
     */
    public void unsubscribeFromUserChannel(WsSessionIdentifier sessionId) {
        logger.debug("Session {} want to unsubscribe from its UserChannel", sessionId);
        User user = securityManager.assertAndGetCurrentUser();
        SubscriptionRequest request = SubscriptionRequest.build(
            SubscriptionRequest.SubscriptionType.UNSUBSCRIBE,
            SubscriptionRequest.ChannelType.USER,
            user.getId(),
            sessionId.getSessionId(),
            requestManager.getAndAssertHttpSession().getId());
        subscriptionEvents.fire(request);
    }

    /**
     * Current user wants to subscribe to a project channel.
     *
     * @param sessionId websocket session identifier
     * @param projectId id of the project
     *
     * @throws HttpErrorMessage notFound if the project does not exist
     */
    public void subscribeToProjectChannel(WsSessionIdentifier sessionId, Long projectId) {
        logger.debug("Session {} want to subscribe to Project#{}", sessionId, projectId);
        Project project = projectManager.assertAndGetProject(projectId);
        if (project != null) {
            securityManager.assertConditionTx(new Conditions.IsCurrentUserMemberOfProject(project),
                "Subscribe to project channel: Permision denied");
            SubscriptionRequest request = SubscriptionRequest.build(
                SubscriptionRequest.SubscriptionType.SUBSCRIBE,
                SubscriptionRequest.ChannelType.PROJECT,
                project.getId(),
                sessionId.getSessionId(),
                requestManager.getAndAssertHttpSession().getId());
            subscriptionEvents.fire(request);

            // Register user presence
            TouchUserPresence touch = new TouchUserPresence();
            touch.setProjectId(projectId);
            touch.setWsSessionId(sessionId.getSessionId());
            presenceManager.updateUserPresence(touch);
        } else {
            throw HttpErrorMessage.notFound();
        }
    }

    /**
     * Current user wants to unsubscribe from a project channel.
     *
     * @param sessionId websocket session identifier
     * @param projectId id of the project
     *
     * @throws HttpErrorMessage notFound if the project does not exist
     */
    public void unsubscribeFromProjectChannel(WsSessionIdentifier sessionId, Long projectId) {
        logger.debug("Session {} want to unsubscribe from Project#{}", sessionId, projectId);
        Project project = projectManager.assertAndGetProject(projectId);
        if (project != null) {
            securityManager.assertConditionTx(new Conditions.IsCurrentUserMemberOfProject(project),
                "Subscribe to project channel: Permision denied");
            SubscriptionRequest request = SubscriptionRequest.build(
                SubscriptionRequest.SubscriptionType.UNSUBSCRIBE,
                SubscriptionRequest.ChannelType.PROJECT,
                project.getId(),
                sessionId.getSessionId(),
                requestManager.getAndAssertHttpSession().getId());
            subscriptionEvents.fire(request);

            // user is not present any longer
            presenceManager.clearWsSession(projectId, sessionId.getSessionId());
        } else {
            throw HttpErrorMessage.notFound();
        }
    }

    /**
     * Current user wants to subscribe to a block channel.
     *
     * @param sessionId websocket session identifier
     * @param blockId   id of the block
     *
     * @throws HttpErrorMessage notFound if the block does not exist
     */
    public void subscribeToBlockChannel(WsSessionIdentifier sessionId, Long blockId) {
        logger.debug("Session {} want to subscribe to Block#{}", sessionId, blockId);
        TextDataBlock block = blockManager.findBlock(blockId);
        if (block != null) {
            // no explicit security check : if one can load the block, one can subscribe to its
            // channel
//            securityManager.assertConditionTx(new Conditions.IsCurrentUserMemberOfBlock(block),
//                "Subscribe to block channel: Permission denied");
            SubscriptionRequest request = SubscriptionRequest.build(
                SubscriptionRequest.SubscriptionType.SUBSCRIBE,
                SubscriptionRequest.ChannelType.BLOCK,
                block.getId(),
                sessionId.getSessionId(),
                requestManager.getAndAssertHttpSession().getId());
            subscriptionEvents.fire(request);
        } else {
            throw HttpErrorMessage.notFound();
        }
    }

    /**
     * Current user wants to unsubscribe from a block channel.
     *
     * @param sessionId websocket session identifier
     * @param blockId   id of the block
     */
    public void unsubscribeFromBlockChannel(WsSessionIdentifier sessionId, Long blockId) {
        logger.debug("Session {} want to unsubscribe from Block#{}", sessionId, blockId);
        SubscriptionRequest request = SubscriptionRequest.build(
            SubscriptionRequest.SubscriptionType.UNSUBSCRIBE,
            SubscriptionRequest.ChannelType.BLOCK,
            blockId,
            sessionId.getSessionId(),
            requestManager.getAndAssertHttpSession().getId());
        subscriptionEvents.fire(request);
    }

    /**
     * Process subscription request
     *
     * @param request the subscription request
     */
    public void processSubscription(
        @Observes @Inbound(eventName = WS_SUBSCRIPTION_EVENT_CHANNEL) SubscriptionRequest request) {
        // all security checks have been done before firing the subscription event
        requestManager.sudo(() -> {
            logger.debug("Channel subscription request: {}", request);
            Session session = WebsocketEndpoint.getSession(request.getWsSessionId());
            if (session != null) {

                logger.debug("Process channel subscription request: {}", request);

                // first determine the effective channel
                WebsocketChannel channel = getChannel(request);
                if (channel != null) {
                    synchronized (this) {
                        // make sure the http session has its own set of wsSessions
                        // and make sure the websocket session is linked to the http session
                        httpSessionToWsSessions
                            .computeIfAbsent(request.getColabSessionId(), (key) -> {
                                return new HashSet<>();
                            })
                            // and make sure the websocket session is linked to the http session
                            .add(session);

                        // make sure to link wsSession to its Http session
                        // wsSessionToHttpSession.put(session, request.getColabSessionId()); //
                        // TODO: is it even used ?
                        if (request.getType() == SubscriptionRequest.SubscriptionType.SUBSCRIBE) {
                            // make sure the http session has its own list of channels
                            wsSessionMap.computeIfAbsent(session, (key) -> {
                                return new HashSet<>();
                            })
                                // keep wsSession to channel registry up-to date
                                .add(channel);

                            // subscribe to channel
                            subscribe(channel, session);
                        } else {
                            // Remove the channel from the set of channel linked to the wsSession
                            if (wsSessionMap.containsKey(session)) {
                                Set<WebsocketChannel> channels = wsSessionMap.get(session);
                                channels.remove(channel);
                                if (channels.isEmpty()) {
                                    wsSessionMap.remove(session);
                                }
                            }
                            unsubscribe(channel, Set.of(session));
                        }
                    }
                } else {
                    logger.debug("Failed to resolve {} to an effective channel", request);
                }
            } else {
                logger.debug("Ignore channel subscription: {}", request);
            }
        });
    }

    /**
     * Add the given session to the set identified by the given channel, in the given map.
     *
     * @param channel websocket channel to which we subscribe
     * @param session session to remove from the set
     */
    private void subscribe(WebsocketChannel channel, Session session) {
        if (logger.isDebugEnabled()) {
            String sessionId = WebsocketEndpoint.getSessionId(session);
            logger.debug("Session {} subscribes to {}", sessionId, channel);
        }
        subscriptions.computeIfAbsent(channel, (key -> {
            return new HashSet<>();
        })).add(session);

        // make sure to propagate channelCHange after subscription
        // (the ChannelChange event may be send through this very subscription, eg if an admin
        // is subscribing to its own userChannel)
        this.propagateChannelChange(channel, 1);
    }

    /**
     * Unsubscribe all sessions from given channel. If the channel is empty after the operation, it
     * will be destroyed.
     *
     * @param channel  channel to update
     * @param sessions session to remove from channel
     */
    private void unsubscribe(WebsocketChannel channel, Set<Session> sessions) {
        Set<Session> chSessions = subscriptions.get(channel);
        if (logger.isDebugEnabled()) {
            logger.debug("Sessions {} unsubscribes from {}",
                sessions.stream().map(session -> WebsocketEndpoint.getSessionId(session)
                ), channel);
        }
        if (chSessions != null) {
            int size = chSessions.size();
            chSessions.removeAll(sessions);

            // make sure to propagate change before the unsubscription is effective
            this.propagateChannelChange(channel, chSessions.size() - size);

            if (chSessions.isEmpty()) {
                subscriptions.remove(channel);
            }
        }
    }

    /**
     * Propagate a channel change
     *
     * @param channel the channel
     * @param diff    diff
     */
    private void propagateChannelChange(WebsocketChannel channel, int diff) {
        try {
            PrecomputedWsMessages prepareWsMessage = WebsocketMessagePreparer
                .prepareWsMessageForAdmins(
                    userDao,
                    WsChannelUpdate.build(channel, diff)
                );
            this.propagate(prepareWsMessage);
        } catch (EncodeException ex) {
            logger.error("Faild to propagate channel change :{}", channel);
        }
    }

    /**
     * Determine the EffectiveChannel
     *
     * @param request the request
     *
     * @return the channel which match the request
     */
    private WebsocketChannel getChannel(SubscriptionRequest request) {
        if (request.getChannelType() == SubscriptionRequest.ChannelType.PROJECT) {
            Project project = projectManager.assertAndGetProject(request.getChannelId());
            if (project != null) {
                return ProjectContentChannel.build(project.getId());
            }
        } else if (request.getChannelType() == SubscriptionRequest.ChannelType.BLOCK) {
            TextDataBlock block = blockManager.findBlock(request.getChannelId());
            if (block != null) {
                return BlockChannel.build(block.getId());
            }
        } else if (request.getChannelType() == SubscriptionRequest.ChannelType.USER) {
            User user = userDao.findUser(request.getChannelId());
            if (user != null) {
                return UserChannel.build(user);
            }
        } else if (request.getChannelType() == SubscriptionRequest.ChannelType.BROADCAST) {
            return BroadcastChannel.build();
        }

        // not found...
        return null;
    }

    /**
     * Propagate precomputed message to clients. Actually, this methods, will ask all instances of
     * the hazelcast cluster to propagate the message to session they're in charge. This will call
     * {@link #onMessagePropagation(PrecomputedWsMessages) onMessagePropagation}cluster-wide.
     *
     * @param message precomputed message to propagate.
     */
    public void propagate(PrecomputedWsMessages message) {
        this.messagePropagation.fire(message);
    }

    /**
     * On Hazelcast event. Each instance receive precomputed message.
     *
     * @param payload the messagesByChannels to send to clients though relevant websocket channels
     */
    public void onMessagePropagation(
        @Observes @Inbound(eventName = WS_PROPAGATION_CHANNEL) PrecomputedWsMessages payload) {

        Map<WebsocketChannel, List<String>> messagesByChannels = payload.getMessages();
        if (messagesByChannels != null) {

            Map<Session, List<String>> mappedMessages = new HashMap<>();

            // Group messages by effective websocket session
            messagesByChannels.forEach((channel, messages) -> {
                Set<Session> subscribers = this.subscriptions.get(channel);
                if (subscribers != null) {
                    subscribers.forEach(session -> {
                        List<String> list = mappedMessages.computeIfAbsent(session,
                            (k) -> new LinkedList<>());
                        list.addAll(messages);
                    });
                }
            });

            // send one big message to each session
            mappedMessages.entrySet().forEach(entry -> {
                Session session = entry.getKey();
                String jsonArray = entry.getValue().stream()
                    .collect(Collectors.joining(", ", "[", "]"));
                try {
                    logger.debug("Send {} to {} ({})", jsonArray, session.getId());
                    if (session.isOpen()) {
                        session.getBasicRemote().sendText(jsonArray);
                    }
                } catch (IOException ex) {
                    logger.error("Failed to send websocket message {} to {}",
                        jsonArray, session);
                }
            });
        }
    }

    /**
     * Propagate the logout and unsubscribe from all channels linked to the given http session.
     * Clear the space.
     *
     * @param session the http session which just logged out
     */
    public void signoutAndUnsubscribeFromAll(HttpSession session) {
        synchronized (this) {
            // propagate before removing the channels
            propagateSignOut(session);

            Set<Session> wsSessions = this.httpSessionToWsSessions.get(session.getId());
            if (wsSessions != null) {
                // the http session is linked to one or more websocket session, let's cancel all
                // their
                // subscriptions
                wsSessions.stream()
                    // get channels from each wsSession
                    .map(wsSession -> this.wsSessionMap.get(wsSession))
                    // filter out null channels set
                    .filter(channels -> channels != null)
                    // convert "stream of set of channels" to "stream of channels"
                    .flatMap(Collection::stream)
                    // no need to list same channel twice
                    .distinct()
                    // clean subscriptions
                    .forEach(channel -> {
                        this.unsubscribe(channel, wsSessions);
                    });
            }

            this.httpSessionToWsSessions.remove(session.getId());
        }
    }

    /**
     * Propagate a logout
     *
     * @param httpSession The http session that is now closed
     */
    private void propagateSignOut(HttpSession httpSession) {
        try {
            PrecomputedWsMessages prepareWsMessage = WebsocketMessagePreparer
                .prepareWsMessage(
                    userDao,
                    teamMemberDao,
                    cardTypeDao,
                    projectDao,
                    httpSession.getChannelsBuilder(),
                    new WsSignOutMessage(httpSession));
            this.propagate(prepareWsMessage);
        } catch (EncodeException ex) {
            logger.error("Faild to propagate sign out : {}", httpSession);
        }
    }

    /**
     * Clean subscription on session close
     *
     * @param session   websocket session to clean subscription for
     * @param sessionId public websocket identifier
     */
    public void unsubscribeFromAll(Session session, String sessionId) {
        synchronized (this) {
            Set<WebsocketChannel> set = this.wsSessionMap.get(session);
            if (set != null) {
                Set<Session> setOfSession = Set.of(session);
                set.forEach(channel -> {
                    unsubscribe(channel, setOfSession);
                    if (channel instanceof ProjectContentChannel) {
                        ProjectContentChannel pChannel = (ProjectContentChannel) channel;
                        presenceManager.clearWsSession(pChannel.getProjectId(), sessionId);
                    }
                });
                this.wsSessionMap.remove(session);
                // flush changes ASAP
                WebsocketTxSync synchronizer = bag.getSynchronizer();
                if (synchronizer != null) {
                    synchronizer.flush();
                }
            }
        }
    }
}