template<uint32_t wg_tile_n_, uint32_t wg_tile_m_, uint32_t sg_tile_n_, uint32_t sg_tile_m_ = 1, uint32_t wg_num_m_ = 1, uint32_t wg_num_n_ = 1, uint32_t chunk_size_ = 1>
struct gpu::xetla::kernel::layer_norm_attr_t< wg_tile_n_, wg_tile_m_, sg_tile_n_, sg_tile_m_, wg_num_m_, wg_num_n_, chunk_size_ >
Sets up attribute of the layer norm.
- Template Parameters
-
| wg_tile_n_ | Is the num of cols processed by one workgroup. Should equal to matrix_n in the current design. |
| wg_tile_m_ | Is the num of rows processed by one workgroup in each inner loop. Mainly used for row reduction in the BWD path |
| sg_tile_n_ | Is the num of cols processed by one subgroup. wg_tile_n % sg_tile_n == 0. |
| sg_tile_m_ | Is the num of rows processed by one subgroup in each inner loop. Mainly used for row reduction in the BWD path |
| wg_num_m_ | Is the num of total workgroups launched in y direction, will be used in static persistent thread mode. |
| wg_num_n_ | Is the num of total workgroups launched in x direction. Currently, it should be 1. |
| chunk_size_ | Is the size of chunks when processing n dimenstion. sg_tile_n % chunk_size == 0. Used only in FWD pass. Should be used when kernels have spills. |