/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.nn.convolutional;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.convolutional.Deconvolution;
import ai.djl.util.Preconditions;

public class Conv2dTranspose
extends Deconvolution {
    private static final LayoutType[] EXPECTED_LAYOUT = new LayoutType[]{LayoutType.BATCH, LayoutType.CHANNEL, LayoutType.HEIGHT, LayoutType.WIDTH};
    private static final String STRING_LAYOUT = "NCHW";
    private static final int NUM_DIMENSIONS = 4;

    Conv2dTranspose(Builder builder) {
        super(builder);
    }

    @Override
    protected LayoutType[] getExpectedLayout() {
        return EXPECTED_LAYOUT;
    }

    @Override
    protected String getStringLayout() {
        return STRING_LAYOUT;
    }

    @Override
    protected int numDimensions() {
        return 4;
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight) {
        return Conv2dTranspose.conv2dTranspose(input, weight, null, new Shape(1L, 1L), new Shape(0L, 0L), new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias) {
        return Conv2dTranspose.conv2dTranspose(input, weight, bias, new Shape(1L, 1L), new Shape(0L, 0L), new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias, Shape stride) {
        return Conv2dTranspose.conv2dTranspose(input, weight, bias, stride, new Shape(0L, 0L), new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding) {
        return Conv2dTranspose.conv2dTranspose(input, weight, bias, stride, padding, new Shape(0L, 0L), new Shape(1L, 1L));
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding) {
        return Conv2dTranspose.conv2dTranspose(input, weight, bias, stride, padding, outPadding, new Shape(1L, 1L));
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation) {
        return Conv2dTranspose.conv2dTranspose(input, weight, bias, stride, padding, outPadding, dilation, 1);
    }

    public static NDList conv2dTranspose(NDArray input, NDArray weight, NDArray bias, Shape stride, Shape padding, Shape outPadding, Shape dilation, int groups) {
        Preconditions.checkArgument(input.getShape().dimension() == 4 && weight.getShape().dimension() == 4, "the shape of input or weight doesn't match the conv2dTranspose");
        Preconditions.checkArgument(stride.dimension() == 2 && padding.dimension() == 2 && outPadding.dimension() == 2 && dilation.dimension() == 2, "the shape of stride or padding or dilation doesn't match the conv2dTranspose");
        return Deconvolution.deconvolution(input, weight, bias, stride, padding, outPadding, dilation, groups);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static final class Builder
    extends Deconvolution.DeconvolutionBuilder<Builder> {
        Builder() {
            this.stride = new Shape(1L, 1L);
            this.padding = new Shape(0L, 0L);
            this.outPadding = new Shape(0L, 0L);
            this.dilation = new Shape(1L, 1L);
        }

        @Override
        protected Builder self() {
            return this;
        }

        public Conv2dTranspose build() {
            this.validate();
            return new Conv2dTranspose(this);
        }
    }
}

