Dynamic concat gpu support#5032
Conversation
| } | ||
|
|
||
| static std::vector<argument> ensure_gpu_kernel_args(const std::vector<argument>& args, | ||
| pmr::vector<argument>& temps) |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| pmr::vector<argument>& temps) | |
| pmr::vector<argument>& temps) |
| const std::size_t num_concat = | ||
| v.get("num_concat_inputs", inputs.size()); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| const std::size_t num_concat = | |
| v.get("num_concat_inputs", inputs.size()); | |
| const std::size_t num_concat = v.get("num_concat_inputs", inputs.size()); |
| concat_shapes.assign(inputs.begin(), | ||
| inputs.begin() + std::min(num_concat, inputs.size())); | ||
| shape output_shape = v.contains("output_shape") ? from_value<shape>(v.at("output_shape")) | ||
| : inputs.back(); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| concat_shapes.assign(inputs.begin(), | |
| inputs.begin() + std::min(num_concat, inputs.size())); | |
| shape output_shape = v.contains("output_shape") ? from_value<shape>(v.at("output_shape")) | |
| : inputs.back(); | |
| concat_shapes.assign(inputs.begin(), inputs.begin() + std::min(num_concat, inputs.size())); | |
| shape output_shape = | |
| v.contains("output_shape") ? from_value<shape>(v.at("output_shape")) : inputs.back(); |
| options.inputs = inputs; | ||
| options.output = output_shape; |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| options.inputs = inputs; | |
| options.output = output_shape; | |
| options.inputs = inputs; | |
| options.output = output_shape; |
| auto args = v.at("args"); | ||
|
|
||
| // normalize() rewrites axis into reduced-dim space; kernel concat<Axis> uses full tensors. | ||
| std::size_t fast_axis = kernel_axis; |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| std::size_t fast_axis = kernel_axis; | |
| std::size_t fast_axis = kernel_axis; |
|
|
||
| const std::size_t nelem = | ||
| output_shape.dynamic() ? output_shape.element_space() : output_shape.elements(); | ||
| auto nelements_per_op = nelem / op_names.size(); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto nelements_per_op = nelem / op_names.size(); | |
| auto nelements_per_op = nelem / op_names.size(); |
| auto psl = var("psl", {1, 64}); | ||
| using dd = migraphx::shape::dynamic_dimension; | ||
|
|
||
| migraphx::shape past_shape{migraphx::shape::half_type, {dd{1, 1}, dd{5, 5}, dd{psl}, dd{64, 64}}}; |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| migraphx::shape past_shape{migraphx::shape::half_type, {dd{1, 1}, dd{5, 5}, dd{psl}, dd{64, 64}}}; | |
| migraphx::shape past_shape{migraphx::shape::half_type, | |
| {dd{1, 1}, dd{5, 5}, dd{psl}, dd{64, 64}}}; |
| auto* mm = p.get_main_module(); | ||
| auto past_key = mm->add_parameter("past_key_values.0.key", past_shape); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto* mm = p.get_main_module(); | |
| auto past_key = mm->add_parameter("past_key_values.0.key", past_shape); | |
| auto* mm = p.get_main_module(); | |
| auto past_key = mm->add_parameter("past_key_values.0.key", past_shape); |
| return {{"past_key_values.0.key", | ||
| migraphx::shape{migraphx::shape::half_type, {1, 5, 1, 64}}}}; |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| return {{"past_key_values.0.key", | |
| migraphx::shape{migraphx::shape::half_type, {1, 5, 1, 64}}}}; | |
| return { | |
| {"past_key_values.0.key", migraphx::shape{migraphx::shape::half_type, {1, 5, 1, 64}}}}; |
| auto n = var("n", {2, 3}); | ||
| auto d0 = var("d0", {2, 4}); | ||
| auto d1 = var("d1", {3, 4}); | ||
| auto d2 = var("d2", {1, 5}); |
There was a problem hiding this comment.
[format.py] reported by reviewdog 🐶
| auto n = var("n", {2, 3}); | |
| auto d0 = var("d0", {2, 4}); | |
| auto d1 = var("d1", {3, 4}); | |
| auto d2 = var("d2", {1, 5}); | |
| auto n = var("n", {2, 3}); | |
| auto d0 = var("d0", {2, 4}); | |
| auto d1 = var("d1", {3, 4}); | |
| auto d2 = var("d2", {1, 5}); |
Motivation
Dynamic concat is required to run dynamic kv-cache
Technical Details
Adds changes needed to run concat with dynamic shape inputs on gpu.
Changelog Category
Add a
CHANGELOG.mdentry for any option other thanNot Applicable