/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import ai.djl.util.Preconditions;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collections;

public class LinearCollection
extends AbstractBlock {
    private static final byte VERSION = 1;
    private long units;
    private long inputFeatures;
    private Shape inputShape;
    private Parameter weight;
    private Parameter bias;
    private int[] shiftBatchAxis;
    private int[] reverseShiftBatchAxis;

    LinearCollection(Builder builder) {
        super((byte)1);
        this.units = builder.units;
        this.weight = this.addParameter(Parameter.builder().setName("weight").setType(Parameter.Type.WEIGHT).build());
        if (builder.bias) {
            this.bias = this.addParameter(Parameter.builder().setName("bias").setType(Parameter.Type.BIAS).build());
        }
    }

    @Override
    protected NDList forwardInternal(ParameterStore parameterStore, NDList inputs, boolean training, PairList<String, Object> params) {
        NDArray input = inputs.singletonOrThrow();
        Device device = input.getDevice();
        NDArray weightArr = parameterStore.getValue(this.weight, device, training);
        NDArray biasArr = parameterStore.getValue(this.bias, device, training);
        return this.linear(input, weightArr, biasArr);
    }

    @Override
    public Shape[] getOutputShapes(Shape[] inputs) {
        return new Shape[]{inputs[0].slice(0, inputs[0].dimension() - 1).add(this.units)};
    }

    @Override
    public PairList<String, Shape> describeInput() {
        return new PairList<String, Shape>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override
    protected void beforeInitialize(Shape ... inputShapes) {
        super.beforeInitialize(inputShapes);
        Preconditions.checkArgument(inputShapes.length == 1, "Linear block only support 1 input");
        Shape input = inputShapes[0];
        this.inputFeatures = input.slice(1).size();
        this.inputShape = input.slice(0, 1);
    }

    @Override
    public void prepare(Shape[] inputShapes) {
        Shape input = inputShapes[0];
        this.weight.setShape(input.slice(1).add(this.units));
        if (this.bias != null) {
            this.bias.setShape(input.slice(1, input.dimension() - 1).add(this.units));
        }
    }

    @Override
    protected void saveMetadata(DataOutputStream os) throws IOException {
        os.writeLong(this.units);
        os.writeLong(this.inputFeatures);
        os.write(this.inputShape.getEncoded());
    }

    @Override
    public void loadMetadata(byte loadVersion, DataInputStream is) throws IOException, MalformedModelException {
        switch (loadVersion) {
            case 1: {
                this.units = is.readLong();
                this.inputFeatures = is.readLong();
                break;
            }
            default: {
                throw new MalformedModelException("Unsupported encoding version: " + loadVersion);
            }
        }
        this.inputShape = Shape.decode(is);
    }

    public NDList linear(NDArray input, NDArray weight, NDArray bias) {
        if (this.shiftBatchAxis == null) {
            int dim = input.getShape().dimension();
            this.shiftBatchAxis = new int[dim];
            this.reverseShiftBatchAxis = new int[dim];
            for (int d = 0; d < dim - 2; ++d) {
                this.shiftBatchAxis[d] = d + 1;
                this.reverseShiftBatchAxis[d + 1] = d;
            }
            this.shiftBatchAxis[dim - 1] = dim - 1;
            this.reverseShiftBatchAxis[dim - 1] = dim - 1;
            this.shiftBatchAxis[dim - 2] = 0;
            this.reverseShiftBatchAxis[0] = dim - 2;
        }
        NDArray resultArr = input.transpose(this.shiftBatchAxis).matMul(weight).transpose(this.reverseShiftBatchAxis);
        if (bias != null) {
            resultArr.addi(bias);
        }
        return new NDList(resultArr);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder {
        private long units;
        private boolean bias = true;

        Builder() {
        }

        public Builder setUnits(long units) {
            this.units = units;
            return this;
        }

        public Builder optBias(boolean bias) {
            this.bias = bias;
            return this;
        }

        public LinearCollection build() {
            Preconditions.checkArgument(this.units > 0L, "You must specify unit");
            return new LinearCollection(this);
        }
    }
}

