001/*-
002 * #%L
003 * HAPI FHIR Storage api
004 * %%
005 * Copyright (C) 2014 - 2024 Smile CDR, Inc.
006 * %%
007 * Licensed under the Apache License, Version 2.0 (the "License");
008 * you may not use this file except in compliance with the License.
009 * You may obtain a copy of the License at
010 *
011 *      http://www.apache.org/licenses/LICENSE-2.0
012 *
013 * Unless required by applicable law or agreed to in writing, software
014 * distributed under the License is distributed on an "AS IS" BASIS,
015 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
016 * See the License for the specific language governing permissions and
017 * limitations under the License.
018 * #L%
019 */
020package ca.uhn.fhir.jpa.util;
021
022import ca.uhn.fhir.util.StopWatch;
023import com.google.common.collect.Queues;
024import jakarta.annotation.Nonnull;
025import net.ttddyy.dsproxy.support.ProxyDataSourceBuilder;
026import org.apache.commons.collections4.queue.CircularFifoQueue;
027import org.hl7.fhir.r4.model.InstantType;
028import org.slf4j.Logger;
029import org.slf4j.LoggerFactory;
030
031import java.util.ArrayList;
032import java.util.Arrays;
033import java.util.Collections;
034import java.util.Date;
035import java.util.List;
036import java.util.Locale;
037import java.util.Queue;
038import java.util.concurrent.atomic.AtomicInteger;
039import java.util.function.Predicate;
040import java.util.stream.Collectors;
041import java.util.stream.Stream;
042
043/**
044 * This is a query listener designed to be plugged into a {@link ProxyDataSourceBuilder proxy DataSource}.
045 * This listener keeps the last 1000 queries across all threads in a {@link CircularFifoQueue}, dropping queries off the
046 * end of the list as new ones are added.
047 * <p>
048 * Note that this class is really only designed for use in testing - It adds a non-trivial overhead
049 * to each query.
050 * </p>
051 */
052public class CircularQueueCaptureQueriesListener extends BaseCaptureQueriesListener {
053
054        public static final Predicate<String> DEFAULT_SELECT_INCLUSION_CRITERIA =
055                        t -> t.toLowerCase(Locale.US).startsWith("select");
056        private static final int CAPACITY = 1000;
057        private static final Logger ourLog = LoggerFactory.getLogger(CircularQueueCaptureQueriesListener.class);
058        private Queue<SqlQuery> myQueries;
059        private AtomicInteger myCommitCounter;
060        private AtomicInteger myRollbackCounter;
061
062        @Nonnull
063        private Predicate<String> mySelectQueryInclusionCriteria = DEFAULT_SELECT_INCLUSION_CRITERIA;
064
065        /**
066         * Constructor
067         */
068        public CircularQueueCaptureQueriesListener() {
069                startCollecting();
070        }
071
072        /**
073         * Sets an alternate inclusion criteria for select queries. This can be used to add
074         * additional criteria beyond the default value of {@link #DEFAULT_SELECT_INCLUSION_CRITERIA}.
075         */
076        public CircularQueueCaptureQueriesListener setSelectQueryInclusionCriteria(
077                        @Nonnull Predicate<String> theSelectQueryInclusionCriteria) {
078                mySelectQueryInclusionCriteria = theSelectQueryInclusionCriteria;
079                return this;
080        }
081
082        @Override
083        protected Queue<SqlQuery> provideQueryList() {
084                return myQueries;
085        }
086
087        @Override
088        protected AtomicInteger provideCommitCounter() {
089                return myCommitCounter;
090        }
091
092        @Override
093        protected AtomicInteger provideRollbackCounter() {
094                return myRollbackCounter;
095        }
096
097        /**
098         * Clear all stored queries
099         */
100        public void clear() {
101                myQueries.clear();
102                myCommitCounter.set(0);
103                myRollbackCounter.set(0);
104        }
105
106        /**
107         * Start collecting queries (this is the default)
108         */
109        public void startCollecting() {
110                myQueries = Queues.synchronizedQueue(new CircularFifoQueue<>(CAPACITY));
111                myCommitCounter = new AtomicInteger(0);
112                myRollbackCounter = new AtomicInteger(0);
113        }
114
115        /**
116         * Stop collecting queries and discard any collected ones
117         */
118        public void stopCollecting() {
119                myQueries = null;
120                myCommitCounter = null;
121                myRollbackCounter = null;
122        }
123
124        /**
125         * Index 0 is oldest
126         */
127        @SuppressWarnings("UseBulkOperation")
128        public List<SqlQuery> getCapturedQueries() {
129                // Make a copy so that we aren't affected by changes to the list outside of the
130                // synchronized block
131                ArrayList<SqlQuery> retVal = new ArrayList<>(CAPACITY);
132                myQueries.forEach(retVal::add);
133                return Collections.unmodifiableList(retVal);
134        }
135
136        private List<SqlQuery> getQueriesForCurrentThreadStartingWith(String theStart) {
137                String threadName = Thread.currentThread().getName();
138                return getQueriesStartingWith(theStart, threadName);
139        }
140
141        private List<SqlQuery> getQueriesStartingWith(String theStart, String theThreadName) {
142                return getCapturedQueries().stream()
143                                .filter(t -> theThreadName == null || t.getThreadName().equals(theThreadName))
144                                .filter(t -> t.getSql(false, false).toLowerCase().startsWith(theStart))
145                                .collect(Collectors.toList());
146        }
147
148        private List<SqlQuery> getQueriesStartingWith(String theStart) {
149                return getQueriesStartingWith(theStart, null);
150        }
151
152        private List<SqlQuery> getQueriesMatching(Predicate<String> thePredicate, String theThreadName) {
153                return getCapturedQueries().stream()
154                                .filter(t -> theThreadName == null || t.getThreadName().equals(theThreadName))
155                                .filter(t -> thePredicate.test(t.getSql(false, false)))
156                                .collect(Collectors.toList());
157        }
158
159        private List<SqlQuery> getQueriesMatching(Predicate<String> thePredicate) {
160                return getQueriesMatching(thePredicate, null);
161        }
162
163        private List<SqlQuery> getQueriesForCurrentThreadMatching(Predicate<String> thePredicate) {
164                String threadName = Thread.currentThread().getName();
165                return getQueriesMatching(thePredicate, threadName);
166        }
167
168        public int getCommitCount() {
169                return myCommitCounter.get();
170        }
171
172        public int getRollbackCount() {
173                return myRollbackCounter.get();
174        }
175
176        /**
177         * Returns all SELECT queries executed on the current thread - Index 0 is oldest
178         */
179        public List<SqlQuery> getSelectQueries() {
180                return getQueriesMatching(mySelectQueryInclusionCriteria);
181        }
182
183        /**
184         * Returns all INSERT queries executed on the current thread - Index 0 is oldest
185         */
186        public List<SqlQuery> getInsertQueries() {
187                return getQueriesStartingWith("insert");
188        }
189
190        /**
191         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
192         */
193        public List<SqlQuery> getUpdateQueries() {
194                return getQueriesStartingWith("update");
195        }
196
197        /**
198         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
199         */
200        public List<SqlQuery> getDeleteQueries() {
201                return getQueriesStartingWith("delete");
202        }
203
204        /**
205         * Returns all SELECT queries executed on the current thread - Index 0 is oldest
206         */
207        public List<SqlQuery> getSelectQueriesForCurrentThread() {
208                return getQueriesForCurrentThreadMatching(mySelectQueryInclusionCriteria);
209        }
210
211        /**
212         * Returns all INSERT queries executed on the current thread - Index 0 is oldest
213         */
214        public List<SqlQuery> getInsertQueriesForCurrentThread() {
215                return getQueriesForCurrentThreadStartingWith("insert");
216        }
217
218        /**
219         * Returns all queries executed on the current thread - Index 0 is oldest
220         */
221        public List<SqlQuery> getAllQueriesForCurrentThread() {
222                return getQueriesForCurrentThreadStartingWith("");
223        }
224
225        /**
226         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
227         */
228        public List<SqlQuery> getUpdateQueriesForCurrentThread() {
229                return getQueriesForCurrentThreadStartingWith("update");
230        }
231
232        /**
233         * Returns all UPDATE queries executed on the current thread - Index 0 is oldest
234         */
235        public List<SqlQuery> getDeleteQueriesForCurrentThread() {
236                return getQueriesForCurrentThreadStartingWith("delete");
237        }
238
239        /**
240         * Log all captured UPDATE queries
241         */
242        public String logUpdateQueriesForCurrentThread() {
243                List<SqlQuery> queries = getUpdateQueriesForCurrentThread();
244                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
245                String joined = String.join("\n", queriesStrings);
246                ourLog.info("Update Queries:\n{}", joined);
247                return joined;
248        }
249
250        /**
251         * Log all captured SELECT queries
252         */
253        public String logSelectQueriesForCurrentThread(int... theIndexes) {
254                List<SqlQuery> queries = getSelectQueriesForCurrentThread();
255                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
256
257                List<String> newList = new ArrayList<>();
258                if (theIndexes != null && theIndexes.length > 0) {
259                        for (int index : theIndexes) {
260                                newList.add(queriesStrings.get(index));
261                        }
262                        queriesStrings = newList;
263                }
264
265                String joined = String.join("\n", queriesStrings);
266                ourLog.info("Select Queries:\n{}", joined);
267                return joined;
268        }
269
270        /**
271         * Log all captured SELECT queries
272         */
273        public List<SqlQuery> logSelectQueries() {
274                return logSelectQueries(true, true);
275        }
276
277        /**
278         * Log all captured SELECT queries
279         */
280        public List<SqlQuery> logSelectQueries(boolean theInlineParams, boolean theFormatSql) {
281                List<SqlQuery> queries = getSelectQueries();
282                List<String> queriesStrings = renderQueriesForLogging(theInlineParams, theFormatSql, queries);
283                ourLog.info("Select Queries:\n{}", String.join("\n", queriesStrings));
284                return queries;
285        }
286
287        @Nonnull
288        private static List<String> renderQueriesForLogging(
289                        boolean theInlineParams, boolean theFormatSql, List<SqlQuery> queries) {
290                List<String> queriesStrings = new ArrayList<>();
291                for (int i = 0; i < queries.size(); i++) {
292                        SqlQuery query = queries.get(i);
293                        String remderedString = "[" + i + "] "
294                                        + CircularQueueCaptureQueriesListener.formatQueryAsSql(query, theInlineParams, theFormatSql);
295                        queriesStrings.add(remderedString);
296                }
297                return queriesStrings;
298        }
299
300        /**
301         * Log first captured SELECT query
302         */
303        public void logFirstSelectQueryForCurrentThread() {
304                boolean inlineParams = true;
305                String firstSelectQuery = getSelectQueriesForCurrentThread().stream()
306                                .findFirst()
307                                .map(t -> CircularQueueCaptureQueriesListener.formatQueryAsSql(t, inlineParams, inlineParams))
308                                .orElse("NONE FOUND");
309                ourLog.info("First select SqlQuery:\n{}", firstSelectQuery);
310        }
311
312        /**
313         * Log all captured INSERT queries
314         */
315        public String logInsertQueriesForCurrentThread() {
316                List<SqlQuery> queries = getInsertQueriesForCurrentThread();
317                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
318                String queriesAsString = String.join("\n", queriesStrings);
319                ourLog.info("Insert Queries:\n{}", queriesAsString);
320                return queriesAsString;
321        }
322
323        /**
324         * Log all captured queries
325         */
326        public void logAllQueriesForCurrentThread() {
327                List<SqlQuery> queries = getAllQueriesForCurrentThread();
328                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
329                ourLog.info("Queries:\n{}", String.join("\n", queriesStrings));
330        }
331
332        /**
333         * Log all captured queries
334         */
335        public void logAllQueries() {
336                List<SqlQuery> queries = getCapturedQueries();
337                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
338                ourLog.info("Queries:\n{}", String.join("\n", queriesStrings));
339        }
340
341        /**
342         * Log all captured INSERT queries
343         */
344        public int logInsertQueries() {
345                return logInsertQueries(null);
346        }
347
348        /**
349         * Log all captured INSERT queries
350         */
351        public int logInsertQueries(Predicate<SqlQuery> theInclusionPredicate) {
352                List<SqlQuery> insertQueries = getInsertQueries().stream()
353                                .filter(t -> theInclusionPredicate == null || theInclusionPredicate.test(t))
354                                .collect(Collectors.toList());
355                boolean inlineParams = true;
356                List<String> queries = insertQueries.stream()
357                                .map(t -> CircularQueueCaptureQueriesListener.formatQueryAsSql(t, inlineParams, inlineParams))
358                                .collect(Collectors.toList());
359                ourLog.info("Insert Queries:\n{}", String.join("\n", queries));
360
361                return countQueries(insertQueries);
362        }
363
364        /**
365         * Log all captured INSERT queries
366         */
367        public int logUpdateQueries() {
368                List<SqlQuery> queries = getUpdateQueries();
369                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
370                ourLog.info("Update Queries:\n{}", String.join("\n", queriesStrings));
371
372                return countQueries(queries);
373        }
374
375        /**
376         * Log all captured DELETE queries
377         */
378        public String logDeleteQueriesForCurrentThread() {
379                List<SqlQuery> queries = getDeleteQueriesForCurrentThread();
380                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
381                String joined = String.join("\n", queriesStrings);
382                ourLog.info("Delete Queries:\n{}", joined);
383                return joined;
384        }
385
386        /**
387         * Log all captured DELETE queries
388         */
389        public int logDeleteQueries() {
390                List<SqlQuery> queries = getDeleteQueries();
391                List<String> queriesStrings = renderQueriesForLogging(true, true, queries);
392                ourLog.info("Delete Queries:\n{}", String.join("\n", queriesStrings));
393
394                return countQueries(queries);
395        }
396
397        public int countSelectQueries() {
398                return countQueries(getSelectQueries());
399        }
400
401        public int countInsertQueries() {
402                return countQueries(getInsertQueries());
403        }
404
405        public int countUpdateQueries() {
406                return countQueries(getUpdateQueries());
407        }
408
409        public int countDeleteQueries() {
410                return countQueries(getDeleteQueries());
411        }
412
413        public int countSelectQueriesForCurrentThread() {
414                return countQueries(getSelectQueriesForCurrentThread());
415        }
416
417        public int countInsertQueriesForCurrentThread() {
418                return countQueries(getInsertQueriesForCurrentThread());
419        }
420
421        public int countUpdateQueriesForCurrentThread() {
422                return countQueries(getUpdateQueriesForCurrentThread());
423        }
424
425        public int countDeleteQueriesForCurrentThread() {
426                return countQueries(getDeleteQueriesForCurrentThread());
427        }
428
429        @Nonnull
430        private static Integer countQueries(List<SqlQuery> theQueries) {
431                return theQueries.stream().map(t -> t.getSize()).reduce(0, Integer::sum);
432        }
433
434        @Nonnull
435        static String formatQueryAsSql(SqlQuery theQuery) {
436                boolean inlineParams = true;
437                boolean formatSql = true;
438                return formatQueryAsSql(theQuery, inlineParams, formatSql);
439        }
440
441        @Nonnull
442        static String formatQueryAsSql(SqlQuery theQuery, boolean inlineParams, boolean formatSql) {
443                String formattedSql = theQuery.getSql(inlineParams, formatSql);
444                StringBuilder b = new StringBuilder();
445                b.append("SqlQuery at ");
446                b.append(new InstantType(new Date(theQuery.getQueryTimestamp())).getValueAsString());
447                if (theQuery.getRequestPartitionId() != null
448                                && theQuery.getRequestPartitionId().hasPartitionIds()) {
449                        b.append(" on partition ");
450                        b.append(theQuery.getRequestPartitionId().getPartitionIds());
451                }
452                b.append(" took ").append(StopWatch.formatMillis(theQuery.getElapsedTime()));
453                b.append(" on Thread: ").append(theQuery.getThreadName());
454                if (theQuery.getSize() > 1) {
455                        b.append("\nExecution Count: ")
456                                        .append(theQuery.getSize())
457                                        .append(" (parameters shown are for first execution)");
458                }
459                b.append("\nSQL:\n").append(formattedSql);
460                if (theQuery.getStackTrace() != null) {
461                        b.append("\nStack:\n   ");
462                        Stream<String> stackTraceStream = Arrays.stream(theQuery.getStackTrace())
463                                        .map(StackTraceElement::toString)
464                                        .filter(t -> t.startsWith("ca."));
465                        b.append(stackTraceStream.collect(Collectors.joining("\n   ")));
466                }
467                b.append("\n");
468                return b.toString();
469        }
470}