001/*
002 * #%L
003 * HAPI FHIR Subscription Server
004 * %%
005 * Copyright (C) 2014 - 2025 Smile CDR, Inc.
006 * %%
007 * Licensed under the Apache License, Version 2.0 (the "License");
008 * you may not use this file except in compliance with the License.
009 * You may obtain a copy of the License at
010 *
011 *      http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 * #L%
019 */
020package ca.uhn.fhir.jpa.subscription.match.deliver.websocket;
021
022import ca.uhn.fhir.broker.api.IMessageListener;
023import ca.uhn.fhir.i18n.Msg;
024import ca.uhn.fhir.jpa.subscription.channel.subscription.SubscriptionChannelRegistry;
025import ca.uhn.fhir.jpa.subscription.channel.subscription.SubscriptionResourceDeliveryMessageConsumer;
026import ca.uhn.fhir.jpa.subscription.match.registry.ActiveSubscription;
027import ca.uhn.fhir.jpa.subscription.model.ResourceDeliveryMessage;
028import ca.uhn.fhir.rest.server.exceptions.InternalErrorException;
029import ca.uhn.fhir.rest.server.messaging.IMessage;
030import jakarta.annotation.Nonnull;
031import jakarta.annotation.PostConstruct;
032import jakarta.annotation.PreDestroy;
033import org.hl7.fhir.instance.model.api.IIdType;
034import org.hl7.fhir.r4.model.IdType;
035import org.slf4j.Logger;
036import org.slf4j.LoggerFactory;
037import org.springframework.beans.factory.annotation.Autowired;
038import org.springframework.web.socket.CloseStatus;
039import org.springframework.web.socket.TextMessage;
040import org.springframework.web.socket.WebSocketHandler;
041import org.springframework.web.socket.WebSocketSession;
042import org.springframework.web.socket.handler.TextWebSocketHandler;
043
044import java.io.IOException;
045import java.util.Optional;
046
047public class SubscriptionWebsocketHandler extends TextWebSocketHandler implements WebSocketHandler {
048        private static final Logger ourLog = LoggerFactory.getLogger(SubscriptionWebsocketHandler.class);
049
050        @Autowired
051        protected WebsocketConnectionValidator myWebsocketConnectionValidator;
052
053        @Autowired
054        SubscriptionChannelRegistry mySubscriptionChannelRegistry;
055
056        private IState myState = new InitialState();
057
058        /**
059         * Constructor
060         */
061        public SubscriptionWebsocketHandler() {
062                super();
063        }
064
065        @Override
066        public void afterConnectionClosed(WebSocketSession theSession, CloseStatus theStatus) throws Exception {
067                super.afterConnectionClosed(theSession, theStatus);
068                ourLog.info("Closing WebSocket connection from {}", theSession.getRemoteAddress());
069        }
070
071        @Override
072        public void afterConnectionEstablished(WebSocketSession theSession) throws Exception {
073                super.afterConnectionEstablished(theSession);
074                ourLog.info("Incoming WebSocket connection from {}", theSession.getRemoteAddress());
075        }
076
077        protected void handleFailure(Exception theE) {
078                ourLog.error("Failure during communication", theE);
079        }
080
081        @Override
082        protected void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) throws Exception {
083                ourLog.info("Textmessage: " + theMessage.getPayload());
084                myState.handleTextMessage(theSession, theMessage);
085        }
086
087        @Override
088        public void handleTransportError(WebSocketSession theSession, Throwable theException) throws Exception {
089                super.handleTransportError(theSession, theException);
090                ourLog.error("Transport error", theException);
091        }
092
093        @PostConstruct
094        public synchronized void postConstruct() {
095                ourLog.info("Websocket connection has been created");
096        }
097
098        @PreDestroy
099        public synchronized void preDescroy() {
100                ourLog.info("Websocket connection is closing");
101                IState state = myState;
102                if (state != null) {
103                        state.closing();
104                }
105        }
106
107        private interface IState {
108
109                void closing();
110
111                void handleTextMessage(WebSocketSession theSession, TextMessage theMessage);
112        }
113
114        private class BoundStaticSubscriptionState implements IState, IMessageListener<ResourceDeliveryMessage> {
115
116                private final WebSocketSession mySession;
117                private final ActiveSubscription myActiveSubscription;
118
119                public BoundStaticSubscriptionState(WebSocketSession theSession, ActiveSubscription theActiveSubscription) {
120                        mySession = theSession;
121                        myActiveSubscription = theActiveSubscription;
122
123                        SubscriptionResourceDeliveryMessageConsumer subscriptionResourceDeliveryMessageConsumer =
124                                        mySubscriptionChannelRegistry.getDeliveryConsumerWithListeners(
125                                                        theActiveSubscription.getChannelName());
126                        subscriptionResourceDeliveryMessageConsumer.addListener(this);
127                }
128
129                public Class<ResourceDeliveryMessage> getPayloadType() {
130                        return ResourceDeliveryMessage.class;
131                }
132
133                @Override
134                public void closing() {
135                        SubscriptionResourceDeliveryMessageConsumer subscriptionResourceDeliveryMessageConsumer =
136                                        mySubscriptionChannelRegistry.getDeliveryConsumerWithListeners(
137                                                        myActiveSubscription.getChannelName());
138                        subscriptionResourceDeliveryMessageConsumer.removeListener(this);
139                }
140
141                /**
142                 * Send the payload to the client
143                 *
144                 * @param payload The payload
145                 */
146                private void deliver(String payload) {
147                        try {
148                                // Log it
149                                ourLog.info("Sending WebSocket message: {}", payload);
150
151                                // Send message
152                                mySession.sendMessage(new TextMessage(payload));
153                        } catch (IOException e) {
154                                handleFailure(e);
155                        }
156                }
157
158                @Override
159                public void handleMessage(@Nonnull IMessage<ResourceDeliveryMessage> theMessage) {
160                        try {
161                                ResourceDeliveryMessage msg = theMessage.getPayload();
162                                handleSubscriptionPayload(msg);
163                        } catch (Exception e) {
164                                ourLog.error("Failure handling subscription payload", e);
165                                throw new InternalErrorException(Msg.code(6) + "Failure handling subscription payload", e);
166                        }
167                }
168
169                /**
170                 * Handle the subscription payload
171                 *
172                 * @param msg The message
173                 */
174                private void handleSubscriptionPayload(ResourceDeliveryMessage msg) {
175                        // Check if the subscription exists and is the same as the active subscription
176                        if (!myActiveSubscription.getSubscription().equals(msg.getSubscription())) {
177                                return;
178                        }
179
180                        // Default payload
181                        String defaultPayload = "ping " + myActiveSubscription.getId();
182                        String payload = defaultPayload;
183
184                        // Check if the subscription is a topic subscription
185                        if (msg.getSubscription().isTopicSubscription()) {
186                                // Get the payload by content
187                                payload = getPayloadByContent(msg).orElse(defaultPayload);
188                        }
189
190                        // Deliver the payload
191                        deliver(payload);
192                }
193
194                /**
195                 * Get the payload based on the subscription content
196                 *
197                 * @param msg The message
198                 * @return The payload
199                 */
200                private Optional<String> getPayloadByContent(ResourceDeliveryMessage msg) {
201                        if (msg.getSubscription().getContent() == null) {
202                                return Optional.empty();
203                        }
204                        switch (msg.getSubscription().getContent()) {
205                                case IDONLY:
206                                        return Optional.of(msg.getPayloadId());
207                                case FULLRESOURCE:
208                                        return Optional.of(msg.getPayloadString());
209                                case EMPTY:
210                                case NULL:
211                                default:
212                                        return Optional.empty();
213                        }
214                }
215
216                @Override
217                public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) {
218                        try {
219                                theSession.sendMessage(new TextMessage("Unexpected client message: " + theMessage.getPayload()));
220                        } catch (IOException e) {
221                                handleFailure(e);
222                        }
223                }
224        }
225
226        private class InitialState implements IState {
227
228                private IIdType bindSimple(WebSocketSession theSession, String theBindString) {
229                        IdType id = new IdType(theBindString);
230
231                        WebsocketValidationResponse response = myWebsocketConnectionValidator.validate(id);
232                        if (!response.isValid()) {
233                                try {
234                                        ourLog.warn(response.getMessage());
235                                        theSession.close(new CloseStatus(CloseStatus.PROTOCOL_ERROR.getCode(), response.getMessage()));
236                                } catch (IOException e) {
237                                        handleFailure(e);
238                                }
239                                return null;
240                        }
241
242                        myState = new BoundStaticSubscriptionState(theSession, response.getActiveSubscription());
243
244                        return id;
245                }
246
247                @Override
248                public void closing() {
249                        // nothing
250                }
251
252                @Override
253                public void handleTextMessage(WebSocketSession theSession, TextMessage theMessage) {
254                        String message = theMessage.getPayload();
255                        if (message.startsWith("bind ")) {
256                                String remaining = message.substring("bind ".length());
257
258                                IIdType subscriptionId;
259                                subscriptionId = bindSimple(theSession, remaining);
260                                if (subscriptionId == null) {
261                                        return;
262                                }
263
264                                try {
265                                        theSession.sendMessage(new TextMessage("bound " + subscriptionId.getIdPart()));
266                                } catch (IOException e) {
267                                        handleFailure(e);
268                                }
269                        }
270                }
271        }
272}