DPC++ Runtime
Runtime libraries for oneAPI DPC++
sycl_mem_obj_allocator.hpp
Go to the documentation of this file.
1 //==------- sycl_mem_obj_allocator.hpp - SYCL standard header file ---------==//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #pragma once
10 
11 #include <sycl/detail/aligned_allocator.hpp> // for aligned_allocator
12 
13 #include <algorithm> // for max
14 #include <cstddef> // for size_t
15 #include <type_traits> // for enable_if_t
16 
17 namespace sycl {
18 inline namespace _V1 {
19 namespace detail {
20 
21 template <typename DataT>
23 
25 
26 protected:
27  virtual void *getAllocatorImpl() = 0;
28 
29 public:
30  virtual ~SYCLMemObjAllocator() = default;
31  virtual void *allocate(std::size_t) = 0;
32  virtual void deallocate(void *, std::size_t) = 0;
33  virtual std::size_t getValueSize() const = 0;
34  virtual void setAlignment(std::size_t RequiredAlign) = 0;
35  template <typename AllocatorT> AllocatorT getAllocator() {
36  return *reinterpret_cast<AllocatorT *>(getAllocatorImpl());
37  }
38 };
39 
40 template <typename AllocatorT, typename OwnerDataT>
42 public:
43  SYCLMemObjAllocatorHolder(AllocatorT Allocator)
44  : MAllocator(Allocator),
45  MValueSize(sizeof(typename AllocatorT::value_type)) {}
46 
48  : MAllocator(AllocatorT()),
49  MValueSize(sizeof(typename AllocatorT::value_type)) {}
50 
52 
53  virtual void *allocate(std::size_t Count) override {
54  return reinterpret_cast<void *>(MAllocator.allocate(Count));
55  }
56 
57  virtual void deallocate(void *Ptr, std::size_t Count) override {
58  MAllocator.deallocate(
59  reinterpret_cast<typename AllocatorT::value_type *>(Ptr), Count);
60  }
61 
62  void setAlignment(std::size_t RequiredAlign) override {
63  setAlignImpl(RequiredAlign);
64  }
65 
66  virtual std::size_t getValueSize() const override { return MValueSize; }
67 
68 protected:
69  virtual void *getAllocatorImpl() override { return &MAllocator; }
70 
71 private:
72  template <typename T>
73  using EnableIfDefaultAllocator = std::enable_if_t<
74  std::is_same_v<T, sycl_memory_object_allocator<OwnerDataT>>>;
75 
76  template <typename T>
77  using EnableIfNonDefaultAllocator = std::enable_if_t<
78  !std::is_same_v<T, sycl_memory_object_allocator<OwnerDataT>>>;
79 
80  template <typename T = AllocatorT>
81  EnableIfNonDefaultAllocator<T> setAlignImpl(std::size_t) {
82  // Do nothing in case of user's allocator.
83  }
84 
85  template <typename T = AllocatorT>
86  EnableIfDefaultAllocator<T> setAlignImpl(std::size_t RequiredAlign) {
87  MAllocator.setAlignment(std::max<size_t>(RequiredAlign, 64));
88  }
89 
90  AllocatorT MAllocator;
91  std::size_t MValueSize;
92 };
93 } // namespace detail
94 } // namespace _V1
95 } // namespace sycl
virtual void deallocate(void *Ptr, std::size_t Count) override
void setAlignment(std::size_t RequiredAlign) override
virtual void * allocate(std::size_t Count) override
virtual std::size_t getValueSize() const override
virtual void deallocate(void *, std::size_t)=0
virtual void * allocate(std::size_t)=0
virtual std::size_t getValueSize() const =0
virtual void setAlignment(std::size_t RequiredAlign)=0
const void value_type
Definition: multi_ptr.hpp:457
Definition: access.hpp:18