001/*-
002 * #%L
003 * HAPI FHIR - Core Library
004 * %%
005 * Copyright (C) 2014 - 2024 Smile CDR, Inc.
006 * %%
007 * Licensed under the Apache License, Version 2.0 (the "License");
008 * you may not use this file except in compliance with the License.
009 * You may obtain a copy of the License at
010 *
011 *      http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 * #L%
019 */
020package ca.uhn.fhir.interceptor.executor;
021
022import ca.uhn.fhir.i18n.Msg;
023import ca.uhn.fhir.interceptor.api.HookParams;
024import ca.uhn.fhir.interceptor.api.IBaseInterceptorBroadcaster;
025import ca.uhn.fhir.interceptor.api.IBaseInterceptorService;
026import ca.uhn.fhir.interceptor.api.IPointcut;
027import ca.uhn.fhir.interceptor.api.Interceptor;
028import ca.uhn.fhir.interceptor.api.Pointcut;
029import ca.uhn.fhir.rest.server.exceptions.InternalErrorException;
030import ca.uhn.fhir.util.ReflectionUtil;
031import com.google.common.annotations.VisibleForTesting;
032import com.google.common.collect.ArrayListMultimap;
033import com.google.common.collect.ListMultimap;
034import jakarta.annotation.Nonnull;
035import jakarta.annotation.Nullable;
036import org.apache.commons.lang3.Validate;
037import org.apache.commons.lang3.builder.ToStringBuilder;
038import org.apache.commons.lang3.builder.ToStringStyle;
039import org.apache.commons.lang3.reflect.MethodUtils;
040import org.slf4j.Logger;
041import org.slf4j.LoggerFactory;
042
043import java.lang.annotation.Annotation;
044import java.lang.reflect.AnnotatedElement;
045import java.lang.reflect.InvocationTargetException;
046import java.lang.reflect.Method;
047import java.util.ArrayList;
048import java.util.Arrays;
049import java.util.Collection;
050import java.util.Collections;
051import java.util.Comparator;
052import java.util.EnumSet;
053import java.util.HashMap;
054import java.util.IdentityHashMap;
055import java.util.List;
056import java.util.Map;
057import java.util.Objects;
058import java.util.Optional;
059import java.util.concurrent.atomic.AtomicInteger;
060import java.util.function.Predicate;
061import java.util.stream.Collectors;
062
063public abstract class BaseInterceptorService<POINTCUT extends Enum<POINTCUT> & IPointcut>
064                implements IBaseInterceptorService<POINTCUT>, IBaseInterceptorBroadcaster<POINTCUT> {
065        private static final Logger ourLog = LoggerFactory.getLogger(BaseInterceptorService.class);
066        private final List<Object> myInterceptors = new ArrayList<>();
067        private final ListMultimap<POINTCUT, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create();
068        private final ListMultimap<POINTCUT, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create();
069        private final Object myRegistryMutex = new Object();
070        private final Class<POINTCUT> myPointcutType;
071        private volatile EnumSet<POINTCUT> myRegisteredPointcuts;
072        private String myName;
073        private boolean myWarnOnInterceptorWithNoHooks = true;
074
075        /**
076         * Constructor which uses a default name of "default"
077         */
078        public BaseInterceptorService(Class<POINTCUT> thePointcutType) {
079                this(thePointcutType, "default");
080        }
081
082        /**
083         * Constructor
084         *
085         * @param theName The name for this registry (useful for troubleshooting)
086         */
087        public BaseInterceptorService(Class<POINTCUT> thePointcutType, String theName) {
088                super();
089                myName = theName;
090                myPointcutType = thePointcutType;
091                rebuildRegisteredPointcutSet();
092        }
093
094        /**
095         * Should a warning be issued if an interceptor is registered and it has no hooks
096         */
097        public void setWarnOnInterceptorWithNoHooks(boolean theWarnOnInterceptorWithNoHooks) {
098                myWarnOnInterceptorWithNoHooks = theWarnOnInterceptorWithNoHooks;
099        }
100
101        @VisibleForTesting
102        List<Object> getGlobalInterceptorsForUnitTest() {
103                return myInterceptors;
104        }
105
106        public void setName(String theName) {
107                myName = theName;
108        }
109
110        protected void registerAnonymousInterceptor(POINTCUT thePointcut, Object theInterceptor, BaseInvoker theInvoker) {
111                Validate.notNull(thePointcut);
112                Validate.notNull(theInterceptor);
113                synchronized (myRegistryMutex) {
114                        myAnonymousInvokers.put(thePointcut, theInvoker);
115                        if (!isInterceptorAlreadyRegistered(theInterceptor)) {
116                                myInterceptors.add(theInterceptor);
117                        }
118
119                        rebuildRegisteredPointcutSet();
120                }
121        }
122
123        @Override
124        public List<Object> getAllRegisteredInterceptors() {
125                synchronized (myRegistryMutex) {
126                        List<Object> retVal = new ArrayList<>(myInterceptors);
127                        return Collections.unmodifiableList(retVal);
128                }
129        }
130
131        @Override
132        @VisibleForTesting
133        public void unregisterAllInterceptors() {
134                synchronized (myRegistryMutex) {
135                        unregisterInterceptors(myAnonymousInvokers.values());
136                        unregisterInterceptors(myGlobalInvokers.values());
137                        unregisterInterceptors(myInterceptors);
138                }
139        }
140
141        @Override
142        public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) {
143                if (theInterceptors != null) {
144                        // We construct a new list before iterating because the service's internal
145                        // interceptor lists get passed into this method, and we get concurrent
146                        // modification errors if we modify them at the same time as we iterate them
147                        new ArrayList<>(theInterceptors).forEach(this::unregisterInterceptor);
148                }
149        }
150
151        @Override
152        public void registerInterceptors(@Nullable Collection<?> theInterceptors) {
153                if (theInterceptors != null) {
154                        theInterceptors.forEach(this::registerInterceptor);
155                }
156        }
157
158        @Override
159        public void unregisterAllAnonymousInterceptors() {
160                synchronized (myRegistryMutex) {
161                        unregisterInterceptorsIf(t -> true, myAnonymousInvokers);
162                }
163        }
164
165        @Override
166        public void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction) {
167                unregisterInterceptorsIf(theShouldUnregisterFunction, myGlobalInvokers);
168                unregisterInterceptorsIf(theShouldUnregisterFunction, myAnonymousInvokers);
169        }
170
171        private void unregisterInterceptorsIf(
172                        Predicate<Object> theShouldUnregisterFunction, ListMultimap<POINTCUT, BaseInvoker> theGlobalInvokers) {
173                synchronized (myRegistryMutex) {
174                        for (Map.Entry<POINTCUT, BaseInvoker> nextInvoker : new ArrayList<>(theGlobalInvokers.entries())) {
175                                if (theShouldUnregisterFunction.test(nextInvoker.getValue().getInterceptor())) {
176                                        unregisterInterceptor(nextInvoker.getValue().getInterceptor());
177                                }
178                        }
179
180                        rebuildRegisteredPointcutSet();
181                }
182        }
183
184        @Override
185        public boolean registerInterceptor(Object theInterceptor) {
186                synchronized (myRegistryMutex) {
187                        if (isInterceptorAlreadyRegistered(theInterceptor)) {
188                                return false;
189                        }
190
191                        List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers);
192                        if (addedInvokers.isEmpty()) {
193                                if (myWarnOnInterceptorWithNoHooks) {
194                                        ourLog.warn(
195                                                        "Interceptor registered with no valid hooks - Type was: {}",
196                                                        theInterceptor.getClass().getName());
197                                }
198                                return false;
199                        }
200
201                        // Add to the global list
202                        myInterceptors.add(theInterceptor);
203                        sortByOrderAnnotation(myInterceptors);
204
205                        rebuildRegisteredPointcutSet();
206
207                        return true;
208                }
209        }
210
211        private void rebuildRegisteredPointcutSet() {
212                EnumSet<POINTCUT> registeredPointcuts = EnumSet.noneOf(myPointcutType);
213                registeredPointcuts.addAll(myAnonymousInvokers.keySet());
214                registeredPointcuts.addAll(myGlobalInvokers.keySet());
215                myRegisteredPointcuts = registeredPointcuts;
216        }
217
218        private boolean isInterceptorAlreadyRegistered(Object theInterceptor) {
219                for (Object next : myInterceptors) {
220                        if (next == theInterceptor) {
221                                return true;
222                        }
223                }
224                return false;
225        }
226
227        @Override
228        public boolean unregisterInterceptor(Object theInterceptor) {
229                synchronized (myRegistryMutex) {
230                        boolean removed = myInterceptors.removeIf(t -> t == theInterceptor);
231                        removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
232                        removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor);
233                        rebuildRegisteredPointcutSet();
234                        return removed;
235                }
236        }
237
238        private void sortByOrderAnnotation(List<Object> theObjects) {
239                IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>();
240                for (Object next : theObjects) {
241                        Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class);
242                        int order = orderAnnotation != null ? orderAnnotation.order() : 0;
243                        interceptorToOrder.put(next, order);
244                }
245
246                theObjects.sort((a, b) -> {
247                        Integer orderA = interceptorToOrder.get(a);
248                        Integer orderB = interceptorToOrder.get(b);
249                        return orderA - orderB;
250                });
251        }
252
253        @Override
254        public Object callHooksAndReturnObject(POINTCUT thePointcut, HookParams theParams) {
255                assert haveAppropriateParams(thePointcut, theParams);
256                assert thePointcut.getReturnType() != void.class;
257
258                return doCallHooks(thePointcut, theParams, null);
259        }
260
261        @Override
262        public boolean hasHooks(POINTCUT thePointcut) {
263                return myRegisteredPointcuts.contains(thePointcut);
264        }
265
266        protected Class<?> getBooleanReturnType() {
267                return boolean.class;
268        }
269
270        @Override
271        public boolean callHooks(POINTCUT thePointcut, HookParams theParams) {
272                assert haveAppropriateParams(thePointcut, theParams);
273                assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == getBooleanReturnType();
274
275                Object retValObj = doCallHooks(thePointcut, theParams, true);
276                return (Boolean) retValObj;
277        }
278
279        private Object doCallHooks(POINTCUT thePointcut, HookParams theParams, Object theRetVal) {
280                // use new list for loop to avoid ConcurrentModificationException in case invoker gets added while looping
281                List<BaseInvoker> invokers = new ArrayList<>(getInvokersForPointcut(thePointcut));
282
283                /*
284                 * Call each hook in order
285                 */
286                for (BaseInvoker nextInvoker : invokers) {
287                        Object nextOutcome = nextInvoker.invoke(theParams);
288                        Class<?> pointcutReturnType = thePointcut.getReturnType();
289                        if (pointcutReturnType.equals(getBooleanReturnType())) {
290                                Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome;
291                                if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) {
292                                        ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker);
293                                        theRetVal = false;
294                                        break;
295                                } else {
296                                        theRetVal = true;
297                                }
298                        } else if (!pointcutReturnType.equals(void.class)) {
299                                if (nextOutcome != null) {
300                                        theRetVal = nextOutcome;
301                                        break;
302                                }
303                        }
304                }
305
306                return theRetVal;
307        }
308
309        @VisibleForTesting
310        List<Object> getInterceptorsWithInvokersForPointcut(POINTCUT thePointcut) {
311                return getInvokersForPointcut(thePointcut).stream()
312                                .map(BaseInvoker::getInterceptor)
313                                .collect(Collectors.toList());
314        }
315
316        /**
317         * Returns an ordered list of invokers for the given pointcut. Note that
318         * a new and stable list is returned to.. do whatever you want with it.
319         */
320        private List<BaseInvoker> getInvokersForPointcut(POINTCUT thePointcut) {
321                List<BaseInvoker> invokers;
322
323                synchronized (myRegistryMutex) {
324                        List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut);
325                        List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut);
326                        List<BaseInvoker> threadLocalInvokers = null;
327                        invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers);
328                }
329
330                return invokers;
331        }
332
333        /**
334         * First argument must be the global invoker list!!
335         */
336        @SafeVarargs
337        private List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) {
338                List<BaseInvoker> haveOne = null;
339                boolean haveMultiple = false;
340                for (List<BaseInvoker> nextInvokerList : theInvokersLists) {
341                        if (nextInvokerList == null || nextInvokerList.isEmpty()) {
342                                continue;
343                        }
344
345                        if (haveOne == null) {
346                                haveOne = nextInvokerList;
347                        } else {
348                                haveMultiple = true;
349                        }
350                }
351
352                if (haveOne == null) {
353                        return Collections.emptyList();
354                }
355
356                List<BaseInvoker> retVal;
357
358                if (!haveMultiple) {
359
360                        // The global list doesn't need to be sorted every time since it's sorted on
361                        // insertion each time. Doing so is a waste of cycles..
362                        if (haveOne == theInvokersLists[0]) {
363                                retVal = haveOne;
364                        } else {
365                                retVal = new ArrayList<>(haveOne);
366                                retVal.sort(Comparator.naturalOrder());
367                        }
368
369                } else {
370
371                        retVal = Arrays.stream(theInvokersLists)
372                                        .filter(Objects::nonNull)
373                                        .flatMap(Collection::stream)
374                                        .sorted()
375                                        .collect(Collectors.toList());
376                }
377
378                return retVal;
379        }
380
381        /**
382         * Only call this when assertions are enabled, it's expensive
383         */
384        final boolean haveAppropriateParams(POINTCUT thePointcut, HookParams theParams) {
385                if (theParams.getParamsForType().values().size()
386                                != thePointcut.getParameterTypes().size()) {
387                        throw new IllegalArgumentException(Msg.code(1909)
388                                        + String.format(
389                                                        "Wrong number of params for pointcut %s - Wanted %s but found %s",
390                                                        thePointcut.name(),
391                                                        toErrorString(thePointcut.getParameterTypes()),
392                                                        theParams.getParamsForType().values().stream()
393                                                                        .map(t -> t != null ? t.getClass().getSimpleName() : "null")
394                                                                        .sorted()
395                                                                        .collect(Collectors.toList())));
396                }
397
398                List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes());
399
400                ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType();
401                for (Class<?> nextTypeClass : givenTypes.keySet()) {
402                        String nextTypeName = nextTypeClass.getName();
403                        for (Object nextParamValue : givenTypes.get(nextTypeClass)) {
404                                Validate.isTrue(
405                                                nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()),
406                                                "Invalid params for pointcut %s - %s is not of type %s",
407                                                thePointcut.name(),
408                                                nextParamValue != null ? nextParamValue.getClass() : "null",
409                                                nextTypeClass);
410                                Validate.isTrue(
411                                                wantedTypes.remove(nextTypeName),
412                                                "Invalid params for pointcut %s - Wanted %s but found %s",
413                                                thePointcut.name(),
414                                                toErrorString(thePointcut.getParameterTypes()),
415                                                nextTypeName);
416                        }
417                }
418
419                return true;
420        }
421
422        private List<HookInvoker> scanInterceptorAndAddToInvokerMultimap(
423                        Object theInterceptor, ListMultimap<POINTCUT, BaseInvoker> theInvokers) {
424                Class<?> interceptorClass = theInterceptor.getClass();
425                int typeOrder = determineOrder(interceptorClass);
426
427                List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder);
428
429                // Invoke the REGISTERED pointcut for any added hooks
430                addedInvokers.stream()
431                                .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut()))
432                                .forEach(t -> t.invoke(new HookParams()));
433
434                // Register the interceptor and its various hooks
435                for (HookInvoker nextAddedHook : addedInvokers) {
436                        POINTCUT nextPointcut = nextAddedHook.getPointcut();
437                        if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) {
438                                continue;
439                        }
440                        theInvokers.put(nextPointcut, nextAddedHook);
441                }
442
443                // Make sure we're always sorted according to the order declared in @Order
444                for (POINTCUT nextPointcut : theInvokers.keys()) {
445                        List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut);
446                        nextInvokerList.sort(Comparator.naturalOrder());
447                }
448
449                return addedInvokers;
450        }
451
452        /**
453         * @return Returns a list of any added invokers
454         */
455        private List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) {
456                ArrayList<HookInvoker> retVal = new ArrayList<>();
457                for (Method nextMethod : ReflectionUtil.getDeclaredMethods(theInterceptor.getClass(), true)) {
458                        Optional<HookDescriptor> hook = scanForHook(nextMethod);
459
460                        if (hook.isPresent()) {
461                                int methodOrder = theTypeOrder;
462                                int methodOrderAnnotation = hook.get().getOrder();
463                                if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) {
464                                        methodOrder = methodOrderAnnotation;
465                                }
466
467                                retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder));
468                        }
469                }
470
471                return retVal;
472        }
473
474        protected abstract Optional<HookDescriptor> scanForHook(Method nextMethod);
475
476        private class HookInvoker extends BaseInvoker {
477
478                private final Method myMethod;
479                private final Class<?>[] myParameterTypes;
480                private final int[] myParameterIndexes;
481                private final POINTCUT myPointcut;
482
483                /**
484                 * Constructor
485                 */
486                private HookInvoker(
487                                HookDescriptor theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) {
488                        super(theInterceptor, theOrder);
489                        myPointcut = theHook.getPointcut();
490                        myParameterTypes = theHookMethod.getParameterTypes();
491                        myMethod = theHookMethod;
492
493                        Class<?> returnType = theHookMethod.getReturnType();
494                        if (myPointcut.getReturnType().equals(getBooleanReturnType())) {
495                                Validate.isTrue(
496                                                getBooleanReturnType().equals(returnType) || void.class.equals(returnType),
497                                                "Method does not return boolean or void: %s",
498                                                theHookMethod);
499                        } else if (myPointcut.getReturnType().equals(void.class)) {
500                                Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod);
501                        } else {
502                                Validate.isTrue(
503                                                myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType),
504                                                "Method does not return %s or void: %s",
505                                                myPointcut.getReturnType(),
506                                                theHookMethod);
507                        }
508
509                        myParameterIndexes = new int[myParameterTypes.length];
510                        Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>();
511                        for (int i = 0; i < myParameterTypes.length; i++) {
512                                AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0));
513                                myParameterIndexes[i] = counter.getAndIncrement();
514                        }
515
516                        myMethod.setAccessible(true);
517                }
518
519                @Override
520                public String toString() {
521                        return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE)
522                                        .append("method", myMethod)
523                                        .toString();
524                }
525
526                public POINTCUT getPointcut() {
527                        return myPointcut;
528                }
529
530                /**
531                 * @return Returns true/false if the hook method returns a boolean, returns true otherwise
532                 */
533                @Override
534                Object invoke(HookParams theParams) {
535
536                        Object[] args = new Object[myParameterTypes.length];
537                        for (int i = 0; i < myParameterTypes.length; i++) {
538                                Class<?> nextParamType = myParameterTypes[i];
539                                if (nextParamType.equals(Pointcut.class)) {
540                                        args[i] = myPointcut;
541                                } else {
542                                        int nextParamIndex = myParameterIndexes[i];
543                                        Object nextParamValue = theParams.get(nextParamType, nextParamIndex);
544                                        args[i] = nextParamValue;
545                                }
546                        }
547
548                        // Invoke the method
549                        try {
550                                return myMethod.invoke(getInterceptor(), args);
551                        } catch (InvocationTargetException e) {
552                                Throwable targetException = e.getTargetException();
553                                if (myPointcut.isShouldLogAndSwallowException(targetException)) {
554                                        ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException);
555                                        return null;
556                                }
557
558                                if (targetException instanceof RuntimeException) {
559                                        throw ((RuntimeException) targetException);
560                                } else {
561                                        throw new InternalErrorException(
562                                                        Msg.code(1910) + "Failure invoking interceptor for pointcut(s) " + getPointcut(),
563                                                        targetException);
564                                }
565                        } catch (Exception e) {
566                                throw new InternalErrorException(Msg.code(1911) + e);
567                        }
568                }
569        }
570
571        protected class HookDescriptor {
572
573                private final POINTCUT myPointcut;
574                private final int myOrder;
575
576                public HookDescriptor(POINTCUT thePointcut, int theOrder) {
577                        myPointcut = thePointcut;
578                        myOrder = theOrder;
579                }
580
581                POINTCUT getPointcut() {
582                        return myPointcut;
583                }
584
585                int getOrder() {
586                        return myOrder;
587                }
588        }
589
590        protected abstract static class BaseInvoker implements Comparable<BaseInvoker> {
591
592                private final int myOrder;
593                private final Object myInterceptor;
594
595                BaseInvoker(Object theInterceptor, int theOrder) {
596                        myInterceptor = theInterceptor;
597                        myOrder = theOrder;
598                }
599
600                public Object getInterceptor() {
601                        return myInterceptor;
602                }
603
604                abstract Object invoke(HookParams theParams);
605
606                @Override
607                public int compareTo(BaseInvoker theInvoker) {
608                        return myOrder - theInvoker.myOrder;
609                }
610        }
611
612        protected static <T extends Annotation> Optional<T> findAnnotation(
613                        AnnotatedElement theObject, Class<T> theHookClass) {
614                T annotation;
615                if (theObject instanceof Method) {
616                        annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true);
617                } else {
618                        annotation = theObject.getAnnotation(theHookClass);
619                }
620                return Optional.ofNullable(annotation);
621        }
622
623        private static int determineOrder(Class<?> theInterceptorClass) {
624                return findAnnotation(theInterceptorClass, Interceptor.class)
625                                .map(Interceptor::order)
626                                .orElse(Interceptor.DEFAULT_ORDER);
627        }
628
629        private static String toErrorString(List<String> theParameterTypes) {
630                return theParameterTypes.stream().sorted().collect(Collectors.joining(","));
631        }
632}