001/*
002 * #%L
003 * HAPI FHIR Subscription Server
004 * %%
005 * Copyright (C) 2014 - 2024 Smile CDR, Inc.
006 * %%
007 * Licensed under the Apache License, Version 2.0 (the "License");
008 * you may not use this file except in compliance with the License.
009 * You may obtain a copy of the License at
010 *
011 *      http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 * #L%
019 */
020package ca.uhn.fhir.jpa.subscription.match.deliver.websocket;
021
022import ca.uhn.fhir.i18n.Msg;
023import ca.uhn.fhir.jpa.subscription.channel.subscription.SubscriptionChannelRegistry;
024import ca.uhn.fhir.jpa.subscription.channel.subscription.SubscriptionChannelWithHandlers;
025import ca.uhn.fhir.jpa.subscription.match.registry.ActiveSubscription;
026import ca.uhn.fhir.jpa.subscription.model.ResourceDeliveryMessage;
027import jakarta.annotation.PostConstruct;
028import jakarta.annotation.PreDestroy;
029import org.hl7.fhir.instance.model.api.IIdType;
030import org.hl7.fhir.r4.model.IdType;
031import org.slf4j.Logger;
032import org.slf4j.LoggerFactory;
033import org.springframework.beans.factory.annotation.Autowired;
034import org.springframework.messaging.Message;
035import org.springframework.messaging.MessageHandler;
036import org.springframework.messaging.MessagingException;
037import org.springframework.web.socket.CloseStatus;
038import org.springframework.web.socket.TextMessage;
039import org.springframework.web.socket.WebSocketHandler;
040import org.springframework.web.socket.WebSocketSession;
041import org.springframework.web.socket.handler.TextWebSocketHandler;
042
043import java.io.IOException;
044import java.util.Optional;
045
046public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler {
047        private static final Logger ourLog = LoggerFactory.getLogger(SubscriptionWebsocketHandler.class);
048
049        @Autowired
050        protected WebsocketConnectionValidator myWebsocketConnectionValidator;
051
052        @Autowired
053        SubscriptionChannelRegistry mySubscriptionChannelRegistry;
054
055        private IState myState = new InitialState();
056
057        /**
058         * Constructor
059         */
060        public SubscriptionWebsocketHandler() {
061                super();
062        }
063
064        @Override
065        public void afterConnectionClosed(WebSocketSession theSession, CloseStatus theStatus) throws Exception {
066                super.afterConnectionClosed(theSession, theStatus);
067                ourLog.info("Closing WebSocket connection from {}", theSession.getRemoteAddress());
068        }
069
070        @Override
071        public void afterConnectionEstablished(WebSocketSession theSession) throws Exception {
072                super.afterConnectionEstablished(theSession);
073                ourLog.info("Incoming WebSocket connection from {}", theSession.getRemoteAddress());
074        }
075
076        protected void handleFailure(Exception theE) {
077                ourLog.error("Failure during communication", theE);
078        }
079
080        @Override
081        protected void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) throws Exception {
082                ourLog.info("Textmessage: " + theMessage.getPayload());
083                myState.handleTextMessage(theSession, theMessage);
084        }
085
086        @Override
087        public void handleTransportError(WebSocketSession theSession, Throwable theException) throws Exception {
088                super.handleTransportError(theSession, theException);
089                ourLog.error("Transport error", theException);
090        }
091
092        @PostConstruct
093        public synchronized void postConstruct() {
094                ourLog.info("Websocket connection has been created");
095        }
096
097        @PreDestroy
098        public synchronized void preDescroy() {
099                ourLog.info("Websocket connection is closing");
100                IState state = myState;
101                if (state != null) {
102                        state.closing();
103                }
104        }
105
106        private interface IState {
107
108                void closing();
109
110                void handleTextMessage(WebSocketSession theSession, TextMessage theMessage);
111        }
112
113        private class BoundStaticSubscriptionState implements IState, MessageHandler {
114
115                private final WebSocketSession mySession;
116                private final ActiveSubscription myActiveSubscription;
117
118                public BoundStaticSubscriptionState(WebSocketSession theSession, ActiveSubscription theActiveSubscription) {
119                        mySession = theSession;
120                        myActiveSubscription = theActiveSubscription;
121
122                        SubscriptionChannelWithHandlers subscriptionChannelWithHandlers =
123                                        mySubscriptionChannelRegistry.getDeliveryReceiverChannel(theActiveSubscription.getChannelName());
124                        subscriptionChannelWithHandlers.addHandler(this);
125                }
126
127                @Override
128                public void closing() {
129                        SubscriptionChannelWithHandlers subscriptionChannelWithHandlers =
130                                        mySubscriptionChannelRegistry.getDeliveryReceiverChannel(myActiveSubscription.getChannelName());
131                        subscriptionChannelWithHandlers.removeHandler(this);
132                }
133
134                /**
135                 * Send the payload to the client
136                 *
137                 * @param payload The payload
138                 */
139                private void deliver(String payload) {
140                        try {
141                                // Log it
142                                ourLog.info("Sending WebSocket message: {}", payload);
143
144                                // Send message
145                                mySession.sendMessage(new TextMessage(payload));
146                        } catch (IOException e) {
147                                handleFailure(e);
148                        }
149                }
150
151                @Override
152                public void handleMessage(Message<?> theMessage) {
153                        if (!(theMessage.getPayload() instanceof ResourceDeliveryMessage)) {
154                                return;
155                        }
156
157                        try {
158                                ResourceDeliveryMessage msg = (ResourceDeliveryMessage) theMessage.getPayload();
159                                handleSubscriptionPayload(msg);
160                        } catch (Exception e) {
161                                handleException(theMessage, e);
162                        }
163                }
164
165                /**
166                 * Handle the subscription payload
167                 *
168                 * @param msg The message
169                 */
170                private void handleSubscriptionPayload(ResourceDeliveryMessage msg) {
171                        // Check if the subscription exists and is the same as the active subscription
172                        if (!myActiveSubscription.getSubscription().equals(msg.getSubscription())) {
173                                return;
174                        }
175
176                        // Default payload
177                        String defaultPayload = "ping " + myActiveSubscription.getId();
178                        String payload = defaultPayload;
179
180                        // Check if the subscription is a topic subscription
181                        if (msg.getSubscription().isTopicSubscription()) {
182                                // Get the payload by content
183                                payload = getPayloadByContent(msg).orElse(defaultPayload);
184                        }
185
186                        // Deliver the payload
187                        deliver(payload);
188                }
189
190                /**
191                 * Handle the exception
192                 *
193                 * @param theMessage The message
194                 * @param e          The exception
195                 */
196                private void handleException(Message<?> theMessage, Exception e) {
197                        ourLog.error("Failure handling subscription payload", e);
198                        throw new MessagingException(theMessage, Msg.code(6) + "Failure handling subscription payload", e);
199                }
200
201                /**
202                 * Get the payload based on the subscription content
203                 *
204                 * @param msg The message
205                 * @return The payload
206                 */
207                private Optional<String> getPayloadByContent(ResourceDeliveryMessage msg) {
208                        if (msg.getSubscription().getContent() == null) {
209                                return Optional.empty();
210                        }
211                        switch (msg.getSubscription().getContent()) {
212                                case IDONLY:
213                                        return Optional.of(msg.getPayloadId());
214                                case FULLRESOURCE:
215                                        return Optional.of(msg.getPayloadString());
216                                case EMPTY:
217                                case NULL:
218                                default:
219                                        return Optional.empty();
220                        }
221                }
222
223                @Override
224                public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) {
225                        try {
226                                theSession.sendMessage(new TextMessage("Unexpected client message: " + theMessage.getPayload()));
227                        } catch (IOException e) {
228                                handleFailure(e);
229                        }
230                }
231        }
232
233        private class InitialState implements IState {
234
235                private IIdType bindSimple(WebSocketSession theSession, String theBindString) {
236                        IdType id = new IdType(theBindString);
237
238                        WebsocketValidationResponse response = myWebsocketConnectionValidator.validate(id);
239                        if (!response.isValid()) {
240                                try {
241                                        ourLog.warn(response.getMessage());
242                                        theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), response.getMessage()));
243                                } catch (IOException e) {
244                                        handleFailure(e);
245                                }
246                                return null;
247                        }
248
249                        myState = new BoundStaticSubscriptionState(theSession, response.getActiveSubscription());
250
251                        return id;
252                }
253
254                @Override
255                public void closing() {
256                        // nothing
257                }
258
259                @Override
260                public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) {
261                        String message = theMessage.getPayload();
262                        if (message.startsWith("bind ")) {
263                                String remaining = message.substring("bind ".length());
264
265                                IIdType subscriptionId;
266                                subscriptionId = bindSimple(theSession, remaining);
267                                if (subscriptionId == null) {
268                                        return;
269                                }
270
271                                try {
272                                        theSession.sendMessage(new TextMessage("bound " + subscriptionId.getIdPart()));
273                                } catch (IOException e) {
274                                        handleFailure(e);
275                                }
276                        }
277                }
278        }
279}