Skip to content

Commit

Permalink
[KQP] Added join algo hint without CBO (#10740)
Browse files Browse the repository at this point in the history
  • Loading branch information
pashandor789 authored Oct 28, 2024
1 parent 5bedea9 commit 215a87c
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 39 deletions.
4 changes: 2 additions & 2 deletions ydb/core/kqp/opt/logical/kqp_opt_log.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,14 @@ class TKqpLogicalOptTransformer : public TOptimizeTransformerBase {

TMaybeNode<TExprBase> RewriteEquiJoin(TExprBase node, TExprContext& ctx) {
bool useCBO = Config->CostBasedOptimizationLevel.Get().GetOrElse(Config->DefaultCostBasedOptimizationLevel) >= 2;
TExprBase output = DqRewriteEquiJoin(node, KqpCtx.Config->GetHashJoinMode(), useCBO, ctx, TypesCtx, KqpCtx.JoinsCount);
TExprBase output = DqRewriteEquiJoin(node, KqpCtx.Config->GetHashJoinMode(), useCBO, ctx, TypesCtx, KqpCtx.JoinsCount, KqpCtx.GetOptimizerHints());
DumpAppliedRule("RewriteEquiJoin", node.Ptr(), output.Ptr(), ctx);
return output;
}

TMaybeNode<TExprBase> JoinToIndexLookup(TExprBase node, TExprContext& ctx) {
bool useCBO = Config->CostBasedOptimizationLevel.Get().GetOrElse(Config->DefaultCostBasedOptimizationLevel) >= 2;
TExprBase output = KqpJoinToIndexLookup(node, ctx, KqpCtx, useCBO);
TExprBase output = KqpJoinToIndexLookup(node, ctx, KqpCtx, useCBO, KqpCtx.GetOptimizerHints());
DumpAppliedRule("JoinToIndexLookup", node.Ptr(), output.Ptr(), ctx);
return output;
}
Expand Down
60 changes: 52 additions & 8 deletions ydb/core/kqp/opt/logical/kqp_opt_log_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ bool IsParameterToListOfStructsRepack(const TExprBase& expr) {
return true;
}

//#define DBG(...) YQL_CLOG(DEBUG, ProviderKqp) << __VA_ARGS__
// #define DBG(...) YQL_CLOG(DEBUG, ProviderKqp) << __VA_ARGS__
#define DBG(...)

TMaybeNode<TExprBase> BuildKqpStreamIndexLookupJoin(
Expand Down Expand Up @@ -928,7 +928,38 @@ TMaybeNode<TExprBase> KqpJoinToIndexLookupImpl(const TDqJoin& join, TExprContext

} // anonymous namespace

TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, bool useCBO)
TVector<TString> CollectLabels(const TExprBase& node) {
TVector<TString> rels;

if (node.Maybe<TDqPrecompute>()) {
auto precompute = node.Cast<TDqPrecompute>();
return CollectLabels(precompute.Input());
}

if (node.Maybe<TDqJoin>()) {
auto join = node.Cast<TDqJoin>();

if (join.LeftLabel().Maybe<TCoAtom>()) {
rels.push_back(join.LeftLabel().Cast<TCoAtom>().StringValue());
} else {
auto lhs = CollectLabels(join.LeftInput());
rels.insert(rels.end(), std::make_move_iterator(lhs.begin()), std::make_move_iterator(lhs.end()));
}

if (join.RightLabel().Maybe<TCoAtom>()) {
rels.push_back(join.RightLabel().Cast<TCoAtom>().StringValue());
} else {
auto rhs = CollectLabels(join.RightInput());
rels.insert(rels.end(), std::make_move_iterator(rhs.begin()), std::make_move_iterator(rhs.end()));
}

return rels;
}

return {};
}

TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const TKqpOptimizeContext& kqpCtx, bool useCBO, const TOptimizerHints& hints)
{
if (!node.Maybe<TDqJoin>()) {
return node;
Expand All @@ -945,20 +976,33 @@ TExprBase KqpJoinToIndexLookup(const TExprBase& node, TExprContext& ctx, const T
return node;
}

if (useCBO){

if (algo != EJoinAlgoType::LookupJoin && algo != EJoinAlgoType::LookupJoinReverse) {
if (useCBO && algo != EJoinAlgoType::LookupJoin && algo != EJoinAlgoType::LookupJoinReverse){
return node;
}

/*
* this cycle looks for applied hints for these join labels. if we've found one then we will leave the function.
* But if it is a LookupJoin we will rewrite it with KqpJoinToIndexLookupImpl because lookup join needs to be rewritten
*/
auto joinLabels = CollectLabels(node);
for (const auto& hint: hints.JoinAlgoHints->Hints) {
if (
std::unordered_set<TString>(hint.JoinLabels.begin(), hint.JoinLabels.end()) ==
std::unordered_set<TString>(joinLabels.begin(), joinLabels.end()) && hint.Applied
) {
if (hint.Algo == EJoinAlgoType::LookupJoin || hint.Algo == EJoinAlgoType::LookupJoinReverse) {
break;
}

return node;
}
}
}

DBG("-- Join: " << KqpExprToPrettyString(join, ctx));

// SqlIn support (preferred lookup direction)
if (join.JoinType().Value() == "LeftSemi") {
auto flipJoin = FlipLeftSemiJoin(join, ctx);
DBG("-- Flip join");

if (auto indexLookupJoin = KqpJoinToIndexLookupImpl(flipJoin, ctx, kqpCtx)) {
return indexLookupJoin.Cast();
}
Expand Down
2 changes: 1 addition & 1 deletion ydb/core/kqp/opt/logical/kqp_opt_log_rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ NYql::NNodes::TExprBase KqpPushExtractedPredicateToReadTable(NYql::NNodes::TExpr
const TKqpOptimizeContext& kqpCtx, NYql::TTypeAnnotationContext& typesCtx, const NYql::TParentsMap& parentsMap);

NYql::NNodes::TExprBase KqpJoinToIndexLookup(const NYql::NNodes::TExprBase& node, NYql::TExprContext& ctx,
const TKqpOptimizeContext& kqpCtx, bool useCBO);
const TKqpOptimizeContext& kqpCtx, bool useCBO, const NYql::TOptimizerHints& hints);

NYql::NNodes::TExprBase KqpRewriteSqlInToEquiJoin(const NYql::NNodes::TExprBase& node, NYql::TExprContext& ctx,
const TKqpOptimizeContext& kqpCtx, const NYql::TKikimrConfiguration::TPtr& config);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
PRAGMA TablePathPrefix='/Root';
PRAGMA ydb.OptimizerHints =
'
JoinType(R S Shuffle)
JoinType(R S T Broadcast)
JoinType(R S T U Shuffle)
JoinType(R S T U V Broadcast)
';

SELECT * FROM
R INNER JOIN S on R.id = S.id
INNER JOIN T on R.id = T.id
INNER JOIN U on T.id = U.id
INNER JOIN V on U.id = V.id;
76 changes: 71 additions & 5 deletions ydb/core/kqp/ut/join/kqp_join_order_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ static void CreateSampleTable(TSession session, bool useColumnStore) {
CreateTables(session, "schema/lookupbug.sql", useColumnStore);
}

static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false, TString stats = ""){
static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false, TString stats = "", bool useCBO = true){
TVector<NKikimrKqp::TKqpSetting> settings;

NKikimrKqp::TKqpSetting setting;
Expand All @@ -96,6 +96,9 @@ static TKikimrRunner GetKikimrWithJoinSettings(bool useStreamLookupJoin = false,
appConfig.MutableTableServiceConfig()->SetEnableKqpDataQueryStreamIdxLookupJoin(useStreamLookupJoin);
appConfig.MutableTableServiceConfig()->SetEnableConstantFolding(true);
appConfig.MutableTableServiceConfig()->SetCompileTimeoutMs(TDuration::Minutes(10).MilliSeconds());
if (!useCBO) {
appConfig.MutableTableServiceConfig()->SetDefaultCostBasedOptimizationLevel(0);
}

auto serverSettings = TKikimrSettings().SetAppConfig(appConfig);
serverSettings.SetKqpSettings(settings);
Expand Down Expand Up @@ -193,8 +196,8 @@ class TChainTester {
size_t ChainSize;
};

void ExplainJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore) {
auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath));
void ExplainJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore, bool useCBO = true) {
auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath), useCBO);
auto db = kikimr.GetTableClient();
auto session = db.CreateSession().GetValueSync().GetSession();

Expand Down Expand Up @@ -329,8 +332,8 @@ Y_UNIT_TEST_SUITE(KqpJoinOrder) {
TChainTester(65).Test();
}

TString ExecuteJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore) {
auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath));
TString ExecuteJoinOrderTestDataQueryWithStats(const TString& queryPath, const TString& statsPath, bool useStreamLookupJoin, bool useColumnStore, bool useCBO = true) {
auto kikimr = GetKikimrWithJoinSettings(useStreamLookupJoin, GetStatic(statsPath), useCBO);
auto db = kikimr.GetTableClient();
auto session = db.CreateSession().GetValueSync().GetSession();

Expand Down Expand Up @@ -514,6 +517,69 @@ Y_UNIT_TEST_SUITE(KqpJoinOrder) {
CheckJoinCardinality("queries/test_join_hint2.sql", "stats/basic.json", "InnerJoin (MapJoin)", 1, StreamLookupJoin, ColumnStore);
}


class TFindJoinWithLabels {
public:
TFindJoinWithLabels(
const NJson::TJsonValue& plan
)
: Plan(plan)
{}

TString Find(const TVector<TString>& labels) {
Labels = labels;
std::sort(Labels.begin(), Labels.end());
TVector<TString> dummy;
auto res = FindImpl(Plan, dummy);
return res;
}

private:
TString FindImpl(const NJson::TJsonValue& plan, TVector<TString>& subtreeLabels) {
auto planMap = plan.GetMapSafe();
if (!planMap.contains("table")) {
TString opName = planMap.at("op_name").GetStringSafe();

auto inputs = planMap.at("args").GetArraySafe();
for (size_t i = 0; i < inputs.size(); ++i) {
TVector<TString> childLabels;
if (auto maybeOpName = FindImpl(inputs[i], childLabels) ) {
return maybeOpName;
}
subtreeLabels.insert(subtreeLabels.end(), childLabels.begin(), childLabels.end());
}

if (AreRequestedLabels(subtreeLabels)) {
return opName;
}

return "";
}

subtreeLabels = {planMap.at("table").GetStringSafe()};
return "";
}

bool AreRequestedLabels(TVector<TString> labels) {
std::sort(labels.begin(), labels.end());
return Labels == labels;
}

NJson::TJsonValue Plan;
TVector<TString> Labels;
};

Y_UNIT_TEST(OltpJoinTypeHintCBOTurnOFF) {
auto plan = ExecuteJoinOrderTestDataQueryWithStats("queries/oltp_join_type_hint_cbo_turnoff.sql", "stats/basic.json", false, false, false);
auto detailedPlan = GetDetailedJoinOrder(plan);

auto joinFinder = TFindJoinWithLabels(detailedPlan);
UNIT_ASSERT(joinFinder.Find({"R", "S"}) == "InnerJoin (Grace)");
UNIT_ASSERT(joinFinder.Find({"R", "S", "T"}) == "InnerJoin (MapJoin)");
UNIT_ASSERT(joinFinder.Find({"R", "S", "T", "U"}) == "InnerJoin (Grace)");
UNIT_ASSERT(joinFinder.Find({"R", "S", "T", "U", "V"}) == "InnerJoin (MapJoin)");
}

Y_UNIT_TEST_XOR_OR_BOTH_FALSE(TestJoinOrderHintsSimple, StreamLookupJoin, ColumnStore) {
auto plan = ExecuteJoinOrderTestDataQueryWithStats("queries/join_order_hints_simple.sql", "stats/basic.json", StreamLookupJoin, ColumnStore);
UNIT_ASSERT_VALUES_EQUAL(GetJoinOrder(plan).GetStringRobust(), R"(["T",["R","S"]])") ;
Expand Down
81 changes: 60 additions & 21 deletions ydb/library/yql/dq/opt/dq_opt_join.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,44 +116,67 @@ TExprBase BuildDqJoinInput(TExprContext& ctx, TPositionHandle pos, const TExprBa
return partition;
}

TMaybe<TJoinInputDesc> BuildDqJoin(const TCoEquiJoinTuple& joinTuple,
const THashMap<TStringBuf, TJoinInputDesc>& inputs, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx)
TMaybe<TJoinInputDesc> BuildDqJoin(
const TCoEquiJoinTuple& joinTuple,
const THashMap<TStringBuf, TJoinInputDesc>& inputs,
EHashJoinMode mode,
TExprContext& ctx,
const TTypeAnnotationContext& typeCtx,
TVector<TString>& subtreeLabels,
const NYql::TOptimizerHints& hints
)
{
auto options = joinTuple.Options();
auto linkSettings = GetEquiJoinLinkSettings(options.Ref());
YQL_ENSURE(linkSettings.JoinAlgo != EJoinAlgoType::StreamLookupJoin || typeCtx.StreamLookupJoin, "Unsupported join strategy: streamlookup");

if (linkSettings.JoinAlgo == EJoinAlgoType::MapJoin) {
mode = EHashJoinMode::Map;
} else if (linkSettings.JoinAlgo == EJoinAlgoType::GraceJoin) {
mode = EHashJoinMode::GraceAndSelf;
}

bool leftAny = linkSettings.LeftHints.contains("any");
bool rightAny = linkSettings.RightHints.contains("any");

TMaybe<TJoinInputDesc> left;
TVector<TString> lhsLabels;
if (joinTuple.LeftScope().Maybe<TCoAtom>()) {
lhsLabels.push_back(joinTuple.LeftScope().Cast<TCoAtom>().StringValue());
left = inputs.at(joinTuple.LeftScope().Cast<TCoAtom>().Value());
YQL_ENSURE(left, "unknown scope " << joinTuple.LeftScope().Cast<TCoAtom>().Value());
} else {
left = BuildDqJoin(joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), inputs, mode, useCBO, ctx, typeCtx);
left = BuildDqJoin(joinTuple.LeftScope().Cast<TCoEquiJoinTuple>(), inputs, mode, ctx, typeCtx, lhsLabels, hints);
if (!left) {
return {};
}
}

TMaybe<TJoinInputDesc> right;
TVector<TString> rhsLabels;
if (joinTuple.RightScope().Maybe<TCoAtom>()) {
rhsLabels.push_back(joinTuple.RightScope().Cast<TCoAtom>().StringValue());
right = inputs.at(joinTuple.RightScope().Cast<TCoAtom>().Value());
YQL_ENSURE(right, "unknown scope " << joinTuple.RightScope().Cast<TCoAtom>().Value());
} else {
right = BuildDqJoin(joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), inputs, mode, useCBO, ctx, typeCtx);
right = BuildDqJoin(joinTuple.RightScope().Cast<TCoEquiJoinTuple>(), inputs, mode, ctx, typeCtx, rhsLabels, hints);
if (!right) {
return {};
}
}

subtreeLabels.insert(subtreeLabels.end(), std::make_move_iterator(lhsLabels.begin()), std::make_move_iterator(lhsLabels.end()));
subtreeLabels.insert(subtreeLabels.end(), std::make_move_iterator(rhsLabels.begin()), std::make_move_iterator(rhsLabels.end()));

auto options = joinTuple.Options();
auto linkSettings = GetEquiJoinLinkSettings(options.Ref());
for (auto& hint: hints.JoinAlgoHints->Hints) {
if (
std::unordered_set<std::string>(hint.JoinLabels.begin(), hint.JoinLabels.end()) ==
std::unordered_set<std::string>(subtreeLabels.begin(), subtreeLabels.end())
) {
linkSettings.JoinAlgo = hint.Algo;
hint.Applied = true;
}
}
YQL_ENSURE(linkSettings.JoinAlgo != EJoinAlgoType::StreamLookupJoin || typeCtx.StreamLookupJoin, "Unsupported join strategy: streamlookup");

if (linkSettings.JoinAlgo == EJoinAlgoType::MapJoin) {
mode = EHashJoinMode::Map;
} else if (linkSettings.JoinAlgo == EJoinAlgoType::GraceJoin) {
mode = EHashJoinMode::GraceAndSelf;
}

bool leftAny = linkSettings.LeftHints.contains("any");
bool rightAny = linkSettings.RightHints.contains("any");

TStringBuf joinType = joinTuple.Type().Value();
TSet<std::pair<TStringBuf, TStringBuf>> resultKeys;
if (joinType != TStringBuf("RightOnly") && joinType != TStringBuf("RightSemi")) {
Expand Down Expand Up @@ -392,17 +415,32 @@ bool CheckJoinColumns(const TExprBase& node) {
}
}

TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx) {
TExprBase DqRewriteEquiJoin(
const TExprBase& node,
EHashJoinMode mode,
bool useCBO,
TExprContext& ctx,
const TTypeAnnotationContext& typeCtx,
const TOptimizerHints& hints
) {
int dummyJoinCounter = 0;
return DqRewriteEquiJoin(node, mode, useCBO, ctx, typeCtx, dummyJoinCounter);
return DqRewriteEquiJoin(node, mode, useCBO, ctx, typeCtx, dummyJoinCounter, hints);
}

/**
* Rewrite `EquiJoin` to a number of `DqJoin` callables. This is done to simplify next step of building
* physical stages with join operators.
* Potentially this optimizer can also perform joins reorder given cardinality information.
*/
TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, int& joinCounter) {
TExprBase DqRewriteEquiJoin(
const TExprBase& node,
EHashJoinMode mode,
bool /* useCBO */,
TExprContext& ctx,
const TTypeAnnotationContext& typeCtx,
int& joinCounter,
const TOptimizerHints& hints
) {
if (!node.Maybe<TCoEquiJoin>()) {
return node;
}
Expand All @@ -419,7 +457,8 @@ TExprBase DqRewriteEquiJoin(const TExprBase& node, EHashJoinMode mode, bool useC
}

auto joinTuple = equiJoin.Arg(equiJoin.ArgCount() - 2).Cast<TCoEquiJoinTuple>();
auto result = BuildDqJoin(joinTuple, inputs, mode, useCBO, ctx, typeCtx);
TVector<TString> dummy;
auto result = BuildDqJoin(joinTuple, inputs, mode, ctx, typeCtx, dummy, hints);
if (!result) {
return node;
}
Expand Down
5 changes: 3 additions & 2 deletions ydb/library/yql/dq/opt/dq_opt_join.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <ydb/library/yql/dq/common/dq_common.h>
#include <ydb/library/yql/core/yql_expr_optimize.h>
#include <ydb/library/yql/core/cbo/cbo_optimizer_new.h>

namespace NYql {

Expand All @@ -12,9 +13,9 @@ struct TRelOptimizerNode;

namespace NDq {

NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx);
NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, const TOptimizerHints& hints = {});

NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, int& joinCounter);
NNodes::TExprBase DqRewriteEquiJoin(const NNodes::TExprBase& node, EHashJoinMode mode, bool useCBO, TExprContext& ctx, const TTypeAnnotationContext& typeCtx, int& joinCounter, const TOptimizerHints& hints = {});

NNodes::TExprBase DqBuildPhyJoin(const NNodes::TDqJoin& join, bool pushLeftStage, TExprContext& ctx, IOptimizationContext& optCtx, bool useGraceCoreForMap);

Expand Down

0 comments on commit 215a87c

Please sign in to comment.