diff --git a/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl b/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl index 7944486aac0f..7f96fc8c3171 100644 --- a/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl +++ b/libclc/clc/lib/generic/workitem/clc_get_sub_group_size.cl @@ -6,21 +6,21 @@ // //===----------------------------------------------------------------------===// +#include "clc/shared/clc_min.h" +#include "clc/workitem/clc_get_local_linear_id.h" #include "clc/workitem/clc_get_local_size.h" #include "clc/workitem/clc_get_max_sub_group_size.h" -#include "clc/workitem/clc_get_num_sub_groups.h" -#include "clc/workitem/clc_get_sub_group_id.h" #include "clc/workitem/clc_get_sub_group_size.h" _CLC_OVERLOAD _CLC_DEF uint __clc_get_sub_group_size() { - if (__clc_get_sub_group_id() != __clc_get_num_sub_groups() - 1) { - return __clc_get_max_sub_group_size(); - } - size_t size_x = __clc_get_local_size(0); - size_t size_y = __clc_get_local_size(1); - size_t size_z = __clc_get_local_size(2); - size_t linear_size = size_z * size_y * size_x; - size_t uniform_groups = __clc_get_num_sub_groups() - 1; - size_t uniform_size = __clc_get_max_sub_group_size() * uniform_groups; - return linear_size - uniform_size; + uint local_linear_size = (uint)__clc_get_local_size(0) * + (uint)__clc_get_local_size(1) * + (uint)__clc_get_local_size(2); + uint max_sg_size = __clc_get_max_sub_group_size(); + // Assume max_sg_size is power of 2. + uint remainder = local_linear_size & (max_sg_size - 1); + if (remainder == 0) + return max_sg_size; + uint lid = (uint)__clc_get_local_linear_id(); + return __clc_min(max_sg_size, local_linear_size - (lid & ~(max_sg_size - 1))); }