/*
 * Decompiled with CFR 0.152.
 */
package org.eclipse.viatra.query.runtime.localsearch.planner;

import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.eclipse.emf.ecore.EReference;
import org.eclipse.emf.ecore.EStructuralFeature;
import org.eclipse.viatra.query.runtime.base.api.BaseIndexOptions;
import org.eclipse.viatra.query.runtime.base.comprehension.EMFModelComprehension;
import org.eclipse.viatra.query.runtime.emf.types.EStructuralFeatureInstancesKey;
import org.eclipse.viatra.query.runtime.localsearch.planner.PConstraintInfo;
import org.eclipse.viatra.query.runtime.localsearch.planner.cost.IConstraintEvaluationContext;
import org.eclipse.viatra.query.runtime.matchers.context.IInputKey;
import org.eclipse.viatra.query.runtime.matchers.context.IQueryBackendContext;
import org.eclipse.viatra.query.runtime.matchers.psystem.PConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.PVariable;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.AggregatorConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExportedParameter;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.ExpressionEvaluation;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.Inequality;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.PatternMatchCounter;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicdeferred.TypeFilterConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.BinaryTransitiveClosure;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.ConstantValue;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.PositivePatternCall;
import org.eclipse.viatra.query.runtime.matchers.psystem.basicenumerables.TypeConstraint;
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameter;
import org.eclipse.viatra.query.runtime.matchers.psystem.queries.PParameterDirection;
import org.eclipse.viatra.query.runtime.matchers.tuple.Tuple;

class PConstraintInfoInferrer {
    private static final Predicate<PVariable> SINGLE_USE_VARIABLE = input -> input != null && input.getReferringConstraints().size() == 1;
    private final boolean useIndex;
    private final Function<IConstraintEvaluationContext, Double> costFunction;
    private final EMFModelComprehension modelComprehension;
    private final IQueryBackendContext context;

    public PConstraintInfoInferrer(boolean useIndex, IQueryBackendContext backendContext, Function<IConstraintEvaluationContext, Double> costFunction) {
        this.useIndex = useIndex;
        this.context = backendContext;
        this.costFunction = costFunction;
        this.modelComprehension = new EMFModelComprehension(new BaseIndexOptions());
    }

    public List<PConstraintInfo> createPConstraintInfos(Set<PConstraint> constraintSet) {
        ArrayList<PConstraintInfo> constraintInfos = new ArrayList<PConstraintInfo>();
        for (PConstraint pConstraint : constraintSet) {
            this.createPConstraintInfoDispatch(constraintInfos, pConstraint);
        }
        return constraintInfos;
    }

    private void createPConstraintInfoDispatch(List<PConstraintInfo> resultList, PConstraint pConstraint) {
        if (pConstraint instanceof ExportedParameter) {
            this.createConstraintInfoExportedParameter(resultList, (ExportedParameter)pConstraint);
        } else if (pConstraint instanceof TypeConstraint) {
            this.createConstraintInfoTypeConstraint(resultList, (TypeConstraint)pConstraint);
        } else if (pConstraint instanceof TypeFilterConstraint) {
            this.createConstraintInfoTypeFilterConstraint(resultList, (TypeFilterConstraint)pConstraint);
        } else if (pConstraint instanceof ConstantValue) {
            this.createConstraintInfoConstantValue(resultList, (ConstantValue)pConstraint);
        } else if (pConstraint instanceof Inequality) {
            this.createConstraintInfoInequality(resultList, (Inequality)pConstraint);
        } else if (pConstraint instanceof ExpressionEvaluation) {
            this.createConstraintInfoExpressionEvaluation(resultList, (ExpressionEvaluation)pConstraint);
        } else if (pConstraint instanceof AggregatorConstraint) {
            this.createConstraintInfoAggregatorConstraint(resultList, pConstraint, ((AggregatorConstraint)pConstraint).getResultVariable());
        } else if (pConstraint instanceof PatternMatchCounter) {
            this.createConstraintInfoAggregatorConstraint(resultList, pConstraint, ((PatternMatchCounter)pConstraint).getResultVariable());
        } else if (pConstraint instanceof PositivePatternCall) {
            this.createConstraintInfoPositivePatternCall(resultList, (PositivePatternCall)pConstraint);
        } else if (pConstraint instanceof BinaryTransitiveClosure) {
            this.createConstraintInfoBinaryTransitiveClosure(resultList, (BinaryTransitiveClosure)pConstraint);
        } else {
            this.createConstraintInfoGeneric(resultList, pConstraint);
        }
    }

    private void createConstraintInfoConstantValue(List<PConstraintInfo> resultList, ConstantValue pConstraint) {
        Set affectedVariables = pConstraint.getAffectedVariables();
        Set bindings = Sets.powerSet((Set)affectedVariables);
        this.doCreateConstraintInfos(resultList, (PConstraint)pConstraint, affectedVariables, bindings);
    }

    private void createConstraintInfoPositivePatternCall(List<PConstraintInfo> resultList, PositivePatternCall pCall) {
        Set affectedVariables = pCall.getAffectedVariables();
        Tuple variables = pCall.getVariablesTuple();
        HashSet<PVariable> inVariables = new HashSet<PVariable>();
        HashSet<PVariable> inoutVariables = new HashSet<PVariable>();
        List parameters = pCall.getReferredQuery().getParameters();
        int i = 0;
        while (i < parameters.size()) {
            switch (((PParameter)parameters.get(i)).getDirection()) {
                case IN: {
                    inVariables.add((PVariable)variables.get(i));
                    break;
                }
                case INOUT: {
                    inoutVariables.add((PVariable)variables.get(i));
                    break;
                }
            }
            ++i;
        }
        Iterable bindings = Sets.powerSet(inoutVariables).stream().map(input -> Stream.concat(input.stream(), inVariables.stream()).collect(Collectors.toSet())).collect(Collectors.toSet());
        this.doCreateConstraintInfos(resultList, (PConstraint)pCall, affectedVariables, bindings);
    }

    private void createConstraintInfoBinaryTransitiveClosure(List<PConstraintInfo> resultList, BinaryTransitiveClosure closure) {
        List parameters = closure.getReferredQuery().getParameters();
        Tuple variables = closure.getVariablesTuple();
        HashSet<Set<PVariable>> bindings = new HashSet<Set<PVariable>>();
        PVariable firstVariable = (PVariable)variables.get(0);
        PVariable secondVariable = (PVariable)variables.get(1);
        bindings.add(new HashSet<PVariable>(Arrays.asList(firstVariable, secondVariable)));
        if (((PParameter)parameters.get(0)).getDirection() != PParameterDirection.IN) {
            bindings.add(Collections.singleton(secondVariable));
        }
        if (((PParameter)parameters.get(1)).getDirection() != PParameterDirection.IN) {
            bindings.add(Collections.singleton(firstVariable));
        }
        this.doCreateConstraintInfos(resultList, (PConstraint)closure, closure.getAffectedVariables(), bindings);
    }

    private void createConstraintInfoExportedParameter(List<PConstraintInfo> resultList, ExportedParameter parameter) {
        Set affectedVariables = parameter.getAffectedVariables();
        this.doCreateConstraintInfos(resultList, (PConstraint)parameter, affectedVariables, Collections.singleton(affectedVariables));
    }

    private void createConstraintInfoExpressionEvaluation(List<PConstraintInfo> resultList, ExpressionEvaluation expressionEvaluation) {
        PVariable output = expressionEvaluation.getOutputVariable();
        HashSet<Set<PVariable>> bindings = new HashSet<Set<PVariable>>();
        Set affectedVariables = expressionEvaluation.getAffectedVariables();
        bindings.add(affectedVariables);
        bindings.add(affectedVariables.stream().filter(var -> !Objects.equals(var, output)).collect(Collectors.toSet()));
        this.doCreateConstraintInfos(resultList, (PConstraint)expressionEvaluation, affectedVariables, bindings);
    }

    private void createConstraintInfoTypeFilterConstraint(List<PConstraintInfo> resultList, TypeFilterConstraint filter) {
        Set affectedVariables = filter.getAffectedVariables();
        this.doCreateConstraintInfos(resultList, (PConstraint)filter, affectedVariables, Collections.singleton(affectedVariables));
    }

    private void createConstraintInfoInequality(List<PConstraintInfo> resultList, Inequality inequality) {
        Set affectedVariables = inequality.getAffectedVariables();
        this.doCreateConstraintInfos(resultList, (PConstraint)inequality, affectedVariables, Collections.singleton(affectedVariables));
    }

    private void createConstraintInfoAggregatorConstraint(List<PConstraintInfo> resultList, PConstraint pConstraint, PVariable resultVariable) {
        Set affectedVariables = pConstraint.getAffectedVariables();
        Set<PVariable> canBeUnboundVariables = Stream.concat(Stream.of(resultVariable), affectedVariables.stream().filter(SINGLE_USE_VARIABLE)).collect(Collectors.toSet());
        Set<Set<PVariable>> bindings = this.calculatePossibleBindings(canBeUnboundVariables, affectedVariables);
        this.doCreateConstraintInfos(resultList, pConstraint, affectedVariables, bindings);
    }

    private Set<Set<PVariable>> calculatePossibleBindings(Set<PVariable> canBeUnboundVariables, Set<PVariable> affectedVariables) {
        Set mustBindVariables = affectedVariables.stream().filter(input -> !canBeUnboundVariables.contains(input)).collect(Collectors.toSet());
        return Sets.powerSet(canBeUnboundVariables).stream().map(input -> {
            HashSet result = new HashSet(input);
            result.addAll(mustBindVariables);
            return result;
        }).collect(Collectors.toSet());
    }

    private void createConstraintInfoGeneric(List<PConstraintInfo> resultList, PConstraint pConstraint) {
        Set affectedVariables = pConstraint.getAffectedVariables();
        Set<PVariable> canBeUnboundVariables = affectedVariables.stream().filter(SINGLE_USE_VARIABLE).collect(Collectors.toSet());
        Set<Set<PVariable>> bindings = this.calculatePossibleBindings(canBeUnboundVariables, affectedVariables);
        this.doCreateConstraintInfos(resultList, pConstraint, affectedVariables, bindings);
    }

    private boolean canPerformInverseNavigation(EStructuralFeature feature) {
        return this.hasEOpposite(feature) || feature instanceof EReference && ((EReference)feature).isContainment() || this.useIndex && this.modelComprehension.representable(feature);
    }

    private void createConstraintInfoTypeConstraint(List<PConstraintInfo> resultList, TypeConstraint typeConstraint) {
        EStructuralFeature feature;
        Set affectedVariables = typeConstraint.getAffectedVariables();
        Set<Set<PVariable>> bindings = null;
        IInputKey inputKey = (IInputKey)typeConstraint.getSupplierKey();
        bindings = inputKey.isEnumerable() ? Sets.powerSet((Set)affectedVariables) : Collections.singleton(affectedVariables);
        if (inputKey instanceof EStructuralFeatureInstancesKey && !this.canPerformInverseNavigation(feature = (EStructuralFeature)((EStructuralFeatureInstancesKey)inputKey).getEmfKey())) {
            bindings = this.excludeUnnavigableOperationMasks(typeConstraint, bindings);
        }
        this.doCreateConstraintInfos(resultList, (PConstraint)typeConstraint, affectedVariables, bindings);
    }

    private void doCreateConstraintInfos(List<PConstraintInfo> constraintInfos, PConstraint pConstraint, Set<PVariable> affectedVariables, Iterable<Set<PVariable>> bindings) {
        HashSet<PConstraintInfo> sameWithDifferentBindings = new HashSet<PConstraintInfo>();
        for (Set<PVariable> boundVariables : bindings) {
            PConstraintInfo info = new PConstraintInfo(pConstraint, boundVariables, affectedVariables.stream().filter(input -> !boundVariables.contains(input)).collect(Collectors.toSet()), sameWithDifferentBindings, this.context, this.costFunction);
            constraintInfos.add(info);
            sameWithDifferentBindings.add(info);
        }
    }

    private Set<Set<PVariable>> excludeUnnavigableOperationMasks(TypeConstraint typeConstraint, Set<Set<PVariable>> bindings) {
        PVariable firstVariable = typeConstraint.getVariableInTuple(0);
        return bindings.stream().filter(boundVariablesSet -> boundVariablesSet.isEmpty() || boundVariablesSet.contains(firstVariable)).collect(Collectors.toSet());
    }

    private boolean hasEOpposite(EStructuralFeature feature) {
        EReference eOpposite;
        return feature instanceof EReference && (eOpposite = ((EReference)feature).getEOpposite()) != null;
    }
}

