/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.sandbox.codecs.quantization;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntCursor;
import org.apache.lucene.internal.hppc.IntHashSet;
import org.apache.lucene.sandbox.codecs.quantization.SampleReader;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.NeighborQueue;

public class KMeans {
    public static final int MAX_NUM_CENTROIDS = Short.MAX_VALUE;
    public static final int DEFAULT_RESTARTS = 5;
    public static final int DEFAULT_ITRS = 10;
    public static final int DEFAULT_SAMPLE_SIZE = 100000;
    private final FloatVectorValues vectors;
    private final int numVectors;
    private final int numCentroids;
    private final Random random;
    private final KmeansInitializationMethod initializationMethod;
    private final int restarts;
    private final int iters;

    public static Results cluster(FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters) throws IOException {
        return KMeans.cluster(vectors, numClusters, true, 42L, KmeansInitializationMethod.PLUS_PLUS, similarityFunction == VectorSimilarityFunction.COSINE, 5, 10, 100000);
    }

    public static Results cluster(FloatVectorValues vectors, int numClusters, boolean assignCentroidsToVectors, long seed, KmeansInitializationMethod initializationMethod, boolean normalizeCenters, int restarts, int iters, int sampleSize) throws IOException {
        float[][] centroids;
        if (vectors.size() == 0) {
            return null;
        }
        if (numClusters < 1 || numClusters > Short.MAX_VALUE) {
            throw new IllegalArgumentException("[numClusters] must be between [1] and [32767]");
        }
        if ((sampleSize = Math.max(sampleSize, 100 * numClusters)) > vectors.size()) {
            sampleSize = vectors.size();
            int maxNumClusters = Math.max(1, sampleSize / 100);
            numClusters = Math.min(numClusters, maxNumClusters);
        }
        Random random = new Random(seed);
        if (numClusters == 1) {
            centroids = new float[1][vectors.dimension()];
        } else {
            FloatVectorValues sampleVectors = vectors.size() <= sampleSize ? vectors : SampleReader.createSampleReader(vectors, sampleSize, seed);
            KMeans kmeans = new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters);
            centroids = kmeans.computeCentroids(normalizeCenters);
        }
        short[] vectorCentroids = null;
        if (assignCentroidsToVectors) {
            vectorCentroids = new short[vectors.size()];
            KMeans.runKMeansStep(vectors, centroids, vectorCentroids, true, normalizeCenters);
        }
        return new Results(centroids, vectorCentroids);
    }

    private KMeans(FloatVectorValues vectors, int numCentroids, Random random, KmeansInitializationMethod initializationMethod, int restarts, int iters) {
        this.vectors = vectors;
        this.numVectors = vectors.size();
        this.numCentroids = numCentroids;
        this.random = random;
        this.initializationMethod = initializationMethod;
        this.restarts = restarts;
        this.iters = iters;
    }

    private float[][] computeCentroids(boolean normalizeCenters) throws IOException {
        short[] vectorCentroids = new short[this.numVectors];
        double minSquaredDist = Double.MAX_VALUE;
        double squaredDist = 0.0;
        float[][] bestCentroids = null;
        for (int restart = 0; restart < this.restarts; ++restart) {
            float[][] centroids = switch (this.initializationMethod.ordinal()) {
                default -> throw new MatchException(null, null);
                case 0 -> this.initializeForgy();
                case 1 -> this.initializeReservoirSampling();
                case 2 -> this.initializePlusPlus();
            };
            double prevSquaredDist = Double.MAX_VALUE;
            for (int iter = 0; iter < this.iters && !(prevSquaredDist <= (squaredDist = KMeans.runKMeansStep(this.vectors, centroids, vectorCentroids, false, normalizeCenters)) + 1.0E-6); ++iter) {
                prevSquaredDist = squaredDist;
            }
            if (!(squaredDist < minSquaredDist)) continue;
            minSquaredDist = squaredDist;
            bestCentroids = centroids;
        }
        return bestCentroids;
    }

    private float[][] initializeForgy() throws IOException {
        IntHashSet selection = new IntHashSet();
        while (selection.size() < this.numCentroids) {
            selection.add(this.random.nextInt(this.numVectors));
        }
        float[][] initialCentroids = new float[this.numCentroids][];
        int i = 0;
        for (IntCursor selectedIdx : selection) {
            float[] vector = this.vectors.vectorValue(selectedIdx.value);
            initialCentroids[i++] = ArrayUtil.copyOfSubArray((float[])vector, (int)0, (int)vector.length);
        }
        return initialCentroids;
    }

    private float[][] initializeReservoirSampling() throws IOException {
        float[][] initialCentroids = new float[this.numCentroids][];
        for (int index = 0; index < this.numVectors; ++index) {
            float[] vector = this.vectors.vectorValue(index);
            if (index < this.numCentroids) {
                initialCentroids[index] = ArrayUtil.copyOfSubArray((float[])vector, (int)0, (int)vector.length);
                continue;
            }
            if (!(this.random.nextDouble() < (double)this.numCentroids * (1.0 / (double)index))) continue;
            int c = this.random.nextInt(this.numCentroids);
            initialCentroids[c] = ArrayUtil.copyOfSubArray((float[])vector, (int)0, (int)vector.length);
        }
        return initialCentroids;
    }

    private float[][] initializePlusPlus() throws IOException {
        float[][] initialCentroids = new float[this.numCentroids][];
        int firstIndex = this.random.nextInt(this.numVectors);
        float[] value = this.vectors.vectorValue(firstIndex);
        initialCentroids[0] = ArrayUtil.copyOfSubArray((float[])value, (int)0, (int)value.length);
        float[] minDistances = new float[this.numVectors];
        Arrays.fill(minDistances, Float.MAX_VALUE);
        for (int i = 1; i < this.numCentroids; ++i) {
            double totalSum = 0.0;
            for (int j = 0; j < this.numVectors; ++j) {
                float dist = VectorUtil.squareDistance((float[])this.vectors.vectorValue(j), (float[])initialCentroids[i - 1]);
                if (dist < minDistances[j]) {
                    minDistances[j] = dist;
                }
                totalSum += (double)minDistances[j];
            }
            double r = totalSum * this.random.nextDouble();
            double cumulativeSum = 0.0;
            int nextCentroidIndex = -1;
            for (int j = 0; j < this.numVectors; ++j) {
                if (!((cumulativeSum += (double)minDistances[j]) >= r) || !(minDistances[j] > 0.0f)) continue;
                nextCentroidIndex = j;
                break;
            }
            value = this.vectors.vectorValue(nextCentroidIndex);
            initialCentroids[i] = ArrayUtil.copyOfSubArray((float[])value, (int)0, (int)value.length);
        }
        return initialCentroids;
    }

    private static double runKMeansStep(FloatVectorValues vectors, float[][] centroids, short[] docCentroids, boolean useKahanSummation, boolean normalizeCentroids) throws IOException {
        int c;
        int numCentroids = centroids.length;
        float[][] newCentroids = new float[numCentroids][centroids[0].length];
        int[] newCentroidSize = new int[numCentroids];
        float[][] compensations = null;
        if (useKahanSummation) {
            compensations = new float[numCentroids][centroids[0].length];
        }
        double sumSquaredDist = 0.0;
        for (int docID = 0; docID < vectors.size(); ++docID) {
            float[] vector = vectors.vectorValue(docID);
            int bestCentroid = 0;
            if (numCentroids > 1) {
                float minSquaredDist = Float.MAX_VALUE;
                for (int c2 = 0; c2 < numCentroids; c2 = (int)((short)(c2 + 1))) {
                    float squareDist = VectorUtil.squareDistance((float[])centroids[c2], (float[])vector);
                    if (!(squareDist < minSquaredDist)) continue;
                    bestCentroid = c2;
                    minSquaredDist = squareDist;
                }
                sumSquaredDist += (double)minSquaredDist;
            }
            int n = bestCentroid;
            newCentroidSize[n] = newCentroidSize[n] + 1;
            for (int dim = 0; dim < vector.length; ++dim) {
                if (useKahanSummation) {
                    float y = vector[dim] - compensations[bestCentroid][dim];
                    float t = newCentroids[bestCentroid][dim] + y;
                    compensations[bestCentroid][dim] = t - newCentroids[bestCentroid][dim] - y;
                    newCentroids[bestCentroid][dim] = t;
                    continue;
                }
                float[] fArray = newCentroids[bestCentroid];
                int n2 = dim;
                fArray[n2] = fArray[n2] + vector[dim];
            }
            docCentroids[docID] = bestCentroid;
        }
        ArrayList<Integer> unassignedCentroids = new ArrayList<Integer>();
        for (c = 0; c < numCentroids; ++c) {
            if (newCentroidSize[c] > 0) {
                for (int dim = 0; dim < newCentroids[c].length; ++dim) {
                    centroids[c][dim] = newCentroids[c][dim] / (float)newCentroidSize[c];
                }
                continue;
            }
            unassignedCentroids.add(c);
        }
        if (unassignedCentroids.size() > 0) {
            KMeans.assignCentroids(vectors, centroids, unassignedCentroids);
        }
        if (normalizeCentroids) {
            for (c = 0; c < centroids.length; ++c) {
                VectorUtil.l2normalize((float[])centroids[c], (boolean)false);
            }
        }
        return sumSquaredDist;
    }

    static void assignCentroids(FloatVectorValues vectors, float[][] centroids, List<Integer> unassignedCentroidsIdxs) throws IOException {
        float[] vector;
        int i;
        int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()];
        int assignedIndex = 0;
        for (int i2 = 0; i2 < centroids.length; ++i2) {
            if (unassignedCentroidsIdxs.contains(i2)) continue;
            assignedCentroidsIdxs[assignedIndex++] = i2;
        }
        NeighborQueue queue = new NeighborQueue(unassignedCentroidsIdxs.size(), false);
        for (i = 0; i < vectors.size(); ++i) {
            vector = vectors.vectorValue(i);
            for (int j = 0; j < assignedCentroidsIdxs.length; j = (int)((short)(j + 1))) {
                float squareDist = VectorUtil.squareDistance((float[])centroids[assignedCentroidsIdxs[j]], (float[])vector);
                queue.insertWithOverflow(i, squareDist);
            }
        }
        for (i = 0; i < unassignedCentroidsIdxs.size(); ++i) {
            vector = vectors.vectorValue(queue.topNode());
            int unassignedCentroidIdx = unassignedCentroidsIdxs.get(i);
            centroids[unassignedCentroidIdx] = ArrayUtil.copyArray((float[])vector);
            queue.pop();
        }
    }

    public static enum KmeansInitializationMethod {
        FORGY,
        RESERVOIR_SAMPLING,
        PLUS_PLUS;

    }

    public record Results(float[][] centroids, short[] vectorCentroids) {
    }
}

