/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.metadata;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.LinkedHashSet;
import java.util.List;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.rel.RelCollation;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Minus;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.NumberUtil;
import org.apache.flink.calcite.shaded.com.google.common.base.Preconditions;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;

public class RelMdUtil {
    public static final SqlFunction ARTIFICIAL_SELECTIVITY_FUNC = new SqlFunction("ARTIFICIAL_SELECTIVITY", SqlKind.OTHER_FUNCTION, ReturnTypes.BOOLEAN, null, OperandTypes.NUMERIC, SqlFunctionCategory.SYSTEM);

    private RelMdUtil() {
    }

    public static RexNode makeSemiJoinSelectivityRexNode(RelMetadataQuery mq, Join rel) {
        RexBuilder rexBuilder = rel.getCluster().getRexBuilder();
        double selectivity = RelMdUtil.computeSemiJoinSelectivity(mq, rel.getLeft(), rel.getRight(), rel);
        return rexBuilder.makeCall((SqlOperator)ARTIFICIAL_SELECTIVITY_FUNC, rexBuilder.makeApproxLiteral(new BigDecimal(selectivity)));
    }

    public static double getSelectivityValue(RexNode artificialSelectivityFuncNode) {
        assert (artificialSelectivityFuncNode instanceof RexCall);
        RexCall call = (RexCall)artificialSelectivityFuncNode;
        assert (call.getOperator() == ARTIFICIAL_SELECTIVITY_FUNC);
        RexNode operand = call.getOperands().get(0);
        return ((RexLiteral)operand).getValueAs(Double.class);
    }

    public static double computeSemiJoinSelectivity(RelMetadataQuery mq, RelNode factRel, RelNode dimRel, Join rel) {
        return RelMdUtil.computeSemiJoinSelectivity(mq, factRel, dimRel, rel.analyzeCondition().leftKeys, rel.analyzeCondition().rightKeys);
    }

    public static double computeSemiJoinSelectivity(RelMetadataQuery mq, RelNode factRel, RelNode dimRel, List<Integer> factKeyList, List<Integer> dimKeyList) {
        Double selectivity;
        Double dimCard;
        ImmutableBitSet.Builder factKeys = ImmutableBitSet.builder();
        for (int factCol : factKeyList) {
            factKeys.set(factCol);
        }
        ImmutableBitSet.Builder dimKeyBuilder = ImmutableBitSet.builder();
        for (int dimCol : dimKeyList) {
            dimKeyBuilder.set(dimCol);
        }
        ImmutableBitSet dimKeys = dimKeyBuilder.build();
        Double factPop = mq.getPopulationSize(factRel, factKeys.build());
        if (factPop == null) {
            factPop = mq.getPopulationSize(dimRel, dimKeys);
        }
        if ((dimCard = mq.getDistinctRowCount(dimRel, dimKeys, null)) != null && factPop != null) {
            if (factPop < 1.0) {
                factPop = 1.0;
            }
            selectivity = dimCard / factPop;
        } else {
            selectivity = mq.getPercentageOriginalRows(dimRel);
        }
        if (selectivity == null) {
            selectivity = Math.pow(0.1, dimKeys.cardinality());
        } else if (selectivity > 1.0) {
            selectivity = 1.0;
        }
        return selectivity;
    }

    public static boolean areColumnsDefinitelyUnique(RelMetadataQuery mq, RelNode rel, ImmutableBitSet colMask) {
        Boolean b = mq.areColumnsUnique(rel, colMask, false);
        return b != null && b != false;
    }

    public static Boolean areColumnsUnique(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        ImmutableBitSet.Builder colMask = ImmutableBitSet.builder();
        for (RexInputRef columnRef : columnRefs) {
            colMask.set(columnRef.getIndex());
        }
        return mq.areColumnsUnique(rel, colMask.build());
    }

    public static boolean areColumnsDefinitelyUnique(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        Boolean b = RelMdUtil.areColumnsUnique(mq, rel, columnRefs);
        return b != null && b != false;
    }

    public static boolean areColumnsDefinitelyUniqueWhenNullsFiltered(RelMetadataQuery mq, RelNode rel, ImmutableBitSet colMask) {
        Boolean b = mq.areColumnsUnique(rel, colMask, true);
        if (b == null) {
            return false;
        }
        return b;
    }

    public static Boolean areColumnsUniqueWhenNullsFiltered(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        ImmutableBitSet.Builder colMask = ImmutableBitSet.builder();
        for (RexInputRef columnRef : columnRefs) {
            colMask.set(columnRef.getIndex());
        }
        return mq.areColumnsUnique(rel, colMask.build(), true);
    }

    public static boolean areColumnsDefinitelyUniqueWhenNullsFiltered(RelMetadataQuery mq, RelNode rel, List<RexInputRef> columnRefs) {
        Boolean b = RelMdUtil.areColumnsUniqueWhenNullsFiltered(mq, rel, columnRefs);
        if (b == null) {
            return false;
        }
        return b;
    }

    public static void setLeftRightBitmaps(ImmutableBitSet groupKey, ImmutableBitSet.Builder leftMask, ImmutableBitSet.Builder rightMask, int nFieldsOnLeft) {
        for (int bit : groupKey) {
            if (bit < nFieldsOnLeft) {
                leftMask.set(bit);
                continue;
            }
            rightMask.set(bit - nFieldsOnLeft);
        }
    }

    public static Double numDistinctVals(Double domainSize, Double numSelected) {
        double res;
        if (domainSize == null || numSelected == null) {
            return null;
        }
        double dSize = RelMdUtil.capInfinity(domainSize);
        double numSel = RelMdUtil.capInfinity(numSelected);
        double d = res = dSize > 0.0 ? (1.0 - Math.exp(-1.0 * numSel / dSize)) * dSize : 0.0;
        if (res > dSize) {
            res = dSize;
        }
        if (res > numSel) {
            res = numSel;
        }
        if (res < 0.0) {
            res = 0.0;
        }
        return res;
    }

    public static double capInfinity(Double d) {
        return d.isInfinite() ? Double.MAX_VALUE : d;
    }

    public static double guessSelectivity(RexNode predicate) {
        return RelMdUtil.guessSelectivity(predicate, false);
    }

    public static double guessSelectivity(RexNode predicate, boolean artificialOnly) {
        double sel = 1.0;
        if (predicate == null || predicate.isAlwaysTrue()) {
            return sel;
        }
        double artificialSel = 1.0;
        for (RexNode pred : RelOptUtil.conjunctions(predicate)) {
            if (pred.getKind() == SqlKind.IS_NOT_NULL) {
                sel *= 0.9;
                continue;
            }
            if (pred instanceof RexCall && ((RexCall)pred).getOperator() == ARTIFICIAL_SELECTIVITY_FUNC) {
                artificialSel *= RelMdUtil.getSelectivityValue(pred);
                continue;
            }
            if (pred.isA(SqlKind.EQUALS)) {
                sel *= 0.15;
                continue;
            }
            if (pred.isA(SqlKind.COMPARISON)) {
                sel *= 0.5;
                continue;
            }
            sel *= 0.25;
        }
        if (artificialOnly) {
            return artificialSel;
        }
        return sel * artificialSel;
    }

    public static RexNode unionPreds(RexBuilder rexBuilder, RexNode pred1, RexNode pred2) {
        LinkedHashSet<RexNode> unionList = new LinkedHashSet<RexNode>();
        unionList.addAll(RelOptUtil.conjunctions(pred1));
        unionList.addAll(RelOptUtil.conjunctions(pred2));
        return RexUtil.composeConjunction(rexBuilder, unionList, true);
    }

    public static RexNode minusPreds(RexBuilder rexBuilder, RexNode pred1, RexNode pred2) {
        ArrayList<RexNode> minusList = new ArrayList<RexNode>(RelOptUtil.conjunctions(pred1));
        minusList.removeAll(RelOptUtil.conjunctions(pred2));
        return RexUtil.composeConjunction(rexBuilder, minusList, true);
    }

    public static void setAggChildKeys(ImmutableBitSet groupKey, Aggregate aggRel, ImmutableBitSet.Builder childKey) {
        List<AggregateCall> aggCalls = aggRel.getAggCallList();
        for (int bit : groupKey) {
            if (bit < aggRel.getGroupCount()) {
                childKey.set(bit);
                continue;
            }
            AggregateCall agg = aggCalls.get(bit - aggRel.getGroupCount());
            for (Integer arg : agg.getArgList()) {
                childKey.set(arg);
            }
        }
    }

    public static void splitCols(List<RexNode> projExprs, ImmutableBitSet groupKey, ImmutableBitSet.Builder baseCols, ImmutableBitSet.Builder projCols) {
        for (int bit : groupKey) {
            RexNode e = projExprs.get(bit);
            if (e instanceof RexInputRef) {
                baseCols.set(((RexInputRef)e).getIndex());
                continue;
            }
            projCols.set(bit);
        }
    }

    public static Double cardOfProjExpr(RelMetadataQuery mq, Project rel, RexNode expr) {
        return expr.accept(new CardOfProjExpr(mq, rel));
    }

    public static Double getJoinPopulationSize(RelMetadataQuery mq, RelNode joinRel, ImmutableBitSet groupKey) {
        Join join = (Join)joinRel;
        if (!join.getJoinType().projectsRight()) {
            return mq.getPopulationSize(join.getLeft(), groupKey);
        }
        ImmutableBitSet.Builder leftMask = ImmutableBitSet.builder();
        ImmutableBitSet.Builder rightMask = ImmutableBitSet.builder();
        RelNode left = joinRel.getInputs().get(0);
        RelNode right = joinRel.getInputs().get(1);
        RelMdUtil.setLeftRightBitmaps(groupKey, leftMask, rightMask, left.getRowType().getFieldCount());
        Double population = NumberUtil.multiply(mq.getPopulationSize(left, leftMask.build()), mq.getPopulationSize(right, rightMask.build()));
        return RelMdUtil.numDistinctVals(population, mq.getRowCount(joinRel));
    }

    public static double addEpsilon(double d) {
        assert (d >= 0.0);
        double d0 = d;
        if (d < 10.0 && (d *= 1.001) != d0) {
            return d;
        }
        if ((d += 1.0) != d0) {
            return d;
        }
        return d *= 1.001;
    }

    public static Double getSemiJoinDistinctRowCount(Join semiJoinRel, RelMetadataQuery mq, ImmutableBitSet groupKey, RexNode predicate) {
        if ((predicate == null || predicate.isAlwaysTrue()) && groupKey.isEmpty()) {
            return 1.0;
        }
        RexNode newPred = RelMdUtil.makeSemiJoinSelectivityRexNode(mq, semiJoinRel);
        if (predicate != null) {
            RexBuilder rexBuilder = semiJoinRel.getCluster().getRexBuilder();
            newPred = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, newPred, predicate);
        }
        return mq.getDistinctRowCount(semiJoinRel.getLeft(), groupKey, newPred);
    }

    public static Double getJoinDistinctRowCount(RelMetadataQuery mq, RelNode joinRel, JoinRelType joinType, ImmutableBitSet groupKey, RexNode predicate, boolean useMaxNdv) {
        if ((predicate == null || predicate.isAlwaysTrue()) && groupKey.isEmpty()) {
            return 1.0;
        }
        Join join = (Join)joinRel;
        if (join.isSemiJoin()) {
            return RelMdUtil.getSemiJoinDistinctRowCount(join, mq, groupKey, predicate);
        }
        ImmutableBitSet.Builder leftMask = ImmutableBitSet.builder();
        ImmutableBitSet.Builder rightMask = ImmutableBitSet.builder();
        RelNode left = joinRel.getInputs().get(0);
        RelNode right = joinRel.getInputs().get(1);
        RelMdUtil.setLeftRightBitmaps(groupKey, leftMask, rightMask, left.getRowType().getFieldCount());
        RexNode leftPred = null;
        RexNode rightPred = null;
        if (predicate != null) {
            ArrayList<RexNode> leftFilters = new ArrayList<RexNode>();
            ArrayList<RexNode> rightFilters = new ArrayList<RexNode>();
            ArrayList<RexNode> joinFilters = new ArrayList<RexNode>();
            List<RexNode> predList = RelOptUtil.conjunctions(predicate);
            RelOptUtil.classifyFilters(joinRel, predList, joinType, !joinType.isOuterJoin(), !joinType.generatesNullsOnLeft(), !joinType.generatesNullsOnRight(), joinFilters, leftFilters, rightFilters);
            RexBuilder rexBuilder = joinRel.getCluster().getRexBuilder();
            leftPred = RexUtil.composeConjunction(rexBuilder, leftFilters, true);
            rightPred = RexUtil.composeConjunction(rexBuilder, rightFilters, true);
        }
        Double distRowCount = useMaxNdv ? Double.valueOf(Math.max(mq.getDistinctRowCount(left, leftMask.build(), leftPred), mq.getDistinctRowCount(right, rightMask.build(), rightPred))) : NumberUtil.multiply(mq.getDistinctRowCount(left, leftMask.build(), leftPred), mq.getDistinctRowCount(right, rightMask.build(), rightPred));
        return RelMdUtil.numDistinctVals(distRowCount, mq.getRowCount(joinRel));
    }

    public static double getUnionAllRowCount(RelMetadataQuery mq, Union rel) {
        double rowCount = 0.0;
        for (RelNode input : rel.getInputs()) {
            rowCount += mq.getRowCount(input).doubleValue();
        }
        return rowCount;
    }

    public static double getMinusRowCount(RelMetadataQuery mq, Minus minus) {
        List<RelNode> inputs = minus.getInputs();
        double dRows = mq.getRowCount(inputs.get(0));
        for (int i = 1; i < inputs.size(); ++i) {
            dRows -= 0.5 * mq.getRowCount(inputs.get(i));
        }
        if (dRows < 0.0) {
            dRows = 0.0;
        }
        return dRows;
    }

    public static Double getJoinRowCount(RelMetadataQuery mq, Join join, RexNode condition) {
        Double max;
        if (!join.getJoinType().projectsRight()) {
            RexNode semiJoinSelectivity = RelMdUtil.makeSemiJoinSelectivityRexNode(mq, join);
            return NumberUtil.multiply(mq.getSelectivity(join.getLeft(), semiJoinSelectivity), mq.getRowCount(join.getLeft()));
        }
        Double left = mq.getRowCount(join.getLeft());
        Double right = mq.getRowCount(join.getRight());
        if (left == null || right == null) {
            return null;
        }
        if ((left <= 1.0 || right <= 1.0) && (max = mq.getMaxRowCount(join)) != null && max <= 1.0) {
            return max;
        }
        double product2 = left * right;
        return product2 * mq.getSelectivity(join, condition);
    }

    public static Double getSemiJoinRowCount(RelMetadataQuery mq, RelNode left, RelNode right, JoinRelType joinType, RexNode condition) {
        Double leftCount = mq.getRowCount(left);
        if (leftCount == null) {
            return null;
        }
        return leftCount * RexUtil.getSelectivity(condition);
    }

    public static double estimateFilteredRows(RelNode child, RexProgram program, RelMetadataQuery mq) {
        RexLocalRef programCondition = program.getCondition();
        RexNode condition = programCondition == null ? null : program.expandLocalRef(programCondition);
        return RelMdUtil.estimateFilteredRows(child, condition, mq);
    }

    public static double estimateFilteredRows(RelNode child, RexNode condition, RelMetadataQuery mq) {
        return mq.getRowCount(child) * mq.getSelectivity(child, condition);
    }

    public static double linear(int x, int minX, int maxX, double minY, double maxY) {
        Preconditions.checkArgument(minX < maxX);
        Preconditions.checkArgument(minY < maxY);
        if (x < minX) {
            return minY;
        }
        if (x > maxX) {
            return maxY;
        }
        return minY + (double)(x - minX) / (double)(maxX - minX) * (maxY - minY);
    }

    public static boolean checkInputForCollationAndLimit(RelMetadataQuery mq, RelNode input, RelCollation collation, RexNode offset, RexNode fetch) {
        int limit;
        int offsetVal;
        boolean alreadySorted = collation.getFieldCollations().isEmpty();
        for (RelCollation inputCollation : mq.collations(input)) {
            if (!inputCollation.satisfies(collation)) continue;
            alreadySorted = true;
            break;
        }
        boolean alreadySmaller = true;
        Double rowCount = mq.getMaxRowCount(input);
        if (rowCount != null && fetch != null && (double)(offsetVal = offset == null ? 0 : RexLiteral.intValue(offset)) + (double)(limit = RexLiteral.intValue(fetch)) < rowCount) {
            alreadySmaller = false;
        }
        return alreadySorted && alreadySmaller;
    }

    public static Double validatePercentage(Double result) {
        assert (RelMdUtil.isPercentage(result, true));
        return result;
    }

    private static boolean isPercentage(Double result, boolean fail) {
        if (result != null) {
            double d = result;
            if (d < 0.0) {
                assert (!fail);
                return false;
            }
            if (d > 1.0) {
                assert (!fail);
                return false;
            }
        }
        return true;
    }

    public static Double validateResult(Double result) {
        if (result == null) {
            return null;
        }
        if (result.isInfinite()) {
            result = Double.MAX_VALUE;
        }
        assert (RelMdUtil.isNonNegative(result, true));
        if (result < 1.0) {
            result = 1.0;
        }
        return result;
    }

    private static boolean isNonNegative(Double result, boolean fail) {
        double d;
        if (result != null && (d = result.doubleValue()) < 0.0) {
            assert (!fail);
            return false;
        }
        return true;
    }

    public static boolean clearCache(RelNode rel) {
        return rel.getCluster().getMetadataQuery().clearCache(rel);
    }

    private static class CardOfProjExpr
    extends RexVisitorImpl<Double> {
        private final RelMetadataQuery mq;
        private Project rel;

        CardOfProjExpr(RelMetadataQuery mq, Project rel) {
            super(true);
            this.mq = mq;
            this.rel = rel;
        }

        @Override
        public Double visitInputRef(RexInputRef var) {
            int index = var.getIndex();
            ImmutableBitSet col = ImmutableBitSet.of(index);
            Double distinctRowCount = this.mq.getDistinctRowCount(this.rel.getInput(), col, null);
            if (distinctRowCount == null) {
                return null;
            }
            return RelMdUtil.numDistinctVals(distinctRowCount, this.mq.getRowCount(this.rel));
        }

        @Override
        public Double visitLiteral(RexLiteral literal) {
            return RelMdUtil.numDistinctVals(1.0, this.mq.getRowCount(this.rel));
        }

        @Override
        public Double visitCall(RexCall call) {
            Double distinctRowCount;
            Double rowCount = this.mq.getRowCount(this.rel);
            if (call.isA(SqlKind.MINUS_PREFIX)) {
                distinctRowCount = RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0));
            } else if (call.isA(ImmutableList.of(SqlKind.PLUS, SqlKind.MINUS))) {
                Double card0 = RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0));
                if (card0 == null) {
                    return null;
                }
                Double card1 = RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(1));
                if (card1 == null) {
                    return null;
                }
                distinctRowCount = Math.max(card0, card1);
            } else {
                distinctRowCount = call.isA(ImmutableList.of(SqlKind.TIMES, SqlKind.DIVIDE)) ? NumberUtil.multiply(RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0)), RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(1))) : (call.getOperands().size() == 1 ? RelMdUtil.cardOfProjExpr(this.mq, this.rel, call.getOperands().get(0)) : Double.valueOf(rowCount / 10.0));
            }
            return RelMdUtil.numDistinctVals(distinctRowCount, rowCount);
        }
    }
}

