Skip to content

Commit

Permalink
[CPU] Fuse SDPA and Concat as early as possible (openvinotoolkit#28189)
Browse files Browse the repository at this point in the history
### Details:
 - *Move StatefulSDPAFusion before CommonOptimizations*
 - *...*

### Tickets:
 - *[158738](https://jira.devtools.intel.com/browse/CVS-158738)*
  • Loading branch information
luo-cheng2021 authored Jan 6, 2025
1 parent 0ea1ecc commit 552ba66
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ namespace gen_pattern {

#ifdef CPU_DEBUG_CAPS

# ifdef __GNUC__
# define CURRENT_LINE_NO __builtin_LINE()
# define CURRENT_FILE __builtin_FILE()
# else
# define CURRENT_LINE_NO -1
# define CURRENT_FILE ""
# endif

template <typename... Args>
static inline void _verbose_log(Args&&... args) {
std::stringstream ss;
Expand All @@ -58,6 +66,10 @@ static bool matcher_verbose_enabled() {
if (matcher_verbose_enabled()) \
_verbose_log(__VA_ARGS__)
#else

# define CURRENT_LINE_NO -1
# define CURRENT_FILE ""

static bool matcher_verbose_enabled() {
return false;
}
Expand Down Expand Up @@ -181,6 +193,8 @@ class Symbol {
double literal_const_value;
std::shared_ptr<Entity> lhs;
std::shared_ptr<Entity> rhs;
const char* filename = "";
int line_no = -1;
// _,+,-,*,/
// l : literal const
// n : named symbol
Expand Down Expand Up @@ -220,10 +234,12 @@ class Symbol {
entity->op = 'n';
entity->name = name;
}
Symbol(const int value) {
Symbol(const int value, int line_no = CURRENT_LINE_NO, const char* file = CURRENT_FILE) {
entity = std::make_shared<Entity>();
entity->op = 'l';
entity->literal_const_value = value;
entity->line_no = line_no;
entity->filename = file;
}
Symbol(char op, const Symbol& lhs, const Symbol& rhs) {
entity = std::make_shared<Entity>();
Expand All @@ -246,8 +262,12 @@ class Symbol {
void* get_id() const {
return entity.get();
}
const char* get_name() const {
return entity->name;
std::string get_name() const {
if (entity->line_no == -1 || is_independent_var())
return entity->name;
auto filename = strrchr(entity->filename, '/') ? strrchr(entity->filename, '/') + 1 : entity->filename;
std::string name(filename); // use filename:lineno instead
return name + ":" + std::to_string(entity->line_no);
}
bool operator<(const Symbol& rhs) const {
return get_id() < rhs.get_id();
Expand Down Expand Up @@ -739,7 +759,9 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
explicit GenericPattern(const DiscreteTypeInfo& type_info,
const OutputVector& args,
const detail::AttrMap& attrs,
const char* vt)
const char* vt,
const int line_no = -1,
const char* file = "")
: ov::pass::pattern::op::Pattern(args),
m_type_info(type_info),
m_attrs(attrs),
Expand All @@ -758,6 +780,12 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
sep = ",";
}
ss << ")";
if (line_no != -1) {
// add the code line no to the log:
// O P752<opset1::Multiply>(P736,P745)@fuse_rotary_positional_embeddings.cpp:551 vs ...
auto filename = strrchr(file, '/') ? strrchr(file, '/') + 1 : file;
ss << "@" << filename << ":" << line_no;
}
m_signature = ss.str();
set_friendly_name(std::string("P") + std::to_string(id));
}
Expand All @@ -776,7 +804,13 @@ class GenericPattern : public ov::pass::pattern::op::Pattern {
// strictly requires pattern & graph value to come from output port with same index,
// this is absolute necessary when pattern contains split node connections.
if (pattern_value.get_index() != graph_value.get_index()) {
_VERBOSE_LOG(level, "X output index mismatch: ", pattern_value.get_index(), "!=", graph_value.get_index());
_VERBOSE_LOG(level,
"X output index mismatch:(",
m_signature,
"): ",
pattern_value.get_index(),
"!=",
graph_value.get_index());
return false;
}

Expand Down Expand Up @@ -1018,15 +1052,18 @@ template <class T>
std::shared_ptr<Node> makePattern(const std::vector<detail::PatternNode>& inputs,
detail::AttrMap attrmap = {},
const char* vt = nullptr,
const char* friendly_name = nullptr) {
const char* friendly_name = nullptr,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
OutputVector args;
for (auto& in : inputs)
args.push_back(in.get_output());

// pattern nodes are better for pattern matching because
// - it can be generic/incomplete, so normal OP node is not working properly
// - it has predicate to correctly decide which branch to take (in Or pattern)
auto pattern_node = std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt);
auto pattern_node =
std::make_shared<detail::GenericPattern>(T::get_type_info_static(), args, attrmap, vt, line_no, file);

if (friendly_name)
pattern_node->set_friendly_name(friendly_name);
Expand Down Expand Up @@ -1120,7 +1157,9 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
detail::PatternNode start,
detail::PatternNode stop,
detail::PatternNode step,
size_t axis) {
size_t axis,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
std::vector<int64_t> begin_mask(axis + 1, 1);
std::vector<int64_t> end_mask(axis + 1, 1);
std::vector<int64_t> new_axis_mask;
Expand All @@ -1135,12 +1174,27 @@ inline std::shared_ptr<Node> GenStridedSlice(detail::PatternNode data,
{"end_mask", end_mask},
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
{"ellipsis_mask", ellipsis_mask}},
nullptr,
nullptr,
line_no,
file);
return opt2;
}

inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Symbol stop, Symbol step, size_t axis) {
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}});
inline std::shared_ptr<Node> GenSlice(detail::PatternNode data,
Symbol start,
Symbol stop,
Symbol step,
size_t axis,
int line_no = CURRENT_LINE_NO,
const char* file = CURRENT_FILE) {
auto opt1 = makePattern<opset8::Slice>({data, {start}, {stop}, {step}, {static_cast<int>(axis)}},
{},
nullptr,
nullptr,
line_no,
file);

std::vector<Symbol> vbegin(axis + 1, Symbol(0));
std::vector<Symbol> vend(axis + 1, Symbol(0));
Expand Down Expand Up @@ -1168,7 +1222,11 @@ inline std::shared_ptr<Node> GenSlice(detail::PatternNode data, Symbol start, Sy
{"end_mask", end_mask},
{"new_axis_mask", new_axis_mask},
{"shrink_axis_mask", shrink_axis_mask},
{"ellipsis_mask", ellipsis_mask}});
{"ellipsis_mask", ellipsis_mask}},
nullptr,
nullptr,
line_no,
file);
return opt1 | opt2;
}

Expand Down Expand Up @@ -1329,7 +1387,9 @@ class PatternValidator {
auto id = sym.get_id();
if (symbol_value_map.count(id)) {
if (symbol_value_map[id] != value) {
_VERBOSE_LOG(" in-consistency between multiple references of same symbol : ",
_VERBOSE_LOG(" in-consistency between multiple references of same symbol(",
sym.get_name(),
"): ",
symbol_value_map[id],
" != ",
value);
Expand All @@ -1345,7 +1405,12 @@ class PatternValidator {
if (sym.is_literal_const()) {
auto literal = sym.eval(symbol_value_map);
if (literal != value) {
_VERBOSE_LOG(" mismatch between literal symbol & value : ", literal, " != ", value);
_VERBOSE_LOG(" mismatch between literal symbol & value(",
sym.get_name(),
"): ",
literal,
" != ",
value);
return false;
}
// no need to put literal into value map to eval them.
Expand Down Expand Up @@ -1373,7 +1438,9 @@ class PatternValidator {
}
}
if (!is_match) {
_VERBOSE_LOG(" mismatch between derived & value : ",
_VERBOSE_LOG(" mismatch between derived & value(",
sym.get_name(),
"): ",
std::setprecision(std::numeric_limits<float>::max_digits10),
derived,
" != ",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <openvino/opsets/opset13.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/manager.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/gen_pattern.hpp>
Expand All @@ -20,7 +21,12 @@
#include "itt.hpp"
#include "openvino/opsets/opset1.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/common_optimizations/simplify_shape_of_sub_graph.hpp"
#include "transformations/cpu_opset/common/op/sdpa.hpp"
#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp"
#include "transformations/defs.hpp"
#include "transformations/op_conversions/convert_broadcast3.hpp"
#include "transformations/transpose_sinking/ts_shape_of.hpp"
using namespace ov::gen_pattern;

namespace ov {
Expand Down Expand Up @@ -56,8 +62,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
std::shared_ptr<Node> reshape_k, reshape_v, unsqueeze_k, unsqueeze_v;
std::shared_ptr<Node> computed_bcst_k, computed_bcst_v, multiply_k, multiply_v;
std::shared_ptr<Node> mq_reshape_k, mq_reshape_v;
std::shared_ptr<Node> computed_bcst3_k, computed_bcst3_v;
auto multi_query_bcst = [](const std::shared_ptr<Node>& kv) {
auto reshape_kv = wrap_type<opset6::Reshape>({kv, any_input()});
auto reshape_kv = makePattern<opset6::Reshape>({kv, any_input()});
auto unsqueeze_kv = makePattern<opset1::Unsqueeze>({kv, any_input()});

auto check_one = [](Output<Node> output) -> bool {
Expand All @@ -73,13 +80,17 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
makePattern<opset1::Broadcast>({wrap_type<opset1::Constant>(check_one), any_input(), any_input()},
{{"mode", "numpy"}});

auto multiply_kv = wrap_type<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
auto result = wrap_type<opset6::Reshape>({multiply_kv, any_input()});
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv);
auto multiply_kv = makePattern<opset6::Multiply>({reshape_kv | unsqueeze_kv, constant_bcst | computed_bcst});
auto computed_bcst3 = makePattern<opset3::Broadcast>({unsqueeze_kv, any_input()}, {{"mode", "bidirectional"}});

auto result = makePattern<opset6::Reshape>({multiply_kv | computed_bcst3, any_input()});
return std::make_tuple(result, reshape_kv, unsqueeze_kv, computed_bcst, multiply_kv, computed_bcst3);
};

std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k) = multi_query_bcst(concat_k);
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v) = multi_query_bcst(concat_v);
std::tie(mq_reshape_k, reshape_k, unsqueeze_k, computed_bcst_k, multiply_k, computed_bcst3_k) =
multi_query_bcst(concat_k);
std::tie(mq_reshape_v, reshape_v, unsqueeze_v, computed_bcst_v, multiply_v, computed_bcst3_v) =
multi_query_bcst(concat_v);
auto present_k = concat_k | mq_reshape_k;
auto present_v = concat_v | mq_reshape_v;

Expand Down Expand Up @@ -178,15 +189,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() {

opset6::Assign *assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert *assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node))
if (!find_assign(concat_k_node, assign_k_node, assign_cvt_k_node)) {
return false;
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
}
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) {
return false;
}

if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node))
if (!find_assign(concat_v_node, assign_v_node, assign_cvt_v_node)) {
return false;
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
}
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) {
return false;
}

auto is_optional_one_child = [&pattern_map](const std::vector<std::shared_ptr<Node>>& nodes) {
for (auto&& node : nodes) {
Expand All @@ -212,7 +227,9 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
computed_bcst_v,
multiply_v,
mq_reshape_k,
mq_reshape_v})) {
mq_reshape_v,
computed_bcst3_k,
computed_bcst3_v})) {
return false;
}

Expand Down Expand Up @@ -284,5 +301,19 @@ StatefulSDPAFusion::StatefulSDPAFusion() {
this->register_matcher(m, callback);
}

bool SDPASubgraphFusion::run_on_model(const std::shared_ptr<ov::Model>& f) {
RUN_ON_FUNCTION_SCOPE(SDPASubgraphFusion);
ov::pass::Manager manager("SDPASubgraphFusion");

CPU_REGISTER_PASS_COMMON(manager, ov::pass::SimplifyGatherShapeOf);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::transpose_sinking::TSShapeOfForward);
CPU_REGISTER_PASS_COMMON(manager, StatefulSDPAFusion);
// TODO: remove the following after snippets support patterns with dynamic shapes
CPU_REGISTER_PASS_X64(manager, ov::intel_cpu::SDPAFuseTransposeReshape);

manager.run_passes(f);
return false;
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,12 @@ class StatefulSDPAFusion : public ov::pass::MatcherPass {
StatefulSDPAFusion();
};

class SDPASubgraphFusion : public ov::pass::ModelPass {
public:
OPENVINO_RTTI("SDPASubgraphFusion", "0");

bool run_on_model(const std::shared_ptr<ov::Model>& f) override;
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
* Description: SDPA fuse transpose and reshape.
* Original pattern Fused pattern
*
* input1 input2 input3
* input1 readvalue readvalue
* | | |
* q_reshape k_reshape v_reshap
* | | | (qkv transpose and reshape's orders)
* q_transpose k_transpose v_transpose |
* \ | / input1 input2 input3 |
* \ | / \ | / /
* q_transpose k_transpose v_transpose |
* \ | / input1 ReadValue ReadValue |
* \ | / \ | / /
* ScaledDotProductAttention ---------> SDPAWithTransposeReshape
* | |
* out_transpose |
Expand All @@ -41,8 +41,8 @@ intel_cpu::SDPAFuseTransposeReshape::SDPAFuseTransposeReshape() {
MATCHER_SCOPE(SDPAFuseTransposeReshape);

auto q_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
auto k_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
auto v_reshape_node = wrap_type<op::v1::Reshape>({any_input(), any_input()});
auto k_reshape_node = wrap_type<op::v1::Reshape>({wrap_type<op::v6::ReadValue>(), any_input()});
auto v_reshape_node = wrap_type<op::v1::Reshape>({wrap_type<op::v6::ReadValue>(), any_input()});

auto q_transpose_order_node = wrap_type<op::v0::Constant>();
auto k_transpose_order_node = wrap_type<op::v0::Constant>();
Expand Down
Loading

0 comments on commit 552ba66

Please sign in to comment.