/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.january.dataset;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.eclipse.january.dataset.BooleanDataset;
import org.eclipse.january.dataset.Dataset;
import org.eclipse.january.dataset.DatasetFactory;
import org.eclipse.january.dataset.InterfaceUtils;
import org.eclipse.january.dataset.ShapeUtils;

public final class BroadcastUtils {
    public static int[][] calculateBroadcastShapes(int[] oldShape, int size, int ... newShape) {
        if (newShape == null) {
            return null;
        }
        int brank = newShape.length;
        if (brank == 0) {
            if (size == 1) {
                return new int[][]{oldShape, newShape};
            }
            return null;
        }
        if (Arrays.equals(oldShape, newShape)) {
            return new int[][]{oldShape, newShape};
        }
        if (ShapeUtils.calcSize(oldShape) != size) {
            throw new IllegalArgumentException("Size must match old shape");
        }
        int offset = brank - oldShape.length;
        if (offset < 0) {
            newShape = BroadcastUtils.padShape(newShape, -offset);
            offset = 0;
        }
        int[] bshape = BroadcastUtils.padShape(oldShape, offset);
        int i = 0;
        while (i < brank) {
            if (newShape[i] != bshape[i] && bshape[i] != 1 && newShape[i] != 1) {
                return null;
            }
            ++i;
        }
        return new int[][]{bshape, newShape};
    }

    public static int[] padShape(int[] shape, int padding) {
        if (padding < 0) {
            throw new IllegalArgumentException("Padding must be zero or greater");
        }
        if (padding == 0) {
            return shape;
        }
        int[] nshape = new int[shape.length + padding];
        Arrays.fill(nshape, 1);
        System.arraycopy(shape, 0, nshape, padding, shape.length);
        return nshape;
    }

    public static List<int[]> broadcastShapes(int[] ... shapes) {
        int maxRank = -1;
        int[][] nArray = shapes;
        int n = shapes.length;
        int n2 = 0;
        while (n2 < n) {
            int r;
            int[] s = nArray[n2];
            if (s != null && (r = s.length) > maxRank) {
                maxRank = r;
            }
            ++n2;
        }
        ArrayList<int[]> newShapes = new ArrayList<int[]>();
        if (maxRank < 0) {
            int i = 0;
            while (i <= shapes.length) {
                newShapes.add(null);
                ++i;
            }
            return newShapes;
        }
        int[][] r = shapes;
        int n3 = shapes.length;
        n = 0;
        while (n < n3) {
            int[] s = r[n];
            newShapes.add(s == null ? null : BroadcastUtils.padShape(s, maxRank - s.length));
            ++n;
        }
        int[] maxShape = new int[maxRank];
        int i = 0;
        while (i < maxRank) {
            int m = -1;
            for (int[] s : newShapes) {
                int l;
                if (s == null || (l = s[i]) <= m) continue;
                if (m > 1) {
                    throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
                }
                m = l;
            }
            maxShape[i] = m;
            ++i;
        }
        BroadcastUtils.checkShapes(maxShape, newShapes);
        newShapes.add(0, maxShape);
        return newShapes;
    }

    public static List<int[]> broadcastShapesToMax(int[] maxShape, int[] ... shapes) {
        int maxRank = maxShape == null ? -1 : maxShape.length;
        int[][] nArray = shapes;
        int n = shapes.length;
        int n2 = 0;
        while (n2 < n) {
            int r;
            int[] s = nArray[n2];
            if (s != null && (r = s.length) > maxRank) {
                throw new IllegalArgumentException("A shape exceeds given rank of maximum shape");
            }
            ++n2;
        }
        ArrayList<int[]> newShapes = new ArrayList<int[]>();
        int[][] nArray2 = shapes;
        int n3 = shapes.length;
        n = 0;
        while (n < n3) {
            int[] s = nArray2[n];
            newShapes.add(s == null ? null : BroadcastUtils.padShape(s, maxRank - s.length));
            ++n;
        }
        if (maxShape != null) {
            BroadcastUtils.checkShapes(maxShape, newShapes);
        }
        return newShapes;
    }

    private static void checkShapes(int[] maxShape, List<int[]> newShapes) {
        int i = 0;
        while (i < maxShape.length) {
            int m = maxShape[i];
            for (int[] s : newShapes) {
                int l;
                if (s == null || (l = s[i]) == 1 || l == m) continue;
                throw new IllegalArgumentException("A shape's dimension was not one or equal to maximum");
            }
            ++i;
        }
    }

    static Dataset createDataset(Dataset a, Dataset b, int[] shape) {
        int ar = a.getRank();
        int br = b.getRank();
        Class<? extends Dataset> tc = InterfaceUtils.getBestInterface(a.getClass(), b.getClass());
        Class<? extends Dataset> rc = ar == 0 ^ br == 0 ? (ar == 0 ? (a.hasFloatingPointElements() ? tc : b.getClass()) : (b.hasFloatingPointElements() ? tc : a.getClass())) : tc;
        int ia = a.getElementsPerItem();
        int ib = b.getElementsPerItem();
        return DatasetFactory.zeros(ia > ib ? ia : ib, rc, shape);
    }

    static void checkItemSize(Dataset a, Dataset o) {
        int iso;
        int isa = a.getElementsPerItem();
        if (o != null && isa != (iso = o.getElementsPerItem()) && isa != 1 && iso != 1) {
            throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
        }
    }

    static void checkItemSize(Dataset a, Dataset b, Dataset o) {
        int isb;
        int isa = a.getElementsPerItem();
        if (!(isa == (isb = b.getElementsPerItem()) || isa == 1 || isb == 1 || isa != 1 && b.getSize() == 1 || isb != 1 && a.getSize() == 1)) {
            throw new IllegalArgumentException("Can not broadcast where number of elements per item mismatch and one does not equal another");
        }
        if (o != null && BooleanDataset.class.isAssignableFrom(o.getClass())) {
            int ism = Math.max(isa, isb);
            int iso = o.getElementsPerItem();
            if (iso != ism && iso != 1 && ism != 1) {
                throw new IllegalArgumentException("Can not output to dataset whose number of elements per item mismatch inputs'");
            }
        }
    }

    public static int[] createBroadcastStrides(Dataset a, int[] broadcastShape) {
        return BroadcastUtils.createBroadcastStrides(a.getElementsPerItem(), a.getShapeRef(), a.getStrides(), broadcastShape);
    }

    public static int[] createBroadcastStrides(int isize, int[] oShape, int[] oStride, int[] broadcastShape) {
        if (oShape == null) {
            if (broadcastShape == null) {
                return null;
            }
            throw new IllegalArgumentException("Broadcast shape must be null if original shape is null");
        }
        int rank = oShape.length;
        if (broadcastShape.length != rank) {
            throw new IllegalArgumentException("Dataset must have same rank as broadcast shape");
        }
        int[] stride = new int[rank];
        if (oStride == null) {
            int s = isize;
            int j = rank - 1;
            while (j >= 0) {
                if (broadcastShape[j] == oShape[j]) {
                    stride[j] = s;
                    s *= oShape[j];
                } else {
                    stride[j] = 0;
                }
                --j;
            }
        } else {
            int j = 0;
            while (j < rank) {
                stride[j] = broadcastShape[j] == oShape[j] ? oStride[j] : 0;
                ++j;
            }
        }
        return stride;
    }

    public static Dataset[] convertAndBroadcast(Object ... objects) {
        int n = objects.length;
        Dataset[] datasets = new Dataset[n];
        int[][] shapes = new int[n][];
        int i = 0;
        while (i < n) {
            Dataset d;
            datasets[i] = d = DatasetFactory.createFromObject(objects[i]);
            shapes[i] = d.getShapeRef();
            ++i;
        }
        List<int[]> nShapes = BroadcastUtils.broadcastShapes(shapes);
        int[] mshape = nShapes.get(0);
        int i2 = 0;
        while (i2 < n) {
            datasets[i2] = datasets[i2].getBroadcastView(mshape);
            ++i2;
        }
        return datasets;
    }
}

