20 namespace ext::intel::esimd::detail {
22 template <
typename BaseTy,
typename EltTy>
struct compute_format_type;
24 template <
typename Ty,
int N,
typename EltTy>
struct compute_format_type_impl {
25 static constexpr
int Size =
sizeof(Ty) * N /
sizeof(EltTy);
26 static constexpr
int Stride = 1;
27 using type = region1d_t<EltTy, Size, Stride>;
30 template <
typename Ty,
int N,
typename EltTy,
31 template <
typename,
int>
class SimdT>
32 struct compute_format_type<SimdT<Ty, N>, EltTy>
33 : compute_format_type_impl<Ty, N, EltTy> {};
35 template <
typename BaseTy,
typename RegionTy,
typename EltTy>
36 struct compute_format_type<simd_view<BaseTy, RegionTy>, EltTy> {
37 using ShapeTy =
typename shape_type<RegionTy>::type;
38 static constexpr
int Size = ShapeTy::Size_in_bytes /
sizeof(EltTy);
39 static constexpr
int Stride = 1;
40 using type = region1d_t<EltTy, Size, Stride>;
43 template <
typename Ty,
typename EltTy>
44 using compute_format_type_t =
typename compute_format_type<Ty, EltTy>::type;
47 template <
typename BaseTy,
typename EltTy,
int Height,
int W
idth>
48 struct compute_format_type_2d;
50 template <
typename Ty,
int N,
typename EltTy,
int Height,
int W
idth>
51 struct compute_format_type_2d_impl {
52 static constexpr
int Prod =
sizeof(Ty) * N /
sizeof(EltTy);
53 static_assert(Prod == Width * Height,
"size mismatch");
55 static constexpr
int SizeX = Width;
56 static constexpr
int StrideX = 1;
57 static constexpr
int SizeY = Height;
58 static constexpr
int StrideY = 1;
59 using type = region2d_t<EltTy, SizeY, StrideY, SizeX, StrideX>;
62 template <
typename Ty,
int N,
typename EltTy,
int Height,
int Width,
63 template <
typename,
int>
class SimdT>
64 struct compute_format_type_2d<SimdT<Ty, N>, EltTy, Height, Width>
65 : compute_format_type_2d_impl<Ty, N, EltTy, Height, Width> {};
67 template <
typename BaseTy,
typename RegionTy,
typename EltTy,
int Height,
69 struct compute_format_type_2d<simd_view<BaseTy, RegionTy>, EltTy, Height,
71 using ShapeTy =
typename shape_type<RegionTy>::type;
72 static constexpr
int Prod = ShapeTy::Size_in_bytes /
sizeof(EltTy);
73 static_assert(Prod == Width * Height,
"size mismatch");
75 static constexpr
int SizeX = Width;
76 static constexpr
int StrideX = 1;
77 static constexpr
int SizeY = Height;
78 static constexpr
int StrideY = 1;
79 using type = region2d_t<EltTy, SizeY, StrideY, SizeX, StrideX>;
82 template <
typename Ty,
typename EltTy,
int Height,
int W
idth>
83 using compute_format_type_2d_t =
84 typename compute_format_type_2d<Ty, EltTy, Height, Width>::type;