001/*-
002 * #%L
003 * HAPI FHIR Storage api
004 * %%
005 * Copyright (C) 2014 - 2025 Smile CDR, Inc.
006 * %%
007 * Licensed under the Apache License, Version 2.0 (the "License");
008 * you may not use this file except in compliance with the License.
009 * You may obtain a copy of the License at
010 *
011 *      http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 * #L%
019 */
020package ca.uhn.fhir.jpa.util;
021
022import ca.uhn.fhir.util.StopWatch;
023import com.google.common.collect.Queues;
024import jakarta.annotation.Nonnull;
025import jakarta.annotation.Nullable;
026import net.ttddyy.dsproxy.support.ProxyDataSourceBuilder;
027import org.apache.commons.collections4.queue.CircularFifoQueue;
028import org.hl7.fhir.r4.model.InstantType;
029import org.slf4j.Logger;
030import org.slf4j.LoggerFactory;
031
032import java.util.ArrayList;
033import java.util.Arrays;
034import java.util.Collections;
035import java.util.Date;
036import java.util.HashSet;
037import java.util.List;
038import java.util.Locale;
039import java.util.Queue;
040import java.util.Set;
041import java.util.concurrent.atomic.AtomicInteger;
042import java.util.function.Predicate;
043import java.util.stream.Collectors;
044import java.util.stream.Stream;
045
046/**
047 * This is a query listener designed to be plugged into a {@link ProxyDataSourceBuilder proxy DataSource}.
048 * This listener keeps the last 1000 queries across all threads in a {@link CircularFifoQueue}, dropping queries off the
049 * end of the list as new ones are added.
050 * <p>
051 * Note that this class is really only designed for use in testing - It adds a non-trivial overhead
052 * to each query.
053 * </p>
054 */
055public class CircularQueueCaptureQueriesListener extends BaseCaptureQueriesListener {
056
057        public static final Predicate<String> DEFAULT_SELECT_INCLUSION_CRITERIA =
058                        t -> t.toLowerCase(Locale.US).startsWith("select");
059        private static final int CAPACITY = 10000;
060        private static final Logger ourLog = LoggerFactory.getLogger(CircularQueueCaptureQueriesListener.class);
061        private Queue<SqlQuery> myQueries;
062        private AtomicInteger myGetConnectionCounter;
063        private AtomicInteger myCommitCounter;
064        private AtomicInteger myRollbackCounter;
065
066        @Nonnull
067        private Predicate<String> mySelectQueryInclusionCriteria = DEFAULT_SELECT_INCLUSION_CRITERIA;
068
069        /**
070         * Constructor
071         */
072        public CircularQueueCaptureQueriesListener() {
073                startCollecting();
074        }
075
076        /**
077         * Sets an alternate inclusion criteria for select queries. This can be used to add
078         * additional criteria beyond the default value of {@link #DEFAULT_SELECT_INCLUSION_CRITERIA}.
079         */
080        public CircularQueueCaptureQueriesListener setSelectQueryInclusionCriteria(
081                        @Nonnull Predicate<String> theSelectQueryInclusionCriteria) {
082                mySelectQueryInclusionCriteria = theSelectQueryInclusionCriteria;
083                return this;
084        }
085
086        @Override
087        protected Queue<SqlQuery> provideQueryList() {
088                return myQueries;
089        }
090
091        @Override
092        protected AtomicInteger provideCommitCounter() {
093                return myCommitCounter;
094        }
095
096        @Nullable
097        @Override
098        protected AtomicInteger provideGetConnectionCounter() {
099                return myGetConnectionCounter;
100        }
101
102        @Override
103        protected AtomicInteger provideRollbackCounter() {
104                return myRollbackCounter;
105        }
106
107        /**
108         * Clear all stored queries
109         */
110        public void clear() {
111                myQueries.clear();
112                myGetConnectionCounter.set(0);
113                myCommitCounter.set(0);
114                myRollbackCounter.set(0);
115        }
116
117        /**
118         * Start collecting queries (this is the default)
119         */
120        public void startCollecting() {
121                myQueries = Queues.synchronizedQueue(new CircularFifoQueue<>(CAPACITY));
122                myGetConnectionCounter = new AtomicInteger(0);
123                myCommitCounter = new AtomicInteger(0);
124                myRollbackCounter = new AtomicInteger(0);
125        }
126
127        /**
128         * Stop collecting queries and discard any collected ones
129         */
130        public void stopCollecting() {
131                myQueries = null;
132                myCommitCounter = null;
133                myRollbackCounter = null;
134        }
135
136        /**
137         * Index 0 is oldest
138         */
139        @SuppressWarnings("UseBulkOperation")
140        public List<SqlQuery> getCapturedQueries() {
141                // Make a copy so that we aren't affected by changes to the list outside of the
142                // synchronized block
143                ArrayList<SqlQuery> retVal = new ArrayList<>(CAPACITY);
144                myQueries.forEach(retVal::add);
145                return Collections.unmodifiableList(retVal);
146        }
147
148        private List<SqlQuery> getQueriesForCurrentThreadStartingWith(String theStart) {
149                String threadName = Thread.currentThread().getName();
150                return getQueriesStartingWith(theStart, threadName);
151        }
152
153        private List<SqlQuery> getQueriesStartingWith(String theStart, String theThreadName) {
154                return getCapturedQueries().stream()
155                                .filter(t -> theThreadName == null || t.getThreadName().equals(theThreadName))
156                                .filter(t -> t.getSql(false, false).toLowerCase().startsWith(theStart))
157                                .collect(Collectors.toList());
158        }
159
160        private List<SqlQuery> getQueriesStartingWith(String theStart) {
161                return getQueriesStartingWith(theStart, null);
162        }
163
164        private List<SqlQuery> getQueriesMatching(Predicate<String> thePredicate, String theThreadName) {
165                return getCapturedQueries().stream()
166                                .filter(t -> theThreadName == null || t.getThreadName().equals(theThreadName))
167                                .filter(t -> thePredicate.test(t.getSql(false, false)))
168                                .collect(Collectors.toList());
169        }
170
171        private List<SqlQuery> getQueriesMatching(Predicate<String> thePredicate) {
172                return getQueriesMatching(thePredicate, null);
173        }
174
175        private List<SqlQuery> getQueriesForCurrentThreadMatching(Predicate<String> thePredicate) {
176                String threadName = Thread.currentThread().getName();
177                return getQueriesMatching(thePredicate, threadName);
178        }
179
180        /**
181         * @deprecated Use {@link #countCommits()}
182         */
183        @Deprecated
184        public int getCommitCount() {
185                return myCommitCounter.get();
186        }
187
188        /**
189         * @deprecated Use {@link #countRollbacks()}
190         */
191        @Deprecated
192        public int getRollbackCount() {
193                return myRollbackCounter.get();
194        }
195
196        /**
197         * Returns all SELECT queries executed on the current thread - Index 0 is oldest
198         */
199        public List<SqlQuery> getSelectQueries() {
200                return getQueriesMatching(mySelectQueryInclusionCriteria);
201        }
202
203        /**
204         * Returns all INSERT queries executed on the current thread - Index 0 is oldest
205         */
206        public List<SqlQuery> getInsertQueries() {
207                return getQueriesStartingWith("insert");
208        }
209
210        /**
211         * Returns all INSERT queries executed on the current thread - Index 0 is oldest
212         */
213        public List<SqlQuery> getInsertQueries(Predicate<SqlQuery> theFilter) {
214                return getQueriesStartingWith("insert").stream().filter(theFilter).collect(Collectors.toList());
215        }
216
217        /**
218         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
219         */
220        public List<SqlQuery> getUpdateQueries() {
221                return getQueriesStartingWith("update");
222        }
223
224        /**
225         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
226         */
227        public List<SqlQuery> getUpdateQueries(Predicate<SqlQuery> theFilter) {
228                return getQueriesStartingWith("update").stream().filter(theFilter).collect(Collectors.toList());
229        }
230
231        /**
232         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
233         */
234        public List<SqlQuery> getDeleteQueries() {
235                return getQueriesStartingWith("delete");
236        }
237
238        /**
239         * Returns all SELECT queries executed on the current thread - Index 0 is oldest
240         */
241        public List<SqlQuery> getSelectQueriesForCurrentThread() {
242                return getQueriesForCurrentThreadMatching(mySelectQueryInclusionCriteria);
243        }
244
245        /**
246         * Returns all INSERT queries executed on the current thread - Index 0 is oldest
247         */
248        public List<SqlQuery> getInsertQueriesForCurrentThread() {
249                return getQueriesForCurrentThreadStartingWith("insert");
250        }
251
252        /**
253         * Returns all queries executed on the current thread - Index 0 is oldest
254         */
255        public List<SqlQuery> getAllQueriesForCurrentThread() {
256                return getQueriesForCurrentThreadStartingWith("");
257        }
258
259        /**
260         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
261         */
262        public List<SqlQuery> getUpdateQueriesForCurrentThread() {
263                return getQueriesForCurrentThreadStartingWith("update");
264        }
265
266        /**
267         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
268         */
269        public List<SqlQuery> getDeleteQueriesForCurrentThread() {
270                return getQueriesForCurrentThreadStartingWith("delete");
271        }
272
273        /**
274         * Log all captured UPDATE queries
275         */
276        public String logUpdateQueriesForCurrentThread() {
277                List<SqlQuery> queries = getUpdateQueriesForCurrentThread();
278                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
279                String joined = String.join("\n", queriesStrings);
280                ourLog.info("Update Queries:\n{}", joined);
281                return joined;
282        }
283
284        /**
285         * Log all captured SELECT queries
286         */
287        public String logSelectQueriesForCurrentThread(int... theIndexes) {
288                List<SqlQuery> queries = getSelectQueriesForCurrentThread();
289                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
290
291                List<String> newList = new ArrayList<>();
292                if (theIndexes != null && theIndexes.length > 0) {
293                        for (int index : theIndexes) {
294                                newList.add(queriesStrings.get(index));
295                        }
296                        queriesStrings = newList;
297                }
298
299                String joined = String.join("\n", queriesStrings);
300                ourLog.info("Select Queries:\n{}", joined);
301                return joined;
302        }
303
304        /**
305         * Log all captured SELECT queries
306         */
307        public List<SqlQuery> logSelectQueries() {
308                return logSelectQueries(true, true);
309        }
310
311        /**
312         * Log all captured SELECT queries
313         */
314        public List<SqlQuery> logSelectQueries(boolean theInlineParams, boolean theFormatSql) {
315                List<SqlQuery> queries = getSelectQueries();
316                List<String> queriesStrings = renderQueriesForLogging(theInlineParams, theFormatSql, queries);
317                ourLog.info("Select Queries:\n{}", String.join("\n", queriesStrings));
318                return queries;
319        }
320
321        @Nonnull
322        private static List<String> renderQueriesForLogging(
323                        boolean theInlineParams, boolean theFormatSql, List<SqlQuery> queries) {
324                List<String> queriesStrings = new ArrayList<>();
325                for (int i = 0; i < queries.size(); i++) {
326                        SqlQuery query = queries.get(i);
327                        String remderedString = "[" + i + "] "
328                                        + CircularQueueCaptureQueriesListener.formatQueryAsSql(query, theInlineParams, theFormatSql);
329                        queriesStrings.add(remderedString);
330                }
331                return queriesStrings;
332        }
333
334        /**
335         * Log first captured SELECT query
336         */
337        public void logFirstSelectQueryForCurrentThread() {
338                boolean inlineParams = true;
339                String firstSelectQuery = getSelectQueriesForCurrentThread().stream()
340                                .findFirst()
341                                .map(t -> CircularQueueCaptureQueriesListener.formatQueryAsSql(t, inlineParams, inlineParams))
342                                .orElse("NONE FOUND");
343                ourLog.info("First select SqlQuery:\n{}", firstSelectQuery);
344        }
345
346        /**
347         * Log all captured INSERT queries
348         */
349        public String logInsertQueriesForCurrentThread() {
350                List<SqlQuery> queries = getInsertQueriesForCurrentThread();
351                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
352                String queriesAsString = String.join("\n", queriesStrings);
353                ourLog.info("Insert Queries:\n{}", queriesAsString);
354                return queriesAsString;
355        }
356
357        /**
358         * Log all captured queries
359         */
360        public void logAllQueriesForCurrentThread() {
361                List<SqlQuery> queries = getAllQueriesForCurrentThread();
362                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
363                ourLog.info("Queries:\n{}", String.join("\n", queriesStrings));
364        }
365
366        /**
367         * Log all captured queries
368         */
369        public void logAllQueries() {
370                List<SqlQuery> queries = getCapturedQueries();
371                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
372                ourLog.info("Queries:\n{}", String.join("\n", queriesStrings));
373        }
374
375        /**
376         * Log all captured INSERT queries
377         */
378        public int logInsertQueries() {
379                return logInsertQueries(null);
380        }
381
382        /**
383         * Log all captured INSERT queries
384         */
385        public int logInsertQueries(Predicate<SqlQuery> theInclusionPredicate) {
386                List<SqlQuery> insertQueries = getInsertQueries().stream()
387                                .filter(t -> theInclusionPredicate == null || theInclusionPredicate.test(t))
388                                .collect(Collectors.toList());
389                List<String> queriesStrings = renderQueriesForLogging(true, true, insertQueries);
390                ourLog.info("Insert Queries:\n{}", String.join("\n", queriesStrings));
391
392                return countQueries(insertQueries);
393        }
394
395        /**
396         * Log all captured INSERT queries
397         */
398        public int logUpdateQueries() {
399                List<SqlQuery> queries = getUpdateQueries();
400                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
401                ourLog.info("Update Queries:\n{}", String.join("\n", queriesStrings));
402
403                return countQueries(queries);
404        }
405
406        /**
407         * Log all captured DELETE queries
408         */
409        public String logDeleteQueriesForCurrentThread() {
410                List<SqlQuery> queries = getDeleteQueriesForCurrentThread();
411                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
412                String joined = String.join("\n", queriesStrings);
413                ourLog.info("Delete Queries:\n{}", joined);
414                return joined;
415        }
416
417        /**
418         * Log all captured DELETE queries
419         */
420        public int logDeleteQueries() {
421                List<SqlQuery> queries = getDeleteQueries();
422                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
423                ourLog.info("Delete Queries:\n{}", String.join("\n", queriesStrings));
424
425                return countQueries(queries);
426        }
427
428        public int countSelectQueries() {
429                return countQueries(getSelectQueries());
430        }
431
432        public int countInsertQueries() {
433                return countQueries(getInsertQueries());
434        }
435
436        /**
437         * This method returns a count of insert queries which were issued more than
438         * once. If a query was issued twice with batched parameters, that does not
439         * count as a repeated query, but if it is issued as two separate insert
440         * statements with a single set of parameters this counts as a repeated
441         * query.
442         */
443        public int countInsertQueriesRepeated() {
444                return getInsertQueries(new DuplicateCheckingPredicate()).size();
445        }
446
447        public int countUpdateQueries() {
448                return countQueries(getUpdateQueries());
449        }
450
451        public int countDeleteQueries() {
452                return countQueries(getDeleteQueries());
453        }
454
455        public int countSelectQueriesForCurrentThread() {
456                return countQueries(getSelectQueriesForCurrentThread());
457        }
458
459        public int countInsertQueriesForCurrentThread() {
460                return countQueries(getInsertQueriesForCurrentThread());
461        }
462
463        public int countUpdateQueriesForCurrentThread() {
464                return countQueries(getUpdateQueriesForCurrentThread());
465        }
466
467        public int countDeleteQueriesForCurrentThread() {
468                return countQueries(getDeleteQueriesForCurrentThread());
469        }
470
471        @Nonnull
472        private static Integer countQueries(List<SqlQuery> theQueries) {
473                return theQueries.stream().map(t -> t.getSize()).reduce(0, Integer::sum);
474        }
475
476        @Nonnull
477        static String formatQueryAsSql(SqlQuery theQuery) {
478                boolean inlineParams = true;
479                boolean formatSql = true;
480                return formatQueryAsSql(theQuery, inlineParams, formatSql);
481        }
482
483        @Nonnull
484        static String formatQueryAsSql(SqlQuery theQuery, boolean inlineParams, boolean formatSql) {
485                String formattedSql = theQuery.getSql(inlineParams, formatSql);
486                StringBuilder b = new StringBuilder();
487                b.append("SqlQuery at ");
488                b.append(new InstantType(new Date(theQuery.getQueryTimestamp())).getValueAsString());
489                if (theQuery.getRequestPartitionId() != null
490                                && theQuery.getRequestPartitionId().hasPartitionIds()) {
491                        b.append(" on partition ");
492                        b.append(theQuery.getRequestPartitionId().getPartitionIds());
493                }
494                b.append(" took ").append(StopWatch.formatMillis(theQuery.getElapsedTime()));
495                b.append(" on Thread: ").append(theQuery.getThreadName());
496                if (theQuery.getSize() > 1) {
497                        b.append("\nExecution Count: ")
498                                        .append(theQuery.getSize())
499                                        .append(" (parameters shown are for first execution)");
500                }
501                b.append("\nSQL:\n").append(formattedSql);
502                if (theQuery.getStackTrace() != null) {
503                        b.append("\nStack:\n   ");
504                        Stream<String> stackTraceStream = Arrays.stream(theQuery.getStackTrace())
505                                        .map(StackTraceElement::toString)
506                                        .filter(t -> t.startsWith("ca."));
507                        b.append(stackTraceStream.collect(Collectors.joining("\n   ")));
508                }
509                b.append("\n");
510                return b.toString();
511        }
512
513        /**
514         * Create a new instance of this each time you use it
515         */
516        private static class DuplicateCheckingPredicate implements Predicate<SqlQuery> {
517                private final Set<String> mySeenSql = new HashSet<>();
518
519                @Override
520                public boolean test(SqlQuery theSqlQuery) {
521                        String sql = theSqlQuery.getSql(false, false);
522                        boolean retVal = !mySeenSql.add(sql);
523                        ourLog.trace("SQL duplicate[{}]: {}", retVal, sql);
524                        return retVal;
525                }
526        }
527}