
001/*- 002 * #%L 003 * HAPI FHIR - Core Library 004 * %% 005 * Copyright (C) 2014 - 2023 Smile CDR, Inc. 006 * %% 007 * Licensed under the Apache License, Version 2.0 (the "License"); 008 * you may not use this file except in compliance with the License. 009 * You may obtain a copy of the License at 010 * 011 * http://www.apache.org/licenses/LICENSE-2.0 012 * 013 * Unless required by applicable law or agreed to in writing, software 014 * distributed under the License is distributed on an "AS IS" BASIS, 015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 016 * See the License for the specific language governing permissions and 017 * limitations under the License. 018 * #L% 019 */ 020package ca.uhn.fhir.interceptor.executor; 021 022import ca.uhn.fhir.i18n.Msg; 023import ca.uhn.fhir.interceptor.api.HookParams; 024import ca.uhn.fhir.interceptor.api.IBaseInterceptorBroadcaster; 025import ca.uhn.fhir.interceptor.api.IBaseInterceptorService; 026import ca.uhn.fhir.interceptor.api.IPointcut; 027import ca.uhn.fhir.interceptor.api.Interceptor; 028import ca.uhn.fhir.interceptor.api.Pointcut; 029import ca.uhn.fhir.rest.server.exceptions.InternalErrorException; 030import ca.uhn.fhir.util.ReflectionUtil; 031import com.google.common.annotations.VisibleForTesting; 032import com.google.common.collect.ArrayListMultimap; 033import com.google.common.collect.ListMultimap; 034import org.apache.commons.lang3.Validate; 035import org.apache.commons.lang3.builder.ToStringBuilder; 036import org.apache.commons.lang3.builder.ToStringStyle; 037import org.apache.commons.lang3.reflect.MethodUtils; 038import org.slf4j.Logger; 039import org.slf4j.LoggerFactory; 040 041import java.lang.annotation.Annotation; 042import java.lang.reflect.AnnotatedElement; 043import java.lang.reflect.InvocationTargetException; 044import java.lang.reflect.Method; 045import java.util.ArrayList; 046import java.util.Arrays; 047import java.util.Collection; 048import java.util.Collections; 049import java.util.Comparator; 050import java.util.EnumSet; 051import java.util.HashMap; 052import java.util.IdentityHashMap; 053import java.util.List; 054import java.util.Map; 055import java.util.Objects; 056import java.util.Optional; 057import java.util.concurrent.atomic.AtomicInteger; 058import java.util.function.Predicate; 059import java.util.stream.Collectors; 060import javax.annotation.Nonnull; 061import javax.annotation.Nullable; 062 063public abstract class BaseInterceptorService<POINTCUT extends Enum<POINTCUT> & IPointcut> 064 implements IBaseInterceptorService<POINTCUT>, IBaseInterceptorBroadcaster<POINTCUT> { 065 private static final Logger ourLog = LoggerFactory.getLogger(BaseInterceptorService.class); 066 private final List<Object> myInterceptors = new ArrayList<>(); 067 private final ListMultimap<POINTCUT, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create(); 068 private final ListMultimap<POINTCUT, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create(); 069 private final Object myRegistryMutex = new Object(); 070 private final Class<POINTCUT> myPointcutType; 071 private volatile EnumSet<POINTCUT> myRegisteredPointcuts; 072 private String myName; 073 private boolean myWarnOnInterceptorWithNoHooks = true; 074 075 /** 076 * Constructor which uses a default name of "default" 077 */ 078 public BaseInterceptorService(Class<POINTCUT> thePointcutType) { 079 this(thePointcutType, "default"); 080 } 081 082 /** 083 * Constructor 084 * 085 * @param theName The name for this registry (useful for troubleshooting) 086 */ 087 public BaseInterceptorService(Class<POINTCUT> thePointcutType, String theName) { 088 super(); 089 myName = theName; 090 myPointcutType = thePointcutType; 091 rebuildRegisteredPointcutSet(); 092 } 093 094 /** 095 * Should a warning be issued if an interceptor is registered and it has no hooks 096 */ 097 public void setWarnOnInterceptorWithNoHooks(boolean theWarnOnInterceptorWithNoHooks) { 098 myWarnOnInterceptorWithNoHooks = theWarnOnInterceptorWithNoHooks; 099 } 100 101 @VisibleForTesting 102 List<Object> getGlobalInterceptorsForUnitTest() { 103 return myInterceptors; 104 } 105 106 public void setName(String theName) { 107 myName = theName; 108 } 109 110 protected void registerAnonymousInterceptor(POINTCUT thePointcut, Object theInterceptor, BaseInvoker theInvoker) { 111 Validate.notNull(thePointcut); 112 Validate.notNull(theInterceptor); 113 synchronized (myRegistryMutex) { 114 myAnonymousInvokers.put(thePointcut, theInvoker); 115 if (!isInterceptorAlreadyRegistered(theInterceptor)) { 116 myInterceptors.add(theInterceptor); 117 } 118 119 rebuildRegisteredPointcutSet(); 120 } 121 } 122 123 @Override 124 public List<Object> getAllRegisteredInterceptors() { 125 synchronized (myRegistryMutex) { 126 List<Object> retVal = new ArrayList<>(myInterceptors); 127 return Collections.unmodifiableList(retVal); 128 } 129 } 130 131 @Override 132 @VisibleForTesting 133 public void unregisterAllInterceptors() { 134 synchronized (myRegistryMutex) { 135 unregisterInterceptors(myAnonymousInvokers.values()); 136 unregisterInterceptors(myGlobalInvokers.values()); 137 unregisterInterceptors(myInterceptors); 138 } 139 } 140 141 @Override 142 public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) { 143 if (theInterceptors != null) { 144 // We construct a new list before iterating because the service's internal 145 // interceptor lists get passed into this method, and we get concurrent 146 // modification errors if we modify them at the same time as we iterate them 147 new ArrayList<>(theInterceptors).forEach(this::unregisterInterceptor); 148 } 149 } 150 151 @Override 152 public void registerInterceptors(@Nullable Collection<?> theInterceptors) { 153 if (theInterceptors != null) { 154 theInterceptors.forEach(this::registerInterceptor); 155 } 156 } 157 158 @Override 159 public void unregisterAllAnonymousInterceptors() { 160 synchronized (myRegistryMutex) { 161 unregisterInterceptorsIf(t -> true, myAnonymousInvokers); 162 } 163 } 164 165 @Override 166 public void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction) { 167 unregisterInterceptorsIf(theShouldUnregisterFunction, myGlobalInvokers); 168 unregisterInterceptorsIf(theShouldUnregisterFunction, myAnonymousInvokers); 169 } 170 171 private void unregisterInterceptorsIf( 172 Predicate<Object> theShouldUnregisterFunction, ListMultimap<POINTCUT, BaseInvoker> theGlobalInvokers) { 173 synchronized (myRegistryMutex) { 174 for (Map.Entry<POINTCUT, BaseInvoker> nextInvoker : new ArrayList<>(theGlobalInvokers.entries())) { 175 if (theShouldUnregisterFunction.test(nextInvoker.getValue().getInterceptor())) { 176 unregisterInterceptor(nextInvoker.getValue().getInterceptor()); 177 } 178 } 179 180 rebuildRegisteredPointcutSet(); 181 } 182 } 183 184 @Override 185 public boolean registerInterceptor(Object theInterceptor) { 186 synchronized (myRegistryMutex) { 187 if (isInterceptorAlreadyRegistered(theInterceptor)) { 188 return false; 189 } 190 191 List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers); 192 if (addedInvokers.isEmpty()) { 193 if (myWarnOnInterceptorWithNoHooks) { 194 ourLog.warn( 195 "Interceptor registered with no valid hooks - Type was: {}", 196 theInterceptor.getClass().getName()); 197 } 198 return false; 199 } 200 201 // Add to the global list 202 myInterceptors.add(theInterceptor); 203 sortByOrderAnnotation(myInterceptors); 204 205 rebuildRegisteredPointcutSet(); 206 207 return true; 208 } 209 } 210 211 private void rebuildRegisteredPointcutSet() { 212 EnumSet<POINTCUT> registeredPointcuts = EnumSet.noneOf(myPointcutType); 213 registeredPointcuts.addAll(myAnonymousInvokers.keySet()); 214 registeredPointcuts.addAll(myGlobalInvokers.keySet()); 215 myRegisteredPointcuts = registeredPointcuts; 216 } 217 218 private boolean isInterceptorAlreadyRegistered(Object theInterceptor) { 219 for (Object next : myInterceptors) { 220 if (next == theInterceptor) { 221 return true; 222 } 223 } 224 return false; 225 } 226 227 @Override 228 public boolean unregisterInterceptor(Object theInterceptor) { 229 synchronized (myRegistryMutex) { 230 boolean removed = myInterceptors.removeIf(t -> t == theInterceptor); 231 removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 232 removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 233 rebuildRegisteredPointcutSet(); 234 return removed; 235 } 236 } 237 238 private void sortByOrderAnnotation(List<Object> theObjects) { 239 IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>(); 240 for (Object next : theObjects) { 241 Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class); 242 int order = orderAnnotation != null ? orderAnnotation.order() : 0; 243 interceptorToOrder.put(next, order); 244 } 245 246 theObjects.sort((a, b) -> { 247 Integer orderA = interceptorToOrder.get(a); 248 Integer orderB = interceptorToOrder.get(b); 249 return orderA - orderB; 250 }); 251 } 252 253 @Override 254 public Object callHooksAndReturnObject(POINTCUT thePointcut, HookParams theParams) { 255 assert haveAppropriateParams(thePointcut, theParams); 256 assert thePointcut.getReturnType() != void.class; 257 258 return doCallHooks(thePointcut, theParams, null); 259 } 260 261 @Override 262 public boolean hasHooks(POINTCUT thePointcut) { 263 return myRegisteredPointcuts.contains(thePointcut); 264 } 265 266 protected Class<?> getBooleanReturnType() { 267 return boolean.class; 268 } 269 270 @Override 271 public boolean callHooks(POINTCUT thePointcut, HookParams theParams) { 272 assert haveAppropriateParams(thePointcut, theParams); 273 assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == getBooleanReturnType(); 274 275 Object retValObj = doCallHooks(thePointcut, theParams, true); 276 return (Boolean) retValObj; 277 } 278 279 private Object doCallHooks(POINTCUT thePointcut, HookParams theParams, Object theRetVal) { 280 // use new list for loop to avoid ConcurrentModificationException in case invoker gets added while looping 281 List<BaseInvoker> invokers = new ArrayList<>(getInvokersForPointcut(thePointcut)); 282 283 /* 284 * Call each hook in order 285 */ 286 for (BaseInvoker nextInvoker : invokers) { 287 Object nextOutcome = nextInvoker.invoke(theParams); 288 Class<?> pointcutReturnType = thePointcut.getReturnType(); 289 if (pointcutReturnType.equals(getBooleanReturnType())) { 290 Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome; 291 if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) { 292 ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker); 293 theRetVal = false; 294 break; 295 } else { 296 theRetVal = true; 297 } 298 } else if (!pointcutReturnType.equals(void.class)) { 299 if (nextOutcome != null) { 300 theRetVal = nextOutcome; 301 break; 302 } 303 } 304 } 305 306 return theRetVal; 307 } 308 309 @VisibleForTesting 310 List<Object> getInterceptorsWithInvokersForPointcut(POINTCUT thePointcut) { 311 return getInvokersForPointcut(thePointcut).stream() 312 .map(BaseInvoker::getInterceptor) 313 .collect(Collectors.toList()); 314 } 315 316 /** 317 * Returns an ordered list of invokers for the given pointcut. Note that 318 * a new and stable list is returned to.. do whatever you want with it. 319 */ 320 private List<BaseInvoker> getInvokersForPointcut(POINTCUT thePointcut) { 321 List<BaseInvoker> invokers; 322 323 synchronized (myRegistryMutex) { 324 List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut); 325 List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut); 326 List<BaseInvoker> threadLocalInvokers = null; 327 invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers); 328 } 329 330 return invokers; 331 } 332 333 /** 334 * First argument must be the global invoker list!! 335 */ 336 @SafeVarargs 337 private List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) { 338 List<BaseInvoker> haveOne = null; 339 boolean haveMultiple = false; 340 for (List<BaseInvoker> nextInvokerList : theInvokersLists) { 341 if (nextInvokerList == null || nextInvokerList.isEmpty()) { 342 continue; 343 } 344 345 if (haveOne == null) { 346 haveOne = nextInvokerList; 347 } else { 348 haveMultiple = true; 349 } 350 } 351 352 if (haveOne == null) { 353 return Collections.emptyList(); 354 } 355 356 List<BaseInvoker> retVal; 357 358 if (!haveMultiple) { 359 360 // The global list doesn't need to be sorted every time since it's sorted on 361 // insertion each time. Doing so is a waste of cycles.. 362 if (haveOne == theInvokersLists[0]) { 363 retVal = haveOne; 364 } else { 365 retVal = new ArrayList<>(haveOne); 366 retVal.sort(Comparator.naturalOrder()); 367 } 368 369 } else { 370 371 retVal = Arrays.stream(theInvokersLists) 372 .filter(Objects::nonNull) 373 .flatMap(Collection::stream) 374 .sorted() 375 .collect(Collectors.toList()); 376 } 377 378 return retVal; 379 } 380 381 /** 382 * Only call this when assertions are enabled, it's expensive 383 */ 384 final boolean haveAppropriateParams(POINTCUT thePointcut, HookParams theParams) { 385 if (theParams.getParamsForType().values().size() 386 != thePointcut.getParameterTypes().size()) { 387 throw new IllegalArgumentException(Msg.code(1909) 388 + String.format( 389 "Wrong number of params for pointcut %s - Wanted %s but found %s", 390 thePointcut.name(), 391 toErrorString(thePointcut.getParameterTypes()), 392 theParams.getParamsForType().values().stream() 393 .map(t -> t != null ? t.getClass().getSimpleName() : "null") 394 .sorted() 395 .collect(Collectors.toList()))); 396 } 397 398 List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes()); 399 400 ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType(); 401 for (Class<?> nextTypeClass : givenTypes.keySet()) { 402 String nextTypeName = nextTypeClass.getName(); 403 for (Object nextParamValue : givenTypes.get(nextTypeClass)) { 404 Validate.isTrue( 405 nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()), 406 "Invalid params for pointcut %s - %s is not of type %s", 407 thePointcut.name(), 408 nextParamValue != null ? nextParamValue.getClass() : "null", 409 nextTypeClass); 410 Validate.isTrue( 411 wantedTypes.remove(nextTypeName), 412 "Invalid params for pointcut %s - Wanted %s but found %s", 413 thePointcut.name(), 414 toErrorString(thePointcut.getParameterTypes()), 415 nextTypeName); 416 } 417 } 418 419 return true; 420 } 421 422 private List<HookInvoker> scanInterceptorAndAddToInvokerMultimap( 423 Object theInterceptor, ListMultimap<POINTCUT, BaseInvoker> theInvokers) { 424 Class<?> interceptorClass = theInterceptor.getClass(); 425 int typeOrder = determineOrder(interceptorClass); 426 427 List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder); 428 429 // Invoke the REGISTERED pointcut for any added hooks 430 addedInvokers.stream() 431 .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut())) 432 .forEach(t -> t.invoke(new HookParams())); 433 434 // Register the interceptor and its various hooks 435 for (HookInvoker nextAddedHook : addedInvokers) { 436 POINTCUT nextPointcut = nextAddedHook.getPointcut(); 437 if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) { 438 continue; 439 } 440 theInvokers.put(nextPointcut, nextAddedHook); 441 } 442 443 // Make sure we're always sorted according to the order declared in @Order 444 for (POINTCUT nextPointcut : theInvokers.keys()) { 445 List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut); 446 nextInvokerList.sort(Comparator.naturalOrder()); 447 } 448 449 return addedInvokers; 450 } 451 452 /** 453 * @return Returns a list of any added invokers 454 */ 455 private List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) { 456 ArrayList<HookInvoker> retVal = new ArrayList<>(); 457 for (Method nextMethod : ReflectionUtil.getDeclaredMethods(theInterceptor.getClass(), true)) { 458 Optional<HookDescriptor> hook = scanForHook(nextMethod); 459 460 if (hook.isPresent()) { 461 int methodOrder = theTypeOrder; 462 int methodOrderAnnotation = hook.get().getOrder(); 463 if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) { 464 methodOrder = methodOrderAnnotation; 465 } 466 467 retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder)); 468 } 469 } 470 471 return retVal; 472 } 473 474 protected abstract Optional<HookDescriptor> scanForHook(Method nextMethod); 475 476 private class HookInvoker extends BaseInvoker { 477 478 private final Method myMethod; 479 private final Class<?>[] myParameterTypes; 480 private final int[] myParameterIndexes; 481 private final POINTCUT myPointcut; 482 483 /** 484 * Constructor 485 */ 486 private HookInvoker( 487 HookDescriptor theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) { 488 super(theInterceptor, theOrder); 489 myPointcut = theHook.getPointcut(); 490 myParameterTypes = theHookMethod.getParameterTypes(); 491 myMethod = theHookMethod; 492 493 Class<?> returnType = theHookMethod.getReturnType(); 494 if (myPointcut.getReturnType().equals(getBooleanReturnType())) { 495 Validate.isTrue( 496 getBooleanReturnType().equals(returnType) || void.class.equals(returnType), 497 "Method does not return boolean or void: %s", 498 theHookMethod); 499 } else if (myPointcut.getReturnType().equals(void.class)) { 500 Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod); 501 } else { 502 Validate.isTrue( 503 myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType), 504 "Method does not return %s or void: %s", 505 myPointcut.getReturnType(), 506 theHookMethod); 507 } 508 509 myParameterIndexes = new int[myParameterTypes.length]; 510 Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>(); 511 for (int i = 0; i < myParameterTypes.length; i++) { 512 AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0)); 513 myParameterIndexes[i] = counter.getAndIncrement(); 514 } 515 516 myMethod.setAccessible(true); 517 } 518 519 @Override 520 public String toString() { 521 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) 522 .append("method", myMethod) 523 .toString(); 524 } 525 526 public POINTCUT getPointcut() { 527 return myPointcut; 528 } 529 530 /** 531 * @return Returns true/false if the hook method returns a boolean, returns true otherwise 532 */ 533 @Override 534 Object invoke(HookParams theParams) { 535 536 Object[] args = new Object[myParameterTypes.length]; 537 for (int i = 0; i < myParameterTypes.length; i++) { 538 Class<?> nextParamType = myParameterTypes[i]; 539 if (nextParamType.equals(Pointcut.class)) { 540 args[i] = myPointcut; 541 } else { 542 int nextParamIndex = myParameterIndexes[i]; 543 Object nextParamValue = theParams.get(nextParamType, nextParamIndex); 544 args[i] = nextParamValue; 545 } 546 } 547 548 // Invoke the method 549 try { 550 return myMethod.invoke(getInterceptor(), args); 551 } catch (InvocationTargetException e) { 552 Throwable targetException = e.getTargetException(); 553 if (myPointcut.isShouldLogAndSwallowException(targetException)) { 554 ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException); 555 return null; 556 } 557 558 if (targetException instanceof RuntimeException) { 559 throw ((RuntimeException) targetException); 560 } else { 561 throw new InternalErrorException( 562 Msg.code(1910) + "Failure invoking interceptor for pointcut(s) " + getPointcut(), 563 targetException); 564 } 565 } catch (Exception e) { 566 throw new InternalErrorException(Msg.code(1911) + e); 567 } 568 } 569 } 570 571 protected class HookDescriptor { 572 573 private final POINTCUT myPointcut; 574 private final int myOrder; 575 576 public HookDescriptor(POINTCUT thePointcut, int theOrder) { 577 myPointcut = thePointcut; 578 myOrder = theOrder; 579 } 580 581 POINTCUT getPointcut() { 582 return myPointcut; 583 } 584 585 int getOrder() { 586 return myOrder; 587 } 588 } 589 590 protected abstract static class BaseInvoker implements Comparable<BaseInvoker> { 591 592 private final int myOrder; 593 private final Object myInterceptor; 594 595 BaseInvoker(Object theInterceptor, int theOrder) { 596 myInterceptor = theInterceptor; 597 myOrder = theOrder; 598 } 599 600 public Object getInterceptor() { 601 return myInterceptor; 602 } 603 604 abstract Object invoke(HookParams theParams); 605 606 @Override 607 public int compareTo(BaseInvoker theInvoker) { 608 return myOrder - theInvoker.myOrder; 609 } 610 } 611 612 protected static <T extends Annotation> Optional<T> findAnnotation( 613 AnnotatedElement theObject, Class<T> theHookClass) { 614 T annotation; 615 if (theObject instanceof Method) { 616 annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true); 617 } else { 618 annotation = theObject.getAnnotation(theHookClass); 619 } 620 return Optional.ofNullable(annotation); 621 } 622 623 private static int determineOrder(Class<?> theInterceptorClass) { 624 return findAnnotation(theInterceptorClass, Interceptor.class) 625 .map(Interceptor::order) 626 .orElse(Interceptor.DEFAULT_ORDER); 627 } 628 629 private static String toErrorString(List<String> theParameterTypes) { 630 return theParameterTypes.stream().sorted().collect(Collectors.joining(",")); 631 } 632}