/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.knn.quantization.models.quantizationState;

import java.io.IOException;
import java.util.Arrays;
import lombok.Generated;
import lombok.NonNull;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationStateSerializer;
import org.opensearch.knn.quantization.util.QuantizationUtils;

public final class OneBitScalarQuantizationState
implements QuantizationState {
    @NonNull
    private final ScalarQuantizationParams quantizationParams;
    @NonNull
    private final float[] meanThresholds;
    private float[] belowThresholdMeans;
    private float[] aboveThresholdMeans;
    private float[][] rotationMatrix;

    @Override
    public ScalarQuantizationParams getQuantizationParams() {
        return this.quantizationParams;
    }

    public void writeTo(StreamOutput out) throws IOException {
        QuantizationUtils.FloatArrayWrapper[] floatArrayWrapperArray;
        QuantizationUtils.FloatArrayWrapper[] floatArrayWrapperArray2;
        out.writeVInt(Version.CURRENT.id);
        this.quantizationParams.writeTo(out);
        out.writeFloatArray(this.meanThresholds);
        if (this.rotationMatrix != null) {
            out.writeBoolean(true);
            out.writeVInt(this.rotationMatrix.length);
            for (float[] row : this.rotationMatrix) {
                out.writeFloatArray(row);
            }
        } else {
            out.writeBoolean(false);
        }
        if (this.belowThresholdMeans != null) {
            QuantizationUtils.FloatArrayWrapper[] floatArrayWrapperArray3 = new QuantizationUtils.FloatArrayWrapper[1];
            floatArrayWrapperArray2 = floatArrayWrapperArray3;
            floatArrayWrapperArray3[0] = new QuantizationUtils.FloatArrayWrapper(this.belowThresholdMeans);
        } else {
            floatArrayWrapperArray2 = null;
        }
        out.writeOptionalArray(floatArrayWrapperArray2);
        if (this.aboveThresholdMeans != null) {
            QuantizationUtils.FloatArrayWrapper[] floatArrayWrapperArray4 = new QuantizationUtils.FloatArrayWrapper[1];
            floatArrayWrapperArray = floatArrayWrapperArray4;
            floatArrayWrapperArray4[0] = new QuantizationUtils.FloatArrayWrapper(this.aboveThresholdMeans);
        } else {
            floatArrayWrapperArray = null;
        }
        out.writeOptionalArray(floatArrayWrapperArray);
    }

    public OneBitScalarQuantizationState(StreamInput in) throws IOException {
        int version = in.readVInt();
        this.quantizationParams = new ScalarQuantizationParams(in, version);
        this.meanThresholds = in.readFloatArray();
        if (Version.fromId((int)version).onOrAfter(Version.V_3_2_0) && in.readBoolean()) {
            int dimensions = in.readVInt();
            this.rotationMatrix = new float[dimensions][];
            for (int i = 0; i < dimensions; ++i) {
                this.rotationMatrix[i] = in.readFloatArray();
            }
        }
        if (Version.fromId((int)version).onOrAfter(Version.V_3_2_0)) {
            QuantizationUtils.FloatArrayWrapper[] wrappedBelowThresholdMeans = (QuantizationUtils.FloatArrayWrapper[])in.readOptionalArray(QuantizationUtils.FloatArrayWrapper::new, QuantizationUtils.FloatArrayWrapper[]::new);
            this.belowThresholdMeans = wrappedBelowThresholdMeans != null ? wrappedBelowThresholdMeans[0].getArray() : null;
            QuantizationUtils.FloatArrayWrapper[] wrappedAboveThresholdMeans = (QuantizationUtils.FloatArrayWrapper[])in.readOptionalArray(QuantizationUtils.FloatArrayWrapper::new, QuantizationUtils.FloatArrayWrapper[]::new);
            this.aboveThresholdMeans = wrappedAboveThresholdMeans != null ? wrappedAboveThresholdMeans[0].getArray() : null;
        }
    }

    public OneBitScalarQuantizationState(@NonNull ScalarQuantizationParams quantizationParams, @NonNull float[] meanThresholds) {
        if (quantizationParams == null) {
            throw new NullPointerException("quantizationParams is marked non-null but is null");
        }
        if (meanThresholds == null) {
            throw new NullPointerException("meanThresholds is marked non-null but is null");
        }
        this.quantizationParams = quantizationParams;
        this.meanThresholds = meanThresholds;
        this.rotationMatrix = null;
    }

    @Override
    public byte[] toByteArray() throws IOException {
        return QuantizationStateSerializer.serialize(this);
    }

    public static OneBitScalarQuantizationState fromByteArray(byte[] bytes) throws IOException {
        return (OneBitScalarQuantizationState)QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new);
    }

    @Override
    public int getBytesPerVector() {
        return (this.meanThresholds.length + 7) / 8;
    }

    @Override
    public int getDimensions() {
        return this.meanThresholds.length + 7 & 0xFFFFFFF8;
    }

    @Override
    public long ramBytesUsed() {
        long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class);
        size += RamUsageEstimator.shallowSizeOf((Object)this.quantizationParams);
        size += RamUsageEstimator.sizeOf((float[])this.meanThresholds);
        if (this.rotationMatrix != null) {
            size += RamUsageEstimator.shallowSizeOf((Object[])this.rotationMatrix);
            for (float[] row : this.rotationMatrix) {
                size += RamUsageEstimator.sizeOf((float[])row);
            }
        }
        if (this.belowThresholdMeans != null) {
            size += RamUsageEstimator.sizeOf((float[])this.belowThresholdMeans);
        }
        if (this.aboveThresholdMeans != null) {
            size += RamUsageEstimator.sizeOf((float[])this.aboveThresholdMeans);
        }
        return size;
    }

    @Generated
    private static float[] $default$belowThresholdMeans() {
        return null;
    }

    @Generated
    private static float[] $default$aboveThresholdMeans() {
        return null;
    }

    @Generated
    private static float[][] $default$rotationMatrix() {
        return null;
    }

    @Generated
    public static OneBitScalarQuantizationStateBuilder builder() {
        return new OneBitScalarQuantizationStateBuilder();
    }

    @NonNull
    @Generated
    public float[] getMeanThresholds() {
        return this.meanThresholds;
    }

    @Generated
    public float[] getBelowThresholdMeans() {
        return this.belowThresholdMeans;
    }

    @Generated
    public float[] getAboveThresholdMeans() {
        return this.aboveThresholdMeans;
    }

    @Generated
    public float[][] getRotationMatrix() {
        return this.rotationMatrix;
    }

    @Generated
    public OneBitScalarQuantizationState(@NonNull ScalarQuantizationParams quantizationParams, @NonNull float[] meanThresholds, float[] belowThresholdMeans, float[] aboveThresholdMeans, float[][] rotationMatrix) {
        if (quantizationParams == null) {
            throw new NullPointerException("quantizationParams is marked non-null but is null");
        }
        if (meanThresholds == null) {
            throw new NullPointerException("meanThresholds is marked non-null but is null");
        }
        this.quantizationParams = quantizationParams;
        this.meanThresholds = meanThresholds;
        this.belowThresholdMeans = belowThresholdMeans;
        this.aboveThresholdMeans = aboveThresholdMeans;
        this.rotationMatrix = rotationMatrix;
    }

    @Generated
    public OneBitScalarQuantizationState() {
        this.quantizationParams = null;
        this.meanThresholds = null;
        this.belowThresholdMeans = OneBitScalarQuantizationState.$default$belowThresholdMeans();
        this.aboveThresholdMeans = OneBitScalarQuantizationState.$default$aboveThresholdMeans();
        this.rotationMatrix = OneBitScalarQuantizationState.$default$rotationMatrix();
    }

    @Generated
    public static class OneBitScalarQuantizationStateBuilder {
        @Generated
        private ScalarQuantizationParams quantizationParams;
        @Generated
        private float[] meanThresholds;
        @Generated
        private boolean belowThresholdMeans$set;
        @Generated
        private float[] belowThresholdMeans$value;
        @Generated
        private boolean aboveThresholdMeans$set;
        @Generated
        private float[] aboveThresholdMeans$value;
        @Generated
        private boolean rotationMatrix$set;
        @Generated
        private float[][] rotationMatrix$value;

        @Generated
        OneBitScalarQuantizationStateBuilder() {
        }

        @Generated
        public OneBitScalarQuantizationStateBuilder quantizationParams(@NonNull ScalarQuantizationParams quantizationParams) {
            if (quantizationParams == null) {
                throw new NullPointerException("quantizationParams is marked non-null but is null");
            }
            this.quantizationParams = quantizationParams;
            return this;
        }

        @Generated
        public OneBitScalarQuantizationStateBuilder meanThresholds(@NonNull float[] meanThresholds) {
            if (meanThresholds == null) {
                throw new NullPointerException("meanThresholds is marked non-null but is null");
            }
            this.meanThresholds = meanThresholds;
            return this;
        }

        @Generated
        public OneBitScalarQuantizationStateBuilder belowThresholdMeans(float[] belowThresholdMeans) {
            this.belowThresholdMeans$value = belowThresholdMeans;
            this.belowThresholdMeans$set = true;
            return this;
        }

        @Generated
        public OneBitScalarQuantizationStateBuilder aboveThresholdMeans(float[] aboveThresholdMeans) {
            this.aboveThresholdMeans$value = aboveThresholdMeans;
            this.aboveThresholdMeans$set = true;
            return this;
        }

        @Generated
        public OneBitScalarQuantizationStateBuilder rotationMatrix(float[][] rotationMatrix) {
            this.rotationMatrix$value = rotationMatrix;
            this.rotationMatrix$set = true;
            return this;
        }

        @Generated
        public OneBitScalarQuantizationState build() {
            float[] belowThresholdMeans$value = this.belowThresholdMeans$value;
            if (!this.belowThresholdMeans$set) {
                belowThresholdMeans$value = OneBitScalarQuantizationState.$default$belowThresholdMeans();
            }
            float[] aboveThresholdMeans$value = this.aboveThresholdMeans$value;
            if (!this.aboveThresholdMeans$set) {
                aboveThresholdMeans$value = OneBitScalarQuantizationState.$default$aboveThresholdMeans();
            }
            float[][] rotationMatrix$value = this.rotationMatrix$value;
            if (!this.rotationMatrix$set) {
                rotationMatrix$value = OneBitScalarQuantizationState.$default$rotationMatrix();
            }
            return new OneBitScalarQuantizationState(this.quantizationParams, this.meanThresholds, belowThresholdMeans$value, aboveThresholdMeans$value, rotationMatrix$value);
        }

        @Generated
        public String toString() {
            return "OneBitScalarQuantizationState.OneBitScalarQuantizationStateBuilder(quantizationParams=" + String.valueOf(this.quantizationParams) + ", meanThresholds=" + Arrays.toString(this.meanThresholds) + ", belowThresholdMeans$value=" + Arrays.toString(this.belowThresholdMeans$value) + ", aboveThresholdMeans$value=" + Arrays.toString(this.aboveThresholdMeans$value) + ", rotationMatrix$value=" + Arrays.deepToString((Object[])this.rotationMatrix$value) + ")";
        }
    }
}

