Skip to content

Commit

Permalink
[Snippets] [ARM]: Fixed bug in PReLU emitter, enabled PReLU, Sqrt, Ro…
Browse files Browse the repository at this point in the history
…und tokenization (openvinotoolkit#28223)

### Details:
 - Fixed a bug in PReLU emitter
 - Enabled PReLU, Sqrt, Round tokenization
 - All local tests pass

### Tickets:
 - openvinotoolkit#28161
  • Loading branch information
0xfedcafe authored Jan 6, 2025
1 parent ffeff3e commit 1d955cd
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "common/utils.hpp"
#include "emitters/utils.hpp"
#include "openvino/core/type/element_type.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"

namespace ov {
Expand Down Expand Up @@ -2128,7 +2129,7 @@ size_t jit_prelu_emitter::get_aux_vecs_count() const {

std::set<std::vector<element::Type>> jit_prelu_emitter::get_supported_precisions(
const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
return {{element::f32, element::f32}};
}

void jit_prelu_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@
#include "emitters/snippets/cpu_runtime_configurator.hpp"
#include "emitters/utils.hpp"
#include "jit_snippets_emitters.hpp"
#include "openvino/core/type.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/round.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/opsets/opset13.hpp"
#include "snippets/emitter.hpp"
#include "snippets/lowered/expression.hpp"
#include "snippets/snippets_isa.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
Expand Down Expand Up @@ -44,7 +50,7 @@ namespace ov {
{ \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
const auto& n = expr->get_node(); \
const auto& gelu = std::dynamic_pointer_cast<ov::op::v7::Gelu>(n); \
const auto& gelu = ov::as_type_ptr<ov::op::v7::Gelu>(n); \
if (gelu == nullptr) { \
OPENVINO_THROW("Can't cast to ov::op::v7::Gelu"); \
} \
Expand Down Expand Up @@ -73,6 +79,37 @@ namespace ov {
} \
}

#define CREATE_ROUND_V5_EMITTER(e_type_from_zero, e_type_even) \
{ \
[this](const snippets::lowered::ExpressionPtr& expr) -> std::shared_ptr<snippets::Emitter> { \
const auto& n = expr->get_node(); \
const auto& round = ov::as_type_ptr<ov::op::v5::Round>(n); \
if (round == nullptr) { \
OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \
} \
const auto roundingMode = round->get_mode(); \
if (roundingMode == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \
return std::make_shared<e_type_from_zero>(h.get(), isa, n); \
} else if (roundingMode == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \
return std::make_shared<e_type_even>(h.get(), isa, n); \
} else { \
OPENVINO_THROW("Unsupported Round mode"); \
} \
}, \
[](const std::shared_ptr<ov::Node>& n) -> std::set<std::vector<element::Type>> { \
const auto& round = std::dynamic_pointer_cast<ov::op::v5::Round>(n); \
if (round == nullptr) { \
OPENVINO_THROW("Can't cast to ov::op::v5::Round"); \
} \
if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_AWAY_FROM_ZERO) { \
return e_type_from_zero::get_supported_precisions(n); \
} else if (round->get_mode() == ov::op::v5::Round::RoundMode::HALF_TO_EVEN) { \
return e_type_even::get_supported_precisions(n); \
} \
OPENVINO_THROW("Unsupported Round mode"); \
} \
}

class jit_snippet : public dnnl::impl::cpu::aarch64::jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_snippet)
Expand Down Expand Up @@ -155,8 +192,12 @@ CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::aarch64::cpu_isa_t host_isa)
CREATE_GELU_V7_EMITTER(jit_gelu_erf_emitter, jit_gelu_tanh_emitter);
jitters[ov::op::v4::HSwish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_hswish_emitter);
jitters[ov::op::v4::Mish::get_type_info_static()] = CREATE_CPU_EMITTER(jit_mish_emitter);
jitters[ov::op::v0::PRelu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_prelu_emitter);
jitters[ov::op::v0::Relu::get_type_info_static()] = CREATE_CPU_EMITTER(jit_relu_emitter);
jitters[ov::op::v5::Round::get_type_info_static()] =
CREATE_ROUND_V5_EMITTER(jit_round_half_away_from_zero_emitter, jit_round_half_to_even_emitter);
jitters[ov::op::v0::Sigmoid::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sigmoid_emitter);
jitters[ov::op::v0::Sqrt::get_type_info_static()] = CREATE_CPU_EMITTER(jit_sqrt_emitter);
jitters[ov::intel_cpu::SwishNode::get_type_info_static()] = CREATE_CPU_EMITTER(jit_swish_emitter);
jitters[ov::op::v0::Tanh::get_type_info_static()] = CREATE_CPU_EMITTER(jit_tanh_emitter);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
#include <ov_ops/gather_compressed.hpp>

#include "openvino/op/paged_attention.hpp"
#include "openvino/op/prelu.hpp"
#include "openvino/op/round.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/opsets/opset1.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/opsets/opset2.hpp"
Expand Down Expand Up @@ -1123,16 +1126,17 @@ void Transformations::MainSnippets(void) {
ov::is_type<ov::op::v0::Clamp>(n) || ov::is_type<ov::op::v0::Ceiling>(n) ||
ov::is_type<ov::op::v0::Convert>(n) || ov::is_type<ov::op::v1::Divide>(n) ||
ov::is_type<ov::op::v0::Elu>(n) || ov::is_type<ov::op::v0::Exp>(n) ||
ov::is_type<ov::op::v0::Floor>(n) || ov::is_type<ov::op::v1::FloorMod>(n) ||
ov::is_type<ov::op::v0::Gelu>(n) || ov::is_type<ov::op::v7::Gelu>(n) ||
ov::is_type<ov::op::v4::HSwish>(n) || ov::is_type<ov::op::v1::Maximum>(n) ||
ov::is_type<ov::op::v1::Equal>(n) || ov::is_type<ov::op::v0::Floor>(n) ||
ov::is_type<ov::op::v1::FloorMod>(n) || ov::is_type<ov::op::v0::Gelu>(n) ||
ov::is_type<ov::op::v7::Gelu>(n) || ov::is_type<ov::op::v1::Greater>(n) ||
ov::is_type<ov::op::v1::GreaterEqual>(n) || ov::is_type<ov::op::v4::HSwish>(n) ||
ov::is_type<ov::op::v1::LessEqual>(n) || ov::is_type<ov::op::v1::Maximum>(n) ||
ov::is_type<ov::op::v1::Minimum>(n) || ov::is_type<ov::op::v4::Mish>(n) ||
ov::is_type<ov::op::v1::Mod>(n) || ov::is_type<ov::op::v1::Multiply>(n) ||
ov::is_type<ov::op::v0::Relu>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
ov::is_type<ov::op::v1::Subtract>(n) || ov::is_type<ov::op::v4::Swish>(n) ||
ov::is_type<ov::op::v1::Equal>(n) || ov::is_type<ov::op::v1::Greater>(n) ||
ov::is_type<ov::op::v1::GreaterEqual>(n) || ov::is_type<ov::op::v1::LessEqual>(n) ||
ov::is_type<ov::op::v0::Tanh>(n));
ov::is_type<ov::op::v0::PRelu>(n) || ov::is_type<ov::op::v0::Relu>(n) ||
ov::is_type<ov::op::v5::Round>(n) || ov::is_type<ov::op::v0::Sigmoid>(n) ||
ov::is_type<ov::op::v0::Sqrt>(n) || ov::is_type<ov::op::v1::Subtract>(n) ||
ov::is_type<ov::op::v4::Swish>(n) || ov::is_type<ov::op::v0::Tanh>(n));
#else
// CPU Plugin support Swish in Subgraph via conversion to SwichCPU which assumes second input to be constant,
// and CPU Plugin does not support Mish for x64
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
(activation_type == utils::ActivationTypes::Sqrt) ||
(activation_type == utils::ActivationTypes::Swish) ||
(activation_type == utils::ActivationTypes::LogicalNot) ||
(activation_type == utils::ActivationTypes::Tanh))) {
(activation_type == utils::ActivationTypes::Tanh) ||
(activation_type == utils::ActivationTypes::RoundHalfAwayFromZero) ||
(activation_type == utils::ActivationTypes::RoundHalfToEven))) {
return "jit";
}

Expand All @@ -209,7 +211,9 @@ std::string ActivationLayerCPUTest::getPrimitiveType(const utils::ActivationType
if ((activation_type == utils::ActivationTypes::Floor) ||
(activation_type == utils::ActivationTypes::Ceiling) ||
(activation_type == utils::ActivationTypes::IsNaN) ||
(activation_type == utils::ActivationTypes::IsFinite)) {
(activation_type == utils::ActivationTypes::IsFinite) ||
(activation_type == utils::ActivationTypes::RoundHalfAwayFromZero) ||
(activation_type == utils::ActivationTypes::RoundHalfToEven)) {
return "ref";
}
return "acl";
Expand Down Expand Up @@ -265,22 +269,26 @@ const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activat

const std::map<utils::ActivationTypes, std::vector<std::vector<float>>>& activationTypesSnippets() {
static const std::map<utils::ActivationTypes, std::vector<std::vector<float>>> activationTypes {
{Abs, {{}}},
{Exp, {{}}},
{Ceiling, {{}}},
{Clamp, {{-2.0f, 2.0f}}},
{Elu, {{0.1f}}},
{Floor, {{}}},
{GeluErf, {{}}},
{GeluTanh, {{}}},
{Relu, {{}}},
{HSwish, {{}}},
{Abs, {{}}},
{Exp, {{}}},
{Ceiling, {{}}},
{Clamp, {{-2.0f, 2.0f}}},
{Elu, {{0.1f}}},
{Floor, {{}}},
{GeluErf, {{}}},
{GeluTanh, {{}}},
{Relu, {{}}},
{HSwish, {{}}},
{PReLu, {{-0.01f}}},
{Sqrt, {{}}},
{RoundHalfToEven, {{}}},
{RoundHalfAwayFromZero, {{}}},
#if defined(OPENVINO_ARCH_ARM64)
{Mish, {{}}},
{Mish, {{}}},
#endif
{Sigmoid, {{}}},
{Swish, {{0.1f}}},
{Tanh, {{}}},
{Sigmoid, {{}}},
{Swish, {{0.1f}}},
{Tanh, {{}}},
};

return activationTypes;
Expand Down

0 comments on commit 1d955cd

Please sign in to comment.