Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/query/frontend/ast/ast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ class Aggregation : public memgraph::query::BinaryOperator {
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const override { return kType; }

enum class Op { COUNT, MIN, MAX, SUM, AVG, COLLECT_LIST, COLLECT_MAP, PROJECT };
enum class Op { COUNT, MIN, MAX, SUM, AVG, COLLECT_LIST, COLLECT_MAP, PROJECT_PATH, PROJECT_LISTS };

Aggregation() = default;

Expand All @@ -906,7 +906,7 @@ class Aggregation : public memgraph::query::BinaryOperator {
static const constexpr char *const kProject = "PROJECT";

static std::string OpToString(Op op) {
const char *op_strings[] = {kCount, kMin, kMax, kSum, kAvg, kCollect, kCollect, kProject};
const char *op_strings[] = {kCount, kMin, kMax, kSum, kAvg, kCollect, kCollect, kProject, kProject};
return op_strings[static_cast<int>(op)];
}

Expand Down Expand Up @@ -945,15 +945,15 @@ class Aggregation : public memgraph::query::BinaryOperator {
// Use only for serialization.
explicit Aggregation(Op op) : op_(op) {}

/// Aggregation's first expression is the value being aggregated. The second
/// expression is the key used only in COLLECT_MAP.
/// Aggregation's first expression is the value being aggregated. The second expression is used either as a key in
/// COLLECT_MAP or for the relationships list in the two-argument overload of PROJECT_PATH; no other aggregate
/// functions use this parameter.
Aggregation(Expression *expression1, Expression *expression2, Op op, bool distinct)
: BinaryOperator(expression1, expression2), op_(op), distinct_(distinct) {
// COUNT without expression denotes COUNT(*) in cypher.
DMG_ASSERT(expression1 || op == Aggregation::Op::COUNT, "All aggregations, except COUNT require expression");
DMG_ASSERT((expression2 == nullptr) ^ (op == Aggregation::Op::COLLECT_MAP),
"The second expression is obligatory in COLLECT_MAP and "
"invalid otherwise");
DMG_ASSERT(expression1 || op == Aggregation::Op::COUNT, "All aggregations, except COUNT require expression1");
DMG_ASSERT((expression2 == nullptr) ^ (op == Aggregation::Op::PROJECT_LISTS || op == Aggregation::Op::COLLECT_MAP),
"expression2 is obligatory in COLLECT_MAP and PROJECT_LISTS, and invalid otherwise");
}

private:
Expand Down
6 changes: 5 additions & 1 deletion src/query/frontend/ast/cypher_main_visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3013,7 +3013,7 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio
}
if (upper_function_name == Aggregation::kProject) {
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::PROJECT, is_distinct));
storage_->Create<Aggregation>(expressions[0], nullptr, Aggregation::Op::PROJECT_PATH, is_distinct));
}
}

Expand All @@ -3022,6 +3022,10 @@ antlrcpp::Any CypherMainVisitor::visitFunctionInvocation(MemgraphCypher::Functio
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[1], expressions[0], Aggregation::Op::COLLECT_MAP, is_distinct));
}
if (upper_function_name == Aggregation::kProject) {
return static_cast<Expression *>(
storage_->Create<Aggregation>(expressions[0], expressions[1], Aggregation::Op::PROJECT_LISTS, is_distinct));
}
}

auto *function_expr = storage_->Create<Function>(function_name, expressions);
Expand Down
26 changes: 25 additions & 1 deletion src/query/graph.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2022 Memgraph Ltd.
// Copyright 2025 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
Expand All @@ -10,10 +10,15 @@
// licenses/APL.txt.

#include "query/graph.hpp"
#include <ranges>
#include "query/exceptions.hpp"
#include "query/path.hpp"

namespace memgraph::query {

namespace r = std::ranges;
namespace rv = r::views;

Graph::Graph(utils::MemoryResource *memory) : vertices_(memory), edges_(memory) {}

Graph::Graph(const Graph &other, utils::MemoryResource *memory)
Expand All @@ -36,6 +41,25 @@ void Graph::Expand(const Path &path) {
std::for_each(path_edges_.begin(), path_edges_.end(), [this](const EdgeAccessor e) { edges_.insert(e); });
}

void Graph::Expand(std::span<TypedValue const> const nodes, std::span<TypedValue const> const edges) {
auto actual_nodes = nodes | rv::filter([](auto const &each) { return each.type() == TypedValue::Type::Vertex; }) |
rv::transform([](auto const &each) { return each.ValueVertex(); });

auto actual_edges = edges | rv::filter([](auto const &each) { return each.type() == TypedValue::Type::Edge; }) |
rv::transform([](auto const &each) { return each.ValueEdge(); });

if (r::any_of(actual_edges, [&](auto const &edge) {
return r::find(actual_nodes, edge.From()) == actual_nodes.end() ||
r::find(actual_nodes, edge.To()) == actual_nodes.end();
})) {
throw memgraph::query::QueryRuntimeException(
"Cannot project graph with any projected relationships whose start or end nodes are not also projected.");
}

vertices_.insert(actual_nodes.begin(), actual_nodes.end());
edges_.insert(actual_edges.begin(), actual_edges.end());
}

void Graph::InsertVertex(const VertexAccessor &vertex) { vertices_.insert(vertex); }

void Graph::InsertEdge(const EdgeAccessor &edge) { edges_.insert(edge); }
Expand Down
7 changes: 6 additions & 1 deletion src/query/graph.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Memgraph Ltd.
// Copyright 2025 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
Expand All @@ -14,6 +14,7 @@
#include <functional>
#include <utility>
#include "query/edge_accessor.hpp"
#include "query/typed_value.hpp"
#include "query/vertex_accessor.hpp"
#include "utils/logging.hpp"
#include "utils/memory.hpp"
Expand Down Expand Up @@ -69,6 +70,10 @@ class Graph final {
/** Expands the graph with the given path. */
void Expand(const Path &path);

/** Expand the graph from lists of nodes and edges. Any nulls or duplicate nodes/edges in these lists are ignored.
*/
void Expand(std::span<TypedValue const> nodes, std::span<TypedValue const> edges);

/** Inserts the vertex in the graph. */
void InsertVertex(const VertexAccessor &vertex);

Expand Down
52 changes: 40 additions & 12 deletions src/query/plan/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4077,7 +4077,8 @@ TypedValue DefaultAggregationOpValue(const Aggregate::Element &element, utils::M
return TypedValue(TypedValue::TVector(memory));
case Aggregation::Op::COLLECT_MAP:
return TypedValue(TypedValue::TMap(memory));
case Aggregation::Op::PROJECT:
case Aggregation::Op::PROJECT_PATH:
case Aggregation::Op::PROJECT_LISTS:
return TypedValue(query::Graph(memory));
}
}
Expand Down Expand Up @@ -4240,7 +4241,8 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::SUM:
case Aggregation::Op::COLLECT_LIST:
case Aggregation::Op::COLLECT_MAP:
case Aggregation::Op::PROJECT:
case Aggregation::Op::PROJECT_PATH:
case Aggregation::Op::PROJECT_LISTS:
break;
}
}
Expand Down Expand Up @@ -4307,7 +4309,7 @@ class AggregateCursor : public Cursor {
for (; count_it != counts_end; ++count_it, ++value_it, ++unique_values_it, ++agg_elem_it) {
// COUNT(*) is the only case where input expression is optional
// handle it here
auto input_expr_ptr = agg_elem_it->value;
auto *input_expr_ptr = agg_elem_it->arg1;
if (!input_expr_ptr) {
*count_it += 1;
// value is deferred to post-processing
Expand Down Expand Up @@ -4345,13 +4347,17 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::COLLECT_LIST:
value_it->ValueList().push_back(std::move(input_value));
break;
case Aggregation::Op::PROJECT: {
EnsureOkForProject(input_value);
case Aggregation::Op::PROJECT_PATH: {
EnsureOkForProjectPath(input_value);
value_it->ValueGraph().Expand(input_value.ValuePath());
break;
}
case Aggregation::Op::PROJECT_LISTS: {
ProjectList(input_value, agg_elem_it->arg2->Accept(*evaluator), value_it->ValueGraph());
break;
}
case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(*evaluator);
auto key = agg_elem_it->arg2->Accept(*evaluator);
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
break;
Expand Down Expand Up @@ -4398,20 +4404,43 @@ class AggregateCursor : public Cursor {
case Aggregation::Op::COLLECT_LIST:
value_it->ValueList().push_back(std::move(input_value));
break;
case Aggregation::Op::PROJECT: {
EnsureOkForProject(input_value);
case Aggregation::Op::PROJECT_PATH: {
EnsureOkForProjectPath(input_value);
value_it->ValueGraph().Expand(input_value.ValuePath());
break;
}

case Aggregation::Op::PROJECT_LISTS: {
ProjectList(input_value, agg_elem_it->arg2->Accept(*evaluator), value_it->ValueGraph());
break;
}
case Aggregation::Op::COLLECT_MAP:
auto key = agg_elem_it->key->Accept(*evaluator);
auto key = agg_elem_it->arg2->Accept(*evaluator);
if (key.type() != TypedValue::Type::String) throw QueryRuntimeException("Map key must be a string.");
value_it->ValueMap().emplace(key.ValueString(), std::move(input_value));
break;
} // end switch over Aggregation::Op enum
} // end loop over all aggregations
}

/** Project a subgraph from lists of nodes and lists of edges. Any nulls in these lists are ignored.
*/
static void ProjectList(TypedValue const &arg1, TypedValue const &arg2, Graph &projectedGraph) {
if (arg1.type() != TypedValue::Type::List || !std::ranges::all_of(arg1.ValueList(), [](TypedValue const &each) {
return each.type() == TypedValue::Type::Vertex || each.type() == TypedValue::Type::Null;
})) {
throw QueryRuntimeException("project() argument 1 must be a list of nodes or nulls.");
}

if (arg2.type() != TypedValue::Type::List || !std::ranges::all_of(arg2.ValueList(), [](TypedValue const &each) {
return each.type() == TypedValue::Type::Edge || each.type() == TypedValue::Type::Null;
})) {
throw QueryRuntimeException("project() argument 2 must be a list of relationships or nulls.");
}

projectedGraph.Expand(arg1.ValueList(), arg2.ValueList());
}

/** Checks if the given TypedValue is legal in MIN and MAX. If not
* an appropriate exception is thrown. */
void EnsureOkForMinMax(const TypedValue &value) const {
Expand Down Expand Up @@ -4443,10 +4472,9 @@ class AggregateCursor : public Cursor {
}
}

/** Checks if the given TypedValue is legal in PROJECT and PROJECT_TRANSITIVE. If not
* an appropriate exception is thrown. */
/** Checks if the given TypedValue is legal in PROJECT_PATH. If not an appropriate exception is thrown. */
// NOLINTNEXTLINE(readability-convert-member-functions-to-static)
void EnsureOkForProject(const TypedValue &value) const {
void EnsureOkForProjectPath(const TypedValue &value) const {
switch (value.type()) {
case TypedValue::Type::Path:
return;
Expand Down
14 changes: 7 additions & 7 deletions src/query/plan/operator.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Memgraph Ltd.
// Copyright 2025 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
Expand Down Expand Up @@ -1987,22 +1987,22 @@ class Aggregate : public memgraph::query::plan::LogicalOperator {
const utils::TypeInfo &GetTypeInfo() const override { return kType; }

/// An aggregation element, contains:
/// (input data expression, key expression - only used in COLLECT_MAP, type of
/// aggregation, output symbol).
/// (input data expression, secondary data expression - only used in COLLECT_MAP and PROJECT_LISTS,
/// type of aggregation, output symbol, distinct)
struct Element {
static const utils::TypeInfo kType;
const utils::TypeInfo &GetTypeInfo() const { return kType; }

Expression *value;
Expression *key;
Expression *arg1;
Expression *arg2;
Aggregation::Op op;
Symbol output_sym;
bool distinct{false};

Element Clone(AstStorage *storage) const {
Element object;
object.value = value ? value->Clone(storage) : nullptr;
object.key = key ? key->Clone(storage) : nullptr;
object.arg1 = arg1 ? arg1->Clone(storage) : nullptr;
object.arg2 = arg2 ? arg2->Clone(storage) : nullptr;
object.op = op;
object.output_sym = output_sym;
object.distinct = distinct;
Expand Down
26 changes: 20 additions & 6 deletions src/query/plan/pretty_print.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2024 Memgraph Ltd.
// Copyright 2025 Memgraph Ltd.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt; by using this file, you agree to be bound by the terms of the Business Source
Expand Down Expand Up @@ -441,12 +441,26 @@ json ToJson(const EdgeCreationInfo &edge_info, const DbAccessor &dba) {

json ToJson(const Aggregate::Element &elem, const DbAccessor &dba) {
json json;
if (elem.value) {
json["value"] = ToJson(elem.value, dba);
}
if (elem.key) {
json["key"] = ToJson(elem.key, dba);
if (elem.op == Aggregation::Op::PROJECT_LISTS) {
if (elem.arg1) {
json["nodes"] = ToJson(elem.arg1, dba);
}
if (elem.arg2) {
json["relationships"] = ToJson(elem.arg2, dba);
}
} else if (elem.op == Aggregation::Op::COLLECT_MAP) {
if (elem.arg1) {
json["value"] = ToJson(elem.arg1, dba);
}
if (elem.arg2) {
json["key"] = ToJson(elem.arg2, dba);
}
} else {
if (elem.arg1) {
json["value"] = ToJson(elem.arg1, dba);
}
}

json["op"] = utils::ToLowerCase(Aggregation::OpToString(elem.op));
json["output_symbol"] = ToJson(elem.output_sym);
json["distinct"] = elem.distinct;
Expand Down
8 changes: 4 additions & 4 deletions src/query/plan/rule_based_planner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,10 +388,10 @@ class ReturnBodyContext : public HierarchicalTreeVisitor {
const auto &symbol = symbol_table_.at(aggr);
aggregations_.emplace_back(
Aggregate::Element{aggr.expression1_, aggr.expression2_, aggr.op_, symbol, aggr.distinct_});
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP uses
// two expressions, so we can have 0, 1 or 2 elements on the
// has_aggregation_stack for this Aggregation expression.
if (aggr.op_ == Aggregation::Op::COLLECT_MAP) has_aggregation_.pop_back();
// Aggregation expression1_ is optional in COUNT(*), and COLLECT_MAP and PROJECT_LISTS use two expressions, so we
// can have 0, 1 or 2 elements on the has_aggregation_stack for this Aggregation expression.
if (aggr.op_ == Aggregation::Op::COLLECT_MAP || aggr.op_ == Aggregation::Op::PROJECT_LISTS)
has_aggregation_.pop_back();
if (aggr.expression1_)
has_aggregation_.back() = true;
else
Expand Down
4 changes: 4 additions & 0 deletions src/query/typed_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,10 @@ TypedValue operator+(const TypedValue &a, const TypedValue &b) {

if (a.IsList() || b.IsList()) {
TypedValue::TVector list(a.GetMemoryResource());

size_t const new_list_size{(a.IsList() ? a.ValueList().size() : 1) + (b.IsList() ? b.ValueList().size() : 1)};
list.reserve(new_list_size);

auto append_list = [&list](const TypedValue &v) {
if (v.IsList()) {
auto list2 = v.ValueList();
Expand Down
Loading