/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.viatra.query.runtime.matchers.psystem.rewriters;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.eclipse.viatra.query.runtime.matchers.psystem.PBody;
import org.eclipse.viatra.query.runtime.matchers.psystem.PConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall;
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PDisjunction;
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PQuery;
import org.eclipse.viatra.query.runtime.matchers.psystem.rewriters.FlattenerCopier;
import org.eclipse.viatra.query.runtime.matchers.psystem.rewriters.IConstraintFilter;
import org.eclipse.viatra.query.runtime.matchers.psystem.rewriters.IFlattenCallPredicate;
import org.eclipse.viatra.query.runtime.matchers.psystem.rewriters.IVariableRenamer;
import org.eclipse.viatra.query.runtime.matchers.psystem.rewriters.PDisjunctionRewriter;
import org.eclipse.viatra.query.runtime.matchers.psystem.rewriters.RewriterException;
import org.eclipse.viatra.query.runtime.matchers.util.Preconditions;
import org.eclipse.viatra.query.runtime.matchers.util.Sets;

public class PQueryFlattener
extends PDisjunctionRewriter {
    private IFlattenCallPredicate flattenCallPredicate;

    private static <K, V> Set<Map<K, V>> permutation(Map<K, Set<V>> values) {
        ArrayList<K> keyList = new ArrayList<K>(values.keySet());
        ArrayList<Set<V>> valuesList = new ArrayList<Set<V>>(keyList.size());
        for (Object key : keyList) {
            valuesList.add(values.get(key));
        }
        Set valueMappings = Sets.cartesianProduct(valuesList);
        LinkedHashSet<Map<K, V>> result = new LinkedHashSet<Map<K, V>>();
        for (List valueList : valueMappings) {
            HashMap map = new HashMap();
            int i = 0;
            while (i < keyList.size()) {
                map.put(keyList.get(i), valueList.get(i));
                ++i;
            }
            result.add(map);
        }
        return result;
    }

    public PQueryFlattener(IFlattenCallPredicate flattenCallPredicate) {
        this.flattenCallPredicate = flattenCallPredicate;
    }

    @Override
    public PDisjunction rewrite(PDisjunction disjunction) {
        PQuery query = disjunction.getQuery();
        Set<PQuery> allReferredQueries = disjunction.getAllReferredQueries();
        for (PQuery referredQuery : allReferredQueries) {
            if (!referredQuery.getAllReferredQueries().contains(referredQuery)) continue;
            throw new RewriterException("Recursive queries are not supported, can't flatten query named \"{1}\"", new String[]{query.getFullyQualifiedName()}, "Unsupported recursive query", query);
        }
        return this.doFlatten(disjunction);
    }

    private List<PDisjunction> disjunctionDependencies(PDisjunction rootDisjunction) {
        ArrayDeque<PDisjunction> stack = new ArrayDeque<PDisjunction>();
        LinkedList<PDisjunction> list = new LinkedList<PDisjunction>();
        stack.push(rootDisjunction);
        list.add(rootDisjunction);
        while (!stack.isEmpty()) {
            PDisjunction disjunction = (PDisjunction)stack.pop();
            for (PBody pBody : disjunction.getBodies()) {
                for (PConstraint constraint : pBody.getConstraints()) {
                    PositivePatternCall positivePatternCall;
                    if (!(constraint instanceof PositivePatternCall) || !this.flattenCallPredicate.shouldFlatten(positivePatternCall = (PositivePatternCall)constraint)) continue;
                    PDisjunction calledDisjunction = positivePatternCall.getReferredQuery().getDisjunctBodies();
                    stack.push(calledDisjunction);
                    list.add(calledDisjunction);
                }
            }
        }
        HashSet visited = new HashSet();
        ArrayList<PDisjunction> result = new ArrayList<PDisjunction>(list.size());
        list.descendingIterator().forEachRemaining(item -> {
            if (!visited.contains(item)) {
                result.add((PDisjunction)item);
                visited.add(item);
            }
        });
        return result;
    }

    private PDisjunction doFlatten(PDisjunction rootDisjunction) {
        HashMap flatBodyMapping = new HashMap();
        List<PDisjunction> dependencies = this.disjunctionDependencies(rootDisjunction);
        for (PDisjunction disjunction : dependencies) {
            LinkedHashSet<PBody> flatBodies = new LinkedHashSet<PBody>();
            for (PBody body : disjunction.getBodies()) {
                if (this.isFlatteningNeeded(body)) {
                    HashMap<PositivePatternCall, Set<PBody>> flattenedBodies = new HashMap<PositivePatternCall, Set<PBody>>();
                    for (PConstraint pConstraint : body.getConstraints()) {
                        PositivePatternCall positivePatternCall;
                        if (!(pConstraint instanceof PositivePatternCall) || !this.flattenCallPredicate.shouldFlatten(positivePatternCall = (PositivePatternCall)pConstraint)) continue;
                        PDisjunction calledDisjunction = positivePatternCall.getReferredQuery().getDisjunctBodies();
                        Set flattenedBodySet = (Set)flatBodyMapping.get(calledDisjunction);
                        Preconditions.checkArgument(!flattenedBodySet.isEmpty());
                        flattenedBodies.put(positivePatternCall, flattenedBodySet);
                    }
                    flatBodies.addAll(this.createSetOfFlatPBodies(body, flattenedBodies));
                    continue;
                }
                flatBodies.add(this.prepareFlatPBody(body));
            }
            flatBodyMapping.put(disjunction, flatBodies);
        }
        return new PDisjunction(rootDisjunction.getQuery(), (Set)flatBodyMapping.get(rootDisjunction));
    }

    private Set<PBody> createSetOfFlatPBodies(PBody pBody, Map<PositivePatternCall, Set<PBody>> flattenedCalls) {
        PQuery pQuery = pBody.getPattern();
        Set<Map<PositivePatternCall, PBody>> conjunctedCalls = PQueryFlattener.permutation(flattenedCalls);
        HashSet<PBody> conjunctedBodies = new HashSet<PBody>();
        for (Map<PositivePatternCall, PBody> calledBodies : conjunctedCalls) {
            FlattenerCopier copier = this.createBodyCopier(pQuery, calledBodies);
            int i = 0;
            IVariableRenamer.HierarchicalName hierarchicalNamingTool = new IVariableRenamer.HierarchicalName();
            for (PositivePatternCall patternCall : calledBodies.keySet()) {
                hierarchicalNamingTool.setCallCount(i++);
                copier.mergeBody(patternCall, (IVariableRenamer)hierarchicalNamingTool, (IConstraintFilter)new IConstraintFilter.ExportedParameterFilter());
            }
            copier.mergeBody(pBody);
            PBody copiedBody = copier.getCopiedBody();
            copiedBody.setStatus(PQuery.PQueryStatus.OK);
            conjunctedBodies.add(copiedBody);
        }
        return conjunctedBodies;
    }

    private FlattenerCopier createBodyCopier(PQuery query, Map<PositivePatternCall, PBody> calledBodies) {
        FlattenerCopier flattenerCopier = new FlattenerCopier(query, calledBodies);
        flattenerCopier.setTraceCollector(this.getTraceCollector());
        return flattenerCopier;
    }

    private PBody prepareFlatPBody(PBody pBody) {
        FlattenerCopier copier = this.createBodyCopier(pBody.getPattern(), Collections.emptyMap());
        copier.mergeBody(pBody, (IVariableRenamer)new IVariableRenamer.SameName(), (IConstraintFilter)new IConstraintFilter.AllowAllFilter());
        return copier.getCopiedBody();
    }

    private boolean isFlatteningNeeded(PBody pBody) {
        for (PConstraint pConstraint : pBody.getConstraints()) {
            if (!(pConstraint instanceof PositivePatternCall)) continue;
            return this.flattenCallPredicate.shouldFlatten((PositivePatternCall)pConstraint);
        }
        return false;
    }
}

