package com.atlassian.instrumentation.driver;

import com.atlassian.instrumentation.instruments.EventType;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.LongSummaryStatistics;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * Collector for getting request based information from JDBC.
 */
public class JdbcThreadLocalCollector {

    // Start listening for JDBC operations
    static {
        Instrumentation.registerFactory(context -> context.getEventType()
                .filter(et -> et == EventType.EXECUTION)
                .map(eventType ->
                {
                    final long start = System.nanoTime();
                    return (Instrumentation.Split) () ->
                            add(context.getSql().get(), System.nanoTime() - start);
                })
                .orElse(() -> {
                }));
    }

    private static final ThreadLocal<List<Counter>> counts = new ThreadLocal<>();

    public static Map<String, LongSummaryStatistics> getStatistics() {
        if (counts.get() != null) {
            // Process the map to get counts and averages.
            List<Counter> results = counts.get();
            return results.stream().collect(Collectors.groupingBy(Counter::getSql, Collectors.summarizingLong(Counter::getTime)));
        } else
            return Collections.emptyMap();
    }

    public static Map<String, Long> getMedianTime() {
        if (counts.get() != null) {
            // Process the map to get counts and averages.
            List<Counter> results = counts.get();
            Map<String, Long> medians = new HashMap<>();

            results.stream()
                    .collect(Collectors.groupingBy(Counter::getSql))
                    .forEach((key, values) -> {
                        List<Long> sorted = values.stream().map(Counter::getTime).sorted().collect(Collectors.toList());
                        if (sorted.size() > 1 && isEven(sorted.size())) {
                            Long median = (sorted.get(sorted.size() / 2 - 1) + sorted.get(sorted.size() / 2)) / 2;
                            medians.put(key, median);
                        } else if (sorted.size() > 1) {
                            Long median = sorted.get(sorted.size() / 2);
                            medians.put(key, median);
                        } else {
                            medians.put(key, sorted.get(0));
                        }
                    });
            return medians;
        } else {
            return Collections.emptyMap();
        }
    }

    public static void start() {
        counts.set(new ArrayList<>());
    }

    public static void clear() {
        counts.remove();
    }

    public static void add(String sql, long time) {
        if (counts.get() == null) {
            return; // Start was not called. Do not collect any stats.
        }
        counts.get().add(new Counter(sql, time));
    }

    static class Counter {
        private String sql;
        private long time;

        public Counter(String sql, long totalTime) {
            this.sql = sql;
            this.time = totalTime;
        }

        public String getSql() {
            return sql;
        }

        public long getTime() {
            return time;
        }
    }

    private static boolean isEven(long number) {
        return (number % 2) == 0;
    }
}
