001/*- 002 * #%L 003 * HAPI FHIR - Core Library 004 * %% 005 * Copyright (C) 2014 - 2024 Smile CDR, Inc. 006 * %% 007 * Licensed under the Apache License, Version 2.0 (the "License"); 008 * you may not use this file except in compliance with the License. 009 * You may obtain a copy of the License at 010 * 011 * http://www.apache.org/licenses/LICENSE-2.0 012 * 013 * Unless required by applicable law or agreed to in writing, software 014 * distributed under the License is distributed on an "AS IS" BASIS, 015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 016 * See the License for the specific language governing permissions and 017 * limitations under the License. 018 * #L% 019 */ 020package ca.uhn.fhir.interceptor.executor; 021 022import ca.uhn.fhir.i18n.Msg; 023import ca.uhn.fhir.interceptor.api.HookParams; 024import ca.uhn.fhir.interceptor.api.IBaseInterceptorBroadcaster; 025import ca.uhn.fhir.interceptor.api.IBaseInterceptorService; 026import ca.uhn.fhir.interceptor.api.IPointcut; 027import ca.uhn.fhir.interceptor.api.Interceptor; 028import ca.uhn.fhir.interceptor.api.Pointcut; 029import ca.uhn.fhir.rest.server.exceptions.InternalErrorException; 030import ca.uhn.fhir.util.ReflectionUtil; 031import com.google.common.annotations.VisibleForTesting; 032import com.google.common.collect.ArrayListMultimap; 033import com.google.common.collect.ListMultimap; 034import io.opentelemetry.api.trace.Span; 035import io.opentelemetry.instrumentation.annotations.WithSpan; 036import jakarta.annotation.Nonnull; 037import jakarta.annotation.Nullable; 038import org.apache.commons.lang3.Validate; 039import org.apache.commons.lang3.builder.ToStringBuilder; 040import org.apache.commons.lang3.builder.ToStringStyle; 041import org.apache.commons.lang3.reflect.MethodUtils; 042import org.slf4j.Logger; 043import org.slf4j.LoggerFactory; 044 045import java.lang.annotation.Annotation; 046import java.lang.reflect.AnnotatedElement; 047import java.lang.reflect.InvocationTargetException; 048import java.lang.reflect.Method; 049import java.util.ArrayList; 050import java.util.Arrays; 051import java.util.Collection; 052import java.util.Collections; 053import java.util.Comparator; 054import java.util.EnumSet; 055import java.util.HashMap; 056import java.util.IdentityHashMap; 057import java.util.List; 058import java.util.Map; 059import java.util.Objects; 060import java.util.Optional; 061import java.util.concurrent.atomic.AtomicInteger; 062import java.util.function.Predicate; 063import java.util.stream.Collectors; 064 065public abstract class BaseInterceptorService<POINTCUT extends Enum<POINTCUT> & IPointcut> 066 implements IBaseInterceptorService<POINTCUT>, IBaseInterceptorBroadcaster<POINTCUT> { 067 private static final Logger ourLog = LoggerFactory.getLogger(BaseInterceptorService.class); 068 private final List<Object> myInterceptors = new ArrayList<>(); 069 private final ListMultimap<POINTCUT, BaseInvoker> myGlobalInvokers = ArrayListMultimap.create(); 070 private final ListMultimap<POINTCUT, BaseInvoker> myAnonymousInvokers = ArrayListMultimap.create(); 071 private final Object myRegistryMutex = new Object(); 072 private final Class<POINTCUT> myPointcutType; 073 private volatile EnumSet<POINTCUT> myRegisteredPointcuts; 074 private String myName; 075 private boolean myWarnOnInterceptorWithNoHooks = true; 076 077 /** 078 * Constructor which uses a default name of "default" 079 */ 080 public BaseInterceptorService(Class<POINTCUT> thePointcutType) { 081 this(thePointcutType, "default"); 082 } 083 084 /** 085 * Constructor 086 * 087 * @param theName The name for this registry (useful for troubleshooting) 088 */ 089 public BaseInterceptorService(Class<POINTCUT> thePointcutType, String theName) { 090 super(); 091 myName = theName; 092 myPointcutType = thePointcutType; 093 rebuildRegisteredPointcutSet(); 094 } 095 096 /** 097 * Should a warning be issued if an interceptor is registered and it has no hooks 098 */ 099 public void setWarnOnInterceptorWithNoHooks(boolean theWarnOnInterceptorWithNoHooks) { 100 myWarnOnInterceptorWithNoHooks = theWarnOnInterceptorWithNoHooks; 101 } 102 103 @VisibleForTesting 104 List<Object> getGlobalInterceptorsForUnitTest() { 105 return myInterceptors; 106 } 107 108 public void setName(String theName) { 109 myName = theName; 110 } 111 112 protected void registerAnonymousInterceptor(POINTCUT thePointcut, Object theInterceptor, BaseInvoker theInvoker) { 113 Validate.notNull(thePointcut); 114 Validate.notNull(theInterceptor); 115 synchronized (myRegistryMutex) { 116 myAnonymousInvokers.put(thePointcut, theInvoker); 117 if (!isInterceptorAlreadyRegistered(theInterceptor)) { 118 myInterceptors.add(theInterceptor); 119 } 120 121 rebuildRegisteredPointcutSet(); 122 } 123 } 124 125 @Override 126 public List<Object> getAllRegisteredInterceptors() { 127 synchronized (myRegistryMutex) { 128 List<Object> retVal = new ArrayList<>(myInterceptors); 129 return Collections.unmodifiableList(retVal); 130 } 131 } 132 133 @Override 134 @VisibleForTesting 135 public void unregisterAllInterceptors() { 136 synchronized (myRegistryMutex) { 137 unregisterInterceptors(myAnonymousInvokers.values()); 138 unregisterInterceptors(myGlobalInvokers.values()); 139 unregisterInterceptors(myInterceptors); 140 } 141 } 142 143 @Override 144 public void unregisterInterceptors(@Nullable Collection<?> theInterceptors) { 145 if (theInterceptors != null) { 146 // We construct a new list before iterating because the service's internal 147 // interceptor lists get passed into this method, and we get concurrent 148 // modification errors if we modify them at the same time as we iterate them 149 new ArrayList<>(theInterceptors).forEach(this::unregisterInterceptor); 150 } 151 } 152 153 @Override 154 public void registerInterceptors(@Nullable Collection<?> theInterceptors) { 155 if (theInterceptors != null) { 156 theInterceptors.forEach(this::registerInterceptor); 157 } 158 } 159 160 @Override 161 public void unregisterAllAnonymousInterceptors() { 162 synchronized (myRegistryMutex) { 163 unregisterInterceptorsIf(t -> true, myAnonymousInvokers); 164 } 165 } 166 167 @Override 168 public void unregisterInterceptorsIf(Predicate<Object> theShouldUnregisterFunction) { 169 unregisterInterceptorsIf(theShouldUnregisterFunction, myGlobalInvokers); 170 unregisterInterceptorsIf(theShouldUnregisterFunction, myAnonymousInvokers); 171 } 172 173 private void unregisterInterceptorsIf( 174 Predicate<Object> theShouldUnregisterFunction, ListMultimap<POINTCUT, BaseInvoker> theGlobalInvokers) { 175 synchronized (myRegistryMutex) { 176 for (Map.Entry<POINTCUT, BaseInvoker> nextInvoker : new ArrayList<>(theGlobalInvokers.entries())) { 177 if (theShouldUnregisterFunction.test(nextInvoker.getValue().getInterceptor())) { 178 unregisterInterceptor(nextInvoker.getValue().getInterceptor()); 179 } 180 } 181 182 rebuildRegisteredPointcutSet(); 183 } 184 } 185 186 @Override 187 public boolean registerInterceptor(Object theInterceptor) { 188 synchronized (myRegistryMutex) { 189 if (isInterceptorAlreadyRegistered(theInterceptor)) { 190 return false; 191 } 192 193 List<HookInvoker> addedInvokers = scanInterceptorAndAddToInvokerMultimap(theInterceptor, myGlobalInvokers); 194 if (addedInvokers.isEmpty()) { 195 if (myWarnOnInterceptorWithNoHooks) { 196 ourLog.warn( 197 "Interceptor registered with no valid hooks - Type was: {}", 198 theInterceptor.getClass().getName()); 199 } 200 return false; 201 } 202 203 // Add to the global list 204 myInterceptors.add(theInterceptor); 205 sortByOrderAnnotation(myInterceptors); 206 207 rebuildRegisteredPointcutSet(); 208 209 return true; 210 } 211 } 212 213 private void rebuildRegisteredPointcutSet() { 214 EnumSet<POINTCUT> registeredPointcuts = EnumSet.noneOf(myPointcutType); 215 registeredPointcuts.addAll(myAnonymousInvokers.keySet()); 216 registeredPointcuts.addAll(myGlobalInvokers.keySet()); 217 myRegisteredPointcuts = registeredPointcuts; 218 } 219 220 private boolean isInterceptorAlreadyRegistered(Object theInterceptor) { 221 for (Object next : myInterceptors) { 222 if (next == theInterceptor) { 223 return true; 224 } 225 } 226 return false; 227 } 228 229 @Override 230 public boolean unregisterInterceptor(Object theInterceptor) { 231 synchronized (myRegistryMutex) { 232 boolean removed = myInterceptors.removeIf(t -> t == theInterceptor); 233 removed |= myGlobalInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 234 removed |= myAnonymousInvokers.entries().removeIf(t -> t.getValue().getInterceptor() == theInterceptor); 235 rebuildRegisteredPointcutSet(); 236 return removed; 237 } 238 } 239 240 private void sortByOrderAnnotation(List<Object> theObjects) { 241 IdentityHashMap<Object, Integer> interceptorToOrder = new IdentityHashMap<>(); 242 for (Object next : theObjects) { 243 Interceptor orderAnnotation = next.getClass().getAnnotation(Interceptor.class); 244 int order = orderAnnotation != null ? orderAnnotation.order() : 0; 245 interceptorToOrder.put(next, order); 246 } 247 248 theObjects.sort((a, b) -> { 249 Integer orderA = interceptorToOrder.get(a); 250 Integer orderB = interceptorToOrder.get(b); 251 return orderA - orderB; 252 }); 253 } 254 255 @Override 256 public Object callHooksAndReturnObject(POINTCUT thePointcut, HookParams theParams) { 257 assert haveAppropriateParams(thePointcut, theParams); 258 assert thePointcut.getReturnType() != void.class; 259 260 return doCallHooks(thePointcut, theParams, null); 261 } 262 263 @Override 264 public boolean hasHooks(POINTCUT thePointcut) { 265 return myRegisteredPointcuts.contains(thePointcut); 266 } 267 268 protected Class<?> getBooleanReturnType() { 269 return boolean.class; 270 } 271 272 @Override 273 public boolean callHooks(POINTCUT thePointcut, HookParams theParams) { 274 assert haveAppropriateParams(thePointcut, theParams); 275 assert thePointcut.getReturnType() == void.class || thePointcut.getReturnType() == getBooleanReturnType(); 276 277 Object retValObj = doCallHooks(thePointcut, theParams, true); 278 return (Boolean) retValObj; 279 } 280 281 private Object doCallHooks(POINTCUT thePointcut, HookParams theParams, Object theRetVal) { 282 // use new list for loop to avoid ConcurrentModificationException in case invoker gets added while looping 283 List<BaseInvoker> invokers = new ArrayList<>(getInvokersForPointcut(thePointcut)); 284 285 /* 286 * Call each hook in order 287 */ 288 for (BaseInvoker nextInvoker : invokers) { 289 Object nextOutcome = nextInvoker.invoke(theParams); 290 Class<?> pointcutReturnType = thePointcut.getReturnType(); 291 if (pointcutReturnType.equals(getBooleanReturnType())) { 292 Boolean nextOutcomeAsBoolean = (Boolean) nextOutcome; 293 if (Boolean.FALSE.equals(nextOutcomeAsBoolean)) { 294 ourLog.trace("callHooks({}) for invoker({}) returned false", thePointcut, nextInvoker); 295 theRetVal = false; 296 break; 297 } else { 298 theRetVal = true; 299 } 300 } else if (!pointcutReturnType.equals(void.class)) { 301 if (nextOutcome != null) { 302 theRetVal = nextOutcome; 303 break; 304 } 305 } 306 } 307 308 return theRetVal; 309 } 310 311 @VisibleForTesting 312 List<Object> getInterceptorsWithInvokersForPointcut(POINTCUT thePointcut) { 313 return getInvokersForPointcut(thePointcut).stream() 314 .map(BaseInvoker::getInterceptor) 315 .collect(Collectors.toList()); 316 } 317 318 /** 319 * Returns an ordered list of invokers for the given pointcut. Note that 320 * a new and stable list is returned to.. do whatever you want with it. 321 */ 322 private List<BaseInvoker> getInvokersForPointcut(POINTCUT thePointcut) { 323 List<BaseInvoker> invokers; 324 325 synchronized (myRegistryMutex) { 326 List<BaseInvoker> globalInvokers = myGlobalInvokers.get(thePointcut); 327 List<BaseInvoker> anonymousInvokers = myAnonymousInvokers.get(thePointcut); 328 List<BaseInvoker> threadLocalInvokers = null; 329 invokers = union(globalInvokers, anonymousInvokers, threadLocalInvokers); 330 } 331 332 return invokers; 333 } 334 335 /** 336 * First argument must be the global invoker list!! 337 */ 338 @SafeVarargs 339 private List<BaseInvoker> union(List<BaseInvoker>... theInvokersLists) { 340 List<BaseInvoker> haveOne = null; 341 boolean haveMultiple = false; 342 for (List<BaseInvoker> nextInvokerList : theInvokersLists) { 343 if (nextInvokerList == null || nextInvokerList.isEmpty()) { 344 continue; 345 } 346 347 if (haveOne == null) { 348 haveOne = nextInvokerList; 349 } else { 350 haveMultiple = true; 351 } 352 } 353 354 if (haveOne == null) { 355 return Collections.emptyList(); 356 } 357 358 List<BaseInvoker> retVal; 359 360 if (!haveMultiple) { 361 362 // The global list doesn't need to be sorted every time since it's sorted on 363 // insertion each time. Doing so is a waste of cycles.. 364 if (haveOne == theInvokersLists[0]) { 365 retVal = haveOne; 366 } else { 367 retVal = new ArrayList<>(haveOne); 368 retVal.sort(Comparator.naturalOrder()); 369 } 370 371 } else { 372 373 retVal = Arrays.stream(theInvokersLists) 374 .filter(Objects::nonNull) 375 .flatMap(Collection::stream) 376 .sorted() 377 .collect(Collectors.toList()); 378 } 379 380 return retVal; 381 } 382 383 /** 384 * Only call this when assertions are enabled, it's expensive 385 */ 386 final boolean haveAppropriateParams(POINTCUT thePointcut, HookParams theParams) { 387 if (theParams.getParamsForType().values().size() 388 != thePointcut.getParameterTypes().size()) { 389 throw new IllegalArgumentException(Msg.code(1909) 390 + String.format( 391 "Wrong number of params for pointcut %s - Wanted %s but found %s", 392 thePointcut.name(), 393 toErrorString(thePointcut.getParameterTypes()), 394 theParams.getParamsForType().values().stream() 395 .map(t -> t != null ? t.getClass().getSimpleName() : "null") 396 .sorted() 397 .collect(Collectors.toList()))); 398 } 399 400 List<String> wantedTypes = new ArrayList<>(thePointcut.getParameterTypes()); 401 402 ListMultimap<Class<?>, Object> givenTypes = theParams.getParamsForType(); 403 for (Class<?> nextTypeClass : givenTypes.keySet()) { 404 String nextTypeName = nextTypeClass.getName(); 405 for (Object nextParamValue : givenTypes.get(nextTypeClass)) { 406 Validate.isTrue( 407 nextParamValue == null || nextTypeClass.isAssignableFrom(nextParamValue.getClass()), 408 "Invalid params for pointcut %s - %s is not of type %s", 409 thePointcut.name(), 410 nextParamValue != null ? nextParamValue.getClass() : "null", 411 nextTypeClass); 412 Validate.isTrue( 413 wantedTypes.remove(nextTypeName), 414 "Invalid params for pointcut %s - Wanted %s but found %s", 415 thePointcut.name(), 416 toErrorString(thePointcut.getParameterTypes()), 417 nextTypeName); 418 } 419 } 420 421 return true; 422 } 423 424 private List<HookInvoker> scanInterceptorAndAddToInvokerMultimap( 425 Object theInterceptor, ListMultimap<POINTCUT, BaseInvoker> theInvokers) { 426 Class<?> interceptorClass = theInterceptor.getClass(); 427 int typeOrder = determineOrder(interceptorClass); 428 429 List<HookInvoker> addedInvokers = scanInterceptorForHookMethods(theInterceptor, typeOrder); 430 431 // Invoke the REGISTERED pointcut for any added hooks 432 addedInvokers.stream() 433 .filter(t -> Pointcut.INTERCEPTOR_REGISTERED.equals(t.getPointcut())) 434 .forEach(t -> t.invoke(new HookParams())); 435 436 // Register the interceptor and its various hooks 437 for (HookInvoker nextAddedHook : addedInvokers) { 438 POINTCUT nextPointcut = nextAddedHook.getPointcut(); 439 if (nextPointcut.equals(Pointcut.INTERCEPTOR_REGISTERED)) { 440 continue; 441 } 442 theInvokers.put(nextPointcut, nextAddedHook); 443 } 444 445 // Make sure we're always sorted according to the order declared in @Order 446 for (POINTCUT nextPointcut : theInvokers.keys()) { 447 List<BaseInvoker> nextInvokerList = theInvokers.get(nextPointcut); 448 nextInvokerList.sort(Comparator.naturalOrder()); 449 } 450 451 return addedInvokers; 452 } 453 454 /** 455 * @return Returns a list of any added invokers 456 */ 457 private List<HookInvoker> scanInterceptorForHookMethods(Object theInterceptor, int theTypeOrder) { 458 ArrayList<HookInvoker> retVal = new ArrayList<>(); 459 for (Method nextMethod : ReflectionUtil.getDeclaredMethods(theInterceptor.getClass(), true)) { 460 Optional<HookDescriptor> hook = scanForHook(nextMethod); 461 462 if (hook.isPresent()) { 463 int methodOrder = theTypeOrder; 464 int methodOrderAnnotation = hook.get().getOrder(); 465 if (methodOrderAnnotation != Interceptor.DEFAULT_ORDER) { 466 methodOrder = methodOrderAnnotation; 467 } 468 469 retVal.add(new HookInvoker(hook.get(), theInterceptor, nextMethod, methodOrder)); 470 } 471 } 472 473 return retVal; 474 } 475 476 protected abstract Optional<HookDescriptor> scanForHook(Method nextMethod); 477 478 private class HookInvoker extends BaseInvoker { 479 480 private final Method myMethod; 481 private final Class<?>[] myParameterTypes; 482 private final int[] myParameterIndexes; 483 private final POINTCUT myPointcut; 484 485 /** 486 * Constructor 487 */ 488 private HookInvoker( 489 HookDescriptor theHook, @Nonnull Object theInterceptor, @Nonnull Method theHookMethod, int theOrder) { 490 super(theInterceptor, theOrder); 491 myPointcut = theHook.getPointcut(); 492 myParameterTypes = theHookMethod.getParameterTypes(); 493 myMethod = theHookMethod; 494 495 Class<?> returnType = theHookMethod.getReturnType(); 496 if (myPointcut.getReturnType().equals(getBooleanReturnType())) { 497 Validate.isTrue( 498 getBooleanReturnType().equals(returnType) || void.class.equals(returnType), 499 "Method does not return boolean or void: %s", 500 theHookMethod); 501 } else if (myPointcut.getReturnType().equals(void.class)) { 502 Validate.isTrue(void.class.equals(returnType), "Method does not return void: %s", theHookMethod); 503 } else { 504 Validate.isTrue( 505 myPointcut.getReturnType().isAssignableFrom(returnType) || void.class.equals(returnType), 506 "Method does not return %s or void: %s", 507 myPointcut.getReturnType(), 508 theHookMethod); 509 } 510 511 myParameterIndexes = new int[myParameterTypes.length]; 512 Map<Class<?>, AtomicInteger> typeToCount = new HashMap<>(); 513 for (int i = 0; i < myParameterTypes.length; i++) { 514 AtomicInteger counter = typeToCount.computeIfAbsent(myParameterTypes[i], t -> new AtomicInteger(0)); 515 myParameterIndexes[i] = counter.getAndIncrement(); 516 } 517 518 myMethod.setAccessible(true); 519 } 520 521 @Override 522 public String toString() { 523 return new ToStringBuilder(this, ToStringStyle.SHORT_PREFIX_STYLE) 524 .append("method", myMethod) 525 .toString(); 526 } 527 528 public POINTCUT getPointcut() { 529 return myPointcut; 530 } 531 532 /** 533 * @return Returns true/false if the hook method returns a boolean, returns true otherwise 534 */ 535 @Override 536 Object invoke(HookParams theParams) { 537 538 Object[] args = new Object[myParameterTypes.length]; 539 for (int i = 0; i < myParameterTypes.length; i++) { 540 Class<?> nextParamType = myParameterTypes[i]; 541 if (nextParamType.equals(Pointcut.class)) { 542 args[i] = myPointcut; 543 } else { 544 int nextParamIndex = myParameterIndexes[i]; 545 Object nextParamValue = theParams.get(nextParamType, nextParamIndex); 546 args[i] = nextParamValue; 547 } 548 } 549 550 // Invoke the method 551 try { 552 return invokeMethod(args); 553 } catch (InvocationTargetException e) { 554 Throwable targetException = e.getTargetException(); 555 if (myPointcut.isShouldLogAndSwallowException(targetException)) { 556 ourLog.error("Exception thrown by interceptor: " + targetException.toString(), targetException); 557 return null; 558 } 559 560 if (targetException instanceof RuntimeException) { 561 throw ((RuntimeException) targetException); 562 } else { 563 throw new InternalErrorException( 564 Msg.code(1910) + "Failure invoking interceptor for pointcut(s) " + getPointcut(), 565 targetException); 566 } 567 } catch (Exception e) { 568 throw new InternalErrorException(Msg.code(1911) + e); 569 } 570 } 571 572 @WithSpan("hapifhir.interceptor") 573 private Object invokeMethod(Object[] args) throws InvocationTargetException, IllegalAccessException { 574 // Add attributes to the opentelemetry span 575 Span currentSpan = Span.current(); 576 currentSpan.setAttribute("hapifhir.interceptor.pointcut_name", myPointcut.name()); 577 currentSpan.setAttribute( 578 "hapifhir.interceptor.class_name", 579 myMethod.getDeclaringClass().getName()); 580 currentSpan.setAttribute("hapifhir.interceptor.method_name", myMethod.getName()); 581 582 return myMethod.invoke(getInterceptor(), args); 583 } 584 } 585 586 protected class HookDescriptor { 587 588 private final POINTCUT myPointcut; 589 private final int myOrder; 590 591 public HookDescriptor(POINTCUT thePointcut, int theOrder) { 592 myPointcut = thePointcut; 593 myOrder = theOrder; 594 } 595 596 POINTCUT getPointcut() { 597 return myPointcut; 598 } 599 600 int getOrder() { 601 return myOrder; 602 } 603 } 604 605 protected abstract static class BaseInvoker implements Comparable<BaseInvoker> { 606 607 private final int myOrder; 608 private final Object myInterceptor; 609 610 BaseInvoker(Object theInterceptor, int theOrder) { 611 myInterceptor = theInterceptor; 612 myOrder = theOrder; 613 } 614 615 public Object getInterceptor() { 616 return myInterceptor; 617 } 618 619 abstract Object invoke(HookParams theParams); 620 621 @Override 622 public int compareTo(BaseInvoker theInvoker) { 623 return myOrder - theInvoker.myOrder; 624 } 625 } 626 627 protected static <T extends Annotation> Optional<T> findAnnotation( 628 AnnotatedElement theObject, Class<T> theHookClass) { 629 T annotation; 630 if (theObject instanceof Method) { 631 annotation = MethodUtils.getAnnotation((Method) theObject, theHookClass, true, true); 632 } else { 633 annotation = theObject.getAnnotation(theHookClass); 634 } 635 return Optional.ofNullable(annotation); 636 } 637 638 private static int determineOrder(Class<?> theInterceptorClass) { 639 return findAnnotation(theInterceptorClass, Interceptor.class) 640 .map(Interceptor::order) 641 .orElse(Interceptor.DEFAULT_ORDER); 642 } 643 644 private static String toErrorString(List<String> theParameterTypes) { 645 return theParameterTypes.stream().sorted().collect(Collectors.joining(",")); 646 } 647}