Skip to content

Dynamic concat gpu support#5032

Draft
turneram wants to merge 5 commits into
developfrom
dynamic-concat-pointwise
Draft

Dynamic concat gpu support#5032
turneram wants to merge 5 commits into
developfrom
dynamic-concat-pointwise

Conversation

@turneram

@turneram turneram commented Jul 2, 2026

Copy link
Copy Markdown
Contributor

Motivation

Dynamic concat is required to run dynamic kv-cache

Technical Details

Adds changes needed to run concat with dynamic shape inputs on gpu.

Changelog Category

Add a CHANGELOG.md entry for any option other than Not Applicable

    • Added: New functionality.
    • Changed: Changes to existing functionality.
    • Removed: Functionality or support that has been removed. (Compared to a previous release)
    • Optimized: Component performance that has been optimized or improved.
    • Resolved Issues: Known issues from a previous version that have been resolved.
    • Not Applicable: This PR is not to be included in the changelog.

}

static std::vector<argument> ensure_gpu_kernel_args(const std::vector<argument>& args,
pmr::vector<argument>& temps)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
pmr::vector<argument>& temps)
pmr::vector<argument>& temps)

Comment on lines +91 to +92
const std::size_t num_concat =
v.get("num_concat_inputs", inputs.size());

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
const std::size_t num_concat =
v.get("num_concat_inputs", inputs.size());
const std::size_t num_concat = v.get("num_concat_inputs", inputs.size());

Comment on lines +94 to +97
concat_shapes.assign(inputs.begin(),
inputs.begin() + std::min(num_concat, inputs.size()));
shape output_shape = v.contains("output_shape") ? from_value<shape>(v.at("output_shape"))
: inputs.back();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
concat_shapes.assign(inputs.begin(),
inputs.begin() + std::min(num_concat, inputs.size()));
shape output_shape = v.contains("output_shape") ? from_value<shape>(v.at("output_shape"))
: inputs.back();
concat_shapes.assign(inputs.begin(), inputs.begin() + std::min(num_concat, inputs.size()));
shape output_shape =
v.contains("output_shape") ? from_value<shape>(v.at("output_shape")) : inputs.back();

Comment on lines +99 to +100
options.inputs = inputs;
options.output = output_shape;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
options.inputs = inputs;
options.output = output_shape;
options.inputs = inputs;
options.output = output_shape;

auto args = v.at("args");

// normalize() rewrites axis into reduced-dim space; kernel concat<Axis> uses full tensors.
std::size_t fast_axis = kernel_axis;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
std::size_t fast_axis = kernel_axis;
std::size_t fast_axis = kernel_axis;


const std::size_t nelem =
output_shape.dynamic() ? output_shape.element_space() : output_shape.elements();
auto nelements_per_op = nelem / op_names.size();

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto nelements_per_op = nelem / op_names.size();
auto nelements_per_op = nelem / op_names.size();

auto psl = var("psl", {1, 64});
using dd = migraphx::shape::dynamic_dimension;

migraphx::shape past_shape{migraphx::shape::half_type, {dd{1, 1}, dd{5, 5}, dd{psl}, dd{64, 64}}};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
migraphx::shape past_shape{migraphx::shape::half_type, {dd{1, 1}, dd{5, 5}, dd{psl}, dd{64, 64}}};
migraphx::shape past_shape{migraphx::shape::half_type,
{dd{1, 1}, dd{5, 5}, dd{psl}, dd{64, 64}}};

Comment on lines +69 to +70
auto* mm = p.get_main_module();
auto past_key = mm->add_parameter("past_key_values.0.key", past_shape);

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto* mm = p.get_main_module();
auto past_key = mm->add_parameter("past_key_values.0.key", past_shape);
auto* mm = p.get_main_module();
auto past_key = mm->add_parameter("past_key_values.0.key", past_shape);

Comment on lines +78 to +79
return {{"past_key_values.0.key",
migraphx::shape{migraphx::shape::half_type, {1, 5, 1, 64}}}};

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
return {{"past_key_values.0.key",
migraphx::shape{migraphx::shape::half_type, {1, 5, 1, 64}}}};
return {
{"past_key_values.0.key", migraphx::shape{migraphx::shape::half_type, {1, 5, 1, 64}}}};

Comment on lines +88 to +91
auto n = var("n", {2, 3});
auto d0 = var("d0", {2, 4});
auto d1 = var("d1", {3, 4});
auto d2 = var("d2", {1, 5});

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[format.py] reported by reviewdog 🐶

Suggested change
auto n = var("n", {2, 3});
auto d0 = var("d0", {2, 4});
auto d1 = var("d1", {3, 4});
auto d2 = var("d2", {1, 5});
auto n = var("n", {2, 3});
auto d0 = var("d0", {2, 4});
auto d1 = var("d1", {3, 4});
auto d2 = var("d2", {1, 5});

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant