clang  19.0.0git
CGHLSLRuntime.cpp
Go to the documentation of this file.
1 //===----- CGHLSLRuntime.cpp - Interface to HLSL Runtimes -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This provides an abstract class for HLSL code generation. Concrete
10 // subclasses of this implement code generation for specific HLSL
11 // runtime libraries.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "CGHLSLRuntime.h"
16 #include "CGDebugInfo.h"
17 #include "CodeGenModule.h"
18 #include "clang/AST/Decl.h"
20 #include "llvm/IR/Metadata.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace clang;
25 using namespace CodeGen;
26 using namespace clang::hlsl;
27 using namespace llvm;
28 
29 namespace {
30 
31 void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
32  // The validation of ValVersionStr is done at HLSLToolChain::TranslateArgs.
33  // Assume ValVersionStr is legal here.
34  VersionTuple Version;
35  if (Version.tryParse(ValVersionStr) || Version.getBuild() ||
36  Version.getSubminor() || !Version.getMinor()) {
37  return;
38  }
39 
40  uint64_t Major = Version.getMajor();
41  uint64_t Minor = *Version.getMinor();
42 
43  auto &Ctx = M.getContext();
44  IRBuilder<> B(M.getContext());
45  MDNode *Val = MDNode::get(Ctx, {ConstantAsMetadata::get(B.getInt32(Major)),
46  ConstantAsMetadata::get(B.getInt32(Minor))});
47  StringRef DXILValKey = "dx.valver";
48  auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
49  DXILValMD->addOperand(Val);
50 }
51 void addDisableOptimizations(llvm::Module &M) {
52  StringRef Key = "dx.disable_optimizations";
53  M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
54 }
55 // cbuffer will be translated into global variable in special address space.
56 // If translate into C,
57 // cbuffer A {
58 // float a;
59 // float b;
60 // }
61 // float foo() { return a + b; }
62 //
63 // will be translated into
64 //
65 // struct A {
66 // float a;
67 // float b;
68 // } cbuffer_A __attribute__((address_space(4)));
69 // float foo() { return cbuffer_A.a + cbuffer_A.b; }
70 //
71 // layoutBuffer will create the struct A type.
72 // replaceBuffer will replace use of global variable a and b with cbuffer_A.a
73 // and cbuffer_A.b.
74 //
75 void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
76  if (Buf.Constants.empty())
77  return;
78 
79  std::vector<llvm::Type *> EltTys;
80  for (auto &Const : Buf.Constants) {
81  GlobalVariable *GV = Const.first;
82  Const.second = EltTys.size();
83  llvm::Type *Ty = GV->getValueType();
84  EltTys.emplace_back(Ty);
85  }
86  Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
87 }
88 
89 GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
90  // Create global variable for CB.
91  GlobalVariable *CBGV = new GlobalVariable(
92  Buf.LayoutStruct, /*isConstant*/ true,
93  GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
94  llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
95  GlobalValue::NotThreadLocal);
96 
97  IRBuilder<> B(CBGV->getContext());
98  Value *ZeroIdx = B.getInt32(0);
99  // Replace Const use with CB use.
100  for (auto &[GV, Offset] : Buf.Constants) {
101  Value *GEP =
102  B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
103 
104  assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
105  "constant type mismatch");
106 
107  // Replace.
108  GV->replaceAllUsesWith(GEP);
109  // Erase GV.
110  GV->removeDeadConstantUsers();
111  GV->eraseFromParent();
112  }
113  return CBGV;
114 }
115 
116 } // namespace
117 
118 llvm::Triple::ArchType CGHLSLRuntime::getArch() {
119  return CGM.getTarget().getTriple().getArch();
120 }
121 
122 void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
123  if (D->getStorageClass() == SC_Static) {
124  // For static inside cbuffer, take as global static.
125  // Don't add to cbuffer.
126  CGM.EmitGlobal(D);
127  return;
128  }
129 
130  auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
131  // Add debug info for constVal.
132  if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
133  if (CGM.getCodeGenOpts().getDebugInfo() >=
134  codegenoptions::DebugInfoKind::LimitedDebugInfo)
135  DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
136 
137  // FIXME: support packoffset.
138  // See https://github.com/llvm/llvm-project/issues/57914.
139  uint32_t Offset = 0;
140  bool HasUserOffset = false;
141 
142  unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
143  CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
144 }
145 
146 void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
147  for (Decl *it : DC->decls()) {
148  if (auto *ConstDecl = dyn_cast<VarDecl>(it)) {
149  addConstant(ConstDecl, CB);
150  } else if (isa<CXXRecordDecl, EmptyDecl>(it)) {
151  // Nothing to do for this declaration.
152  } else if (isa<FunctionDecl>(it)) {
153  // A function within an cbuffer is effectively a top-level function,
154  // as it only refers to globally scoped declarations.
155  CGM.EmitTopLevelDecl(it);
156  }
157  }
158 }
159 
161  Buffers.emplace_back(Buffer(D));
162  addBufferDecls(D, Buffers.back());
163 }
164 
166  auto &TargetOpts = CGM.getTarget().getTargetOpts();
167  llvm::Module &M = CGM.getModule();
168  Triple T(M.getTargetTriple());
169  if (T.getArch() == Triple::ArchType::dxil)
170  addDxilValVersion(TargetOpts.DxilValidatorVersion, M);
171 
172  generateGlobalCtorDtorCalls();
173  if (CGM.getCodeGenOpts().OptimizationLevel == 0)
174  addDisableOptimizations(M);
175 
176  const DataLayout &DL = M.getDataLayout();
177 
178  for (auto &Buf : Buffers) {
179  layoutBuffer(Buf, DL);
180  GlobalVariable *GV = replaceBuffer(Buf);
181  M.insertGlobalVariable(GV);
182  llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
183  ? llvm::hlsl::ResourceClass::CBuffer
184  : llvm::hlsl::ResourceClass::SRV;
185  llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
186  ? llvm::hlsl::ResourceKind::CBuffer
187  : llvm::hlsl::ResourceKind::TBuffer;
188  addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
189  llvm::hlsl::ElementType::Invalid, Buf.Binding);
190  }
191 }
192 
194  : Name(D->getName()), IsCBuffer(D->isCBuffer()),
195  Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
196 
197 void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
198  llvm::hlsl::ResourceClass RC,
199  llvm::hlsl::ResourceKind RK,
200  bool IsROV,
201  llvm::hlsl::ElementType ET,
202  BufferResBinding &Binding) {
203  llvm::Module &M = CGM.getModule();
204 
205  NamedMDNode *ResourceMD = nullptr;
206  switch (RC) {
207  case llvm::hlsl::ResourceClass::UAV:
208  ResourceMD = M.getOrInsertNamedMetadata("hlsl.uavs");
209  break;
210  case llvm::hlsl::ResourceClass::SRV:
211  ResourceMD = M.getOrInsertNamedMetadata("hlsl.srvs");
212  break;
213  case llvm::hlsl::ResourceClass::CBuffer:
214  ResourceMD = M.getOrInsertNamedMetadata("hlsl.cbufs");
215  break;
216  default:
217  assert(false && "Unsupported buffer type!");
218  return;
219  }
220  assert(ResourceMD != nullptr &&
221  "ResourceMD must have been set by the switch above.");
222 
223  llvm::hlsl::FrontendResource Res(
224  GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
225  ResourceMD->addOperand(Res.getMetadata());
226 }
227 
228 static llvm::hlsl::ElementType
229 calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy) {
230  using llvm::hlsl::ElementType;
231 
232  // TODO: We may need to update this when we add things like ByteAddressBuffer
233  // that don't have a template parameter (or, indeed, an element type).
234  const auto *TST = ResourceTy->getAs<TemplateSpecializationType>();
235  assert(TST && "Resource types must be template specializations");
236  ArrayRef<TemplateArgument> Args = TST->template_arguments();
237  assert(!Args.empty() && "Resource has no element type");
238 
239  // At this point we have a resource with an element type, so we can assume
240  // that it's valid or we would have diagnosed the error earlier.
241  QualType ElTy = Args[0].getAsType();
242 
243  // We should either have a basic type or a vector of a basic type.
244  if (const auto *VecTy = ElTy->getAs<clang::VectorType>())
245  ElTy = VecTy->getElementType();
246 
247  if (ElTy->isSignedIntegerType()) {
248  switch (Context.getTypeSize(ElTy)) {
249  case 16:
250  return ElementType::I16;
251  case 32:
252  return ElementType::I32;
253  case 64:
254  return ElementType::I64;
255  }
256  } else if (ElTy->isUnsignedIntegerType()) {
257  switch (Context.getTypeSize(ElTy)) {
258  case 16:
259  return ElementType::U16;
260  case 32:
261  return ElementType::U32;
262  case 64:
263  return ElementType::U64;
264  }
265  } else if (ElTy->isSpecificBuiltinType(BuiltinType::Half))
266  return ElementType::F16;
268  return ElementType::F32;
269  else if (ElTy->isSpecificBuiltinType(BuiltinType::Double))
270  return ElementType::F64;
271 
272  // TODO: We need to handle unorm/snorm float types here once we support them
273  llvm_unreachable("Invalid element type for resource");
274 }
275 
276 void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
277  const Type *Ty = D->getType()->getPointeeOrArrayElementType();
278  if (!Ty)
279  return;
280  const auto *RD = Ty->getAsCXXRecordDecl();
281  if (!RD)
282  return;
283  const auto *Attr = RD->getAttr<HLSLResourceAttr>();
284  if (!Attr)
285  return;
286 
287  llvm::hlsl::ResourceClass RC = Attr->getResourceClass();
288  llvm::hlsl::ResourceKind RK = Attr->getResourceKind();
289  bool IsROV = Attr->getIsROV();
290  llvm::hlsl::ElementType ET = calculateElementType(CGM.getContext(), Ty);
291 
292  BufferResBinding Binding(D->getAttr<HLSLResourceBindingAttr>());
293  addBufferResourceAnnotation(GV, RC, RK, IsROV, ET, Binding);
294 }
295 
297  HLSLResourceBindingAttr *Binding) {
298  if (Binding) {
299  llvm::APInt RegInt(64, 0);
300  Binding->getSlot().substr(1).getAsInteger(10, RegInt);
301  Reg = RegInt.getLimitedValue();
302  llvm::APInt SpaceInt(64, 0);
303  Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
304  Space = SpaceInt.getLimitedValue();
305  } else {
306  Space = 0;
307  }
308 }
309 
311  const FunctionDecl *FD, llvm::Function *Fn) {
312  const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>();
313  assert(ShaderAttr && "All entry functions must have a HLSLShaderAttr");
314  const StringRef ShaderAttrKindStr = "hlsl.shader";
315  Fn->addFnAttr(ShaderAttrKindStr,
316  ShaderAttr->ConvertShaderTypeToStr(ShaderAttr->getType()));
317  if (HLSLNumThreadsAttr *NumThreadsAttr = FD->getAttr<HLSLNumThreadsAttr>()) {
318  const StringRef NumThreadsKindStr = "hlsl.numthreads";
319  std::string NumThreadsStr =
320  formatv("{0},{1},{2}", NumThreadsAttr->getX(), NumThreadsAttr->getY(),
321  NumThreadsAttr->getZ());
322  Fn->addFnAttr(NumThreadsKindStr, NumThreadsStr);
323  }
324 }
325 
326 static Value *buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty) {
327  if (const auto *VT = dyn_cast<FixedVectorType>(Ty)) {
328  Value *Result = PoisonValue::get(Ty);
329  for (unsigned I = 0; I < VT->getNumElements(); ++I) {
330  Value *Elt = B.CreateCall(F, {B.getInt32(I)});
331  Result = B.CreateInsertElement(Result, Elt, I);
332  }
333  return Result;
334  }
335  return B.CreateCall(F, {B.getInt32(0)});
336 }
337 
338 llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
339  const ParmVarDecl &D,
340  llvm::Type *Ty) {
341  assert(D.hasAttrs() && "Entry parameter missing annotation attribute!");
342  if (D.hasAttr<HLSLSV_GroupIndexAttr>()) {
343  llvm::Function *DxGroupIndex =
344  CGM.getIntrinsic(Intrinsic::dx_flattened_thread_id_in_group);
345  return B.CreateCall(FunctionCallee(DxGroupIndex));
346  }
347  if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
348  llvm::Function *ThreadIDIntrinsic =
349  CGM.getIntrinsic(getThreadIdIntrinsic());
350  return buildVectorInput(B, ThreadIDIntrinsic, Ty);
351  }
352  assert(false && "Unhandled parameter attribute");
353  return nullptr;
354 }
355 
357  llvm::Function *Fn) {
358  llvm::Module &M = CGM.getModule();
359  llvm::LLVMContext &Ctx = M.getContext();
360  auto *EntryTy = llvm::FunctionType::get(llvm::Type::getVoidTy(Ctx), false);
361  Function *EntryFn =
362  Function::Create(EntryTy, Function::ExternalLinkage, FD->getName(), &M);
363 
364  // Copy function attributes over, we have no argument or return attributes
365  // that can be valid on the real entry.
366  AttributeList NewAttrs = AttributeList::get(Ctx, AttributeList::FunctionIndex,
367  Fn->getAttributes().getFnAttrs());
368  EntryFn->setAttributes(NewAttrs);
369  setHLSLEntryAttributes(FD, EntryFn);
370 
371  // Set the called function as internal linkage.
372  Fn->setLinkage(GlobalValue::InternalLinkage);
373 
374  BasicBlock *BB = BasicBlock::Create(Ctx, "entry", EntryFn);
375  IRBuilder<> B(BB);
377  // FIXME: support struct parameters where semantics are on members.
378  // See: https://github.com/llvm/llvm-project/issues/57874
379  unsigned SRetOffset = 0;
380  for (const auto &Param : Fn->args()) {
381  if (Param.hasStructRetAttr()) {
382  // FIXME: support output.
383  // See: https://github.com/llvm/llvm-project/issues/57874
384  SRetOffset = 1;
385  Args.emplace_back(PoisonValue::get(Param.getType()));
386  continue;
387  }
388  const ParmVarDecl *PD = FD->getParamDecl(Param.getArgNo() - SRetOffset);
389  Args.push_back(emitInputSemantic(B, *PD, Param.getType()));
390  }
391 
392  CallInst *CI = B.CreateCall(FunctionCallee(Fn), Args);
393  (void)CI;
394  // FIXME: Handle codegen for return type semantics.
395  // See: https://github.com/llvm/llvm-project/issues/57875
396  B.CreateRetVoid();
397 }
398 
399 static void gatherFunctions(SmallVectorImpl<Function *> &Fns, llvm::Module &M,
400  bool CtorOrDtor) {
401  const auto *GV =
402  M.getNamedGlobal(CtorOrDtor ? "llvm.global_ctors" : "llvm.global_dtors");
403  if (!GV)
404  return;
405  const auto *CA = dyn_cast<ConstantArray>(GV->getInitializer());
406  if (!CA)
407  return;
408  // The global_ctor array elements are a struct [Priority, Fn *, COMDat].
409  // HLSL neither supports priorities or COMDat values, so we will check those
410  // in an assert but not handle them.
411 
413  for (const auto &Ctor : CA->operands()) {
414  if (isa<ConstantAggregateZero>(Ctor))
415  continue;
416  ConstantStruct *CS = cast<ConstantStruct>(Ctor);
417 
418  assert(cast<ConstantInt>(CS->getOperand(0))->getValue() == 65535 &&
419  "HLSL doesn't support setting priority for global ctors.");
420  assert(isa<ConstantPointerNull>(CS->getOperand(2)) &&
421  "HLSL doesn't support COMDat for global ctors.");
422  Fns.push_back(cast<Function>(CS->getOperand(1)));
423  }
424 }
425 
427  llvm::Module &M = CGM.getModule();
428  SmallVector<Function *> CtorFns;
429  SmallVector<Function *> DtorFns;
430  gatherFunctions(CtorFns, M, true);
431  gatherFunctions(DtorFns, M, false);
432 
433  // Insert a call to the global constructor at the beginning of the entry block
434  // to externally exported functions. This is a bit of a hack, but HLSL allows
435  // global constructors, but doesn't support driver initialization of globals.
436  for (auto &F : M.functions()) {
437  if (!F.hasFnAttribute("hlsl.shader"))
438  continue;
439  IRBuilder<> B(&F.getEntryBlock(), F.getEntryBlock().begin());
440  for (auto *Fn : CtorFns)
441  B.CreateCall(FunctionCallee(Fn));
442 
443  // Insert global dtors before the terminator of the last instruction
444  B.SetInsertPoint(F.back().getTerminator());
445  for (auto *Fn : DtorFns)
446  B.CreateCall(FunctionCallee(Fn));
447  }
448 
449  // No need to keep global ctors/dtors for non-lib profile after call to
450  // ctors/dtors added for entry.
451  Triple T(M.getTargetTriple());
452  if (T.getEnvironment() != Triple::EnvironmentType::Library) {
453  if (auto *GV = M.getNamedGlobal("llvm.global_ctors"))
454  GV->eraseFromParent();
455  if (auto *GV = M.getNamedGlobal("llvm.global_dtors"))
456  GV->eraseFromParent();
457  }
458 }
static llvm::hlsl::ElementType calculateElementType(const ASTContext &Context, const clang::Type *ResourceTy)
static Value * buildVectorInput(IRBuilder<> &B, Function *F, llvm::Type *Ty)
static void gatherFunctions(SmallVectorImpl< Function * > &Fns, llvm::Module &M, bool CtorOrDtor)
unsigned Offset
Definition: Format.cpp:2978
static std::string getName(const CallEvent &Call)
Defines the clang::TargetOptions class.
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition: ASTContext.h:185
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
Definition: ASTContext.h:2355
Attr - This represents one attribute.
Definition: Attr.h:46
This class gathers all debug information during compilation and is responsible for emitting to llvm g...
Definition: CGDebugInfo.h:55
void setHLSLEntryAttributes(const FunctionDecl *FD, llvm::Function *Fn)
void emitEntryFunction(const FunctionDecl *FD, llvm::Function *Fn)
llvm::Value * emitInputSemantic(llvm::IRBuilder<> &B, const ParmVarDecl &D, llvm::Type *Ty)
void annotateHLSLResource(const VarDecl *D, llvm::GlobalVariable *GV)
void addBuffer(const HLSLBufferDecl *D)
llvm::Module & getModule() const
ASTContext & getContext() const
llvm::Function * getIntrinsic(unsigned IID, ArrayRef< llvm::Type * > Tys=std::nullopt)
DeclContext - This is used only as base class of specific decl types that can act as declaration cont...
Definition: DeclBase.h:1436
decl_range decls() const
decls_begin/decls_end - Iterate over the declarations stored in this context.
Definition: DeclBase.h:2322
Decl - This represents one declaration (or definition), e.g.
Definition: DeclBase.h:86
bool hasAttrs() const
Definition: DeclBase.h:524
bool hasAttr() const
Definition: DeclBase.h:583
T * getAttr() const
Definition: DeclBase.h:579
Represents a function declaration or definition.
Definition: Decl.h:1972
const ParmVarDecl * getParamDecl(unsigned i) const
Definition: Decl.h:2709
HLSLBufferDecl - Represent a cbuffer or tbuffer declaration.
Definition: Decl.h:4943
StringRef getName() const
Get the name of identifier for this declaration as a StringRef.
Definition: Decl.h:276
Represents a parameter to a function.
Definition: Decl.h:1762
A (possibly-)qualified type.
Definition: Type.h:940
Represents a type template specialization; the template must be a class template, a type alias templa...
Definition: Type.h:6101
The base class of the type hierarchy.
Definition: Type.h:1813
CXXRecordDecl * getAsCXXRecordDecl() const
Retrieves the CXXRecordDecl that this type refers to, either because the type is a RecordType or beca...
Definition: Type.cpp:1881
const Type * getPointeeOrArrayElementType() const
If this is a pointer type, return the pointee type.
Definition: Type.h:8117
bool isSignedIntegerType() const
Return true if this is an integer type that is signed, according to C99 6.2.5p4 [char,...
Definition: Type.cpp:2145
bool isSpecificBuiltinType(unsigned K) const
Test for a particular builtin type.
Definition: Type.h:7908
bool isUnsignedIntegerType() const
Return true if this is an integer type that is unsigned, according to C99 6.2.5p6 [which returns true...
Definition: Type.cpp:2195
const T * getAs() const
Member-template getAs<specific type>'.
Definition: Type.h:8160
QualType getType() const
Definition: Decl.h:718
Represents a variable declaration or definition.
Definition: Decl.h:919
StorageClass getStorageClass() const
Returns the storage class as written in the source.
Definition: Decl.h:1156
Represents a GCC generic vector type.
Definition: Type.h:3981
#define UINT_MAX
Definition: limits.h:64
llvm::APInt APInt
Definition: Integral.h:29
bool Const(InterpState &S, CodePtr OpPC, const T &Arg)
Definition: Interp.h:940
The JSON file list parser is used to communicate input to InstallAPI.
@ SC_Static
Definition: Specifiers.h:249
const FunctionProtoType * T
unsigned long uint64_t
Diagnostic wrappers for TextAPI types for error reporting.
Definition: Dominators.h:30
BufferResBinding(HLSLResourceBindingAttr *Attr)
std::vector< std::pair< llvm::GlobalVariable *, unsigned > > Constants
Definition: CGHLSLRuntime.h:99