/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.matrix.data;

import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnPoolingDescriptor;
import jcuda.jcudnn.cudnnTensorDescriptor;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;

public class LibMatrixCuDNNPoolingDescriptors
implements AutoCloseable {
    public cudnnTensorDescriptor xDesc;
    public cudnnTensorDescriptor yDesc;
    public cudnnTensorDescriptor dxDesc;
    public cudnnTensorDescriptor dyDesc;
    public cudnnPoolingDescriptor poolingDesc;

    @Override
    public void close() {
        if (this.xDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.xDesc);
        }
        if (this.yDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.yDesc);
        }
        if (this.dxDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.dxDesc);
        }
        if (this.dyDesc != null) {
            JCudnn.cudnnDestroyTensorDescriptor((cudnnTensorDescriptor)this.dyDesc);
        }
        if (this.poolingDesc != null) {
            JCudnn.cudnnDestroyPoolingDescriptor((cudnnPoolingDescriptor)this.poolingDesc);
        }
    }

    public static LibMatrixCuDNNPoolingDescriptors cudnnPoolingBackwardDescriptors(GPUContext gCtx, String instName, int N, int C, int H, int W, int K2, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, LibMatrixDNN.PoolingType poolingType) {
        LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
        ret.xDesc = LibMatrixCuDNNPoolingDescriptors.allocateTensorDescriptor(N, C, H, W);
        ret.yDesc = LibMatrixCuDNNPoolingDescriptors.allocateTensorDescriptor(N, C, P, Q);
        ret.dxDesc = LibMatrixCuDNNPoolingDescriptors.allocateTensorDescriptor(N, C, H, W);
        ret.dyDesc = LibMatrixCuDNNPoolingDescriptors.allocateTensorDescriptor(N, C, P, Q);
        ret.poolingDesc = LibMatrixCuDNNPoolingDescriptors.allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w, poolingType);
        return ret;
    }

    public static LibMatrixCuDNNPoolingDescriptors cudnnPoolingDescriptors(GPUContext gCtx, String instName, int N, int C, int H, int W, int K2, int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, int P, int Q, LibMatrixDNN.PoolingType poolingType) {
        LibMatrixCuDNNPoolingDescriptors ret = new LibMatrixCuDNNPoolingDescriptors();
        ret.xDesc = LibMatrixCuDNNPoolingDescriptors.allocateTensorDescriptor(N, C, H, W);
        ret.yDesc = LibMatrixCuDNNPoolingDescriptors.allocateTensorDescriptor(N, C, P, Q);
        ret.poolingDesc = LibMatrixCuDNNPoolingDescriptors.allocatePoolingDescriptor(R, S, pad_h, pad_w, stride_h, stride_w, poolingType);
        return ret;
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int N, int C, int H, int W) {
        cudnnTensorDescriptor tensorDescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor((cudnnTensorDescriptor)tensorDescriptor);
        JCudnn.cudnnSetTensor4dDescriptor((cudnnTensorDescriptor)tensorDescriptor, (int)0, (int)LibMatrixCUDA.CUDNN_DATA_TYPE, (int)N, (int)C, (int)H, (int)W);
        return tensorDescriptor;
    }

    private static cudnnPoolingDescriptor allocatePoolingDescriptor(int R, int S, int pad_h, int pad_w, int stride_h, int stride_w, LibMatrixDNN.PoolingType poolingType) {
        cudnnPoolingDescriptor poolingDesc = new cudnnPoolingDescriptor();
        JCudnn.cudnnCreatePoolingDescriptor((cudnnPoolingDescriptor)poolingDesc);
        int CUDNN_POOLING = poolingType == LibMatrixDNN.PoolingType.MAX ? 0 : 1;
        JCudnn.cudnnSetPooling2dDescriptor((cudnnPoolingDescriptor)poolingDesc, (int)CUDNN_POOLING, (int)1, (int)R, (int)S, (int)pad_h, (int)pad_w, (int)stride_h, (int)stride_w);
        return poolingDesc;
    }
}

