diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-30 12:03:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-30 12:03:23 -0700 |
| commit | de83628070614ec37349c9f334ed72a54a6889da (patch) | |
| tree | bc97f74013a073fc958b75e68089696e14d71412 /source/slang/slang-emit-spirv.cpp | |
| parent | f428a058ea48535a323c32d206ebc7e551c3c3e9 (diff) | |
Support specialization constants. (#4963)
* Support specialization constants.
* Fix.
* Fix.
* Fix.
* Fix.
* Make sure specialization constants have names.
* Clean up and support the dxc [vk::constant_id] syntax.
* Fix.
* Fix.
* Fix.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 141 |
1 files changed, 139 insertions, 2 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 8f4501c2b..bb1f1378f 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -918,6 +918,51 @@ struct SPIRVEmitContext }; Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants; Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants; + + // Get an SpvLiteralBits from an IRConstant. + SpvLiteralBits getLiteralBits(IRInst* type, IRInst* inst) + { + switch (type->getOp()) + { + case kIROp_DoubleType: + { + if (auto fval = as<IRFloatLit>(inst)) + return SpvLiteralBits::from64(DoubleAsInt64(fval->getValue())); + break; + } + case kIROp_HalfType: + { + if (auto fval = as<IRFloatLit>(inst)) + return SpvLiteralBits::from32(uint32_t(FloatToHalf((float)fval->getValue()))); + break; + } + case kIROp_FloatType: + { + if (auto fval = as<IRFloatLit>(inst)) + return SpvLiteralBits::from32(FloatAsInt((float)fval->getValue())); + break; + } + case kIROp_Int64Type: + case kIROp_UInt64Type: +#if SLANG_PTR_IS_64 + case kIROp_PtrType: + case kIROp_UIntPtrType: +#endif + { + if (auto val = as<IRIntLit>(inst)) + return SpvLiteralBits::from64(uint64_t(val->getValue())); + break; + } + default: + { + if (auto val = as<IRIntLit>(inst)) + return SpvLiteralBits::from32(uint32_t(val->getValue())); + break; + } + } + return SpvLiteralBits::from32(0); + } + SpvInst* emitIntConstant(IRIntegerValue val, IRType* type, IRInst* inst = nullptr) { ConstantValueKey<IRIntegerValue> key; @@ -1269,6 +1314,7 @@ struct SPIRVEmitContext return SpvStorageClassPhysicalStorageBuffer; case AddressSpace::Global: case AddressSpace::MetalObjectData: + case AddressSpace::SpecializationConstant: // msvc is limiting us from putting the UNEXPECTED macro here, so // just fall out ; @@ -2403,9 +2449,88 @@ struct SPIRVEmitContext } } + /// Emit a specialization constant. + SpvInst* emitSpecializationConstant(IRGlobalParam* param, IRVarOffsetAttr* offset) + { + IRInst* defaultVal = nullptr; + if (auto defaultValDecor = param->findDecoration<IRDefaultValueDecoration>()) + { + defaultVal = defaultValDecor->getOperand(0); + } + else + { + IRBuilder builder(param); + builder.setInsertBefore(param); + defaultVal = builder.emitDefaultConstruct(param->getDataType()); + } + + SpvInst* result = nullptr; + if (as<IRBoolType>(defaultVal->getDataType())) + { + bool value = false; + if (auto boolLit = as<IRBoolLit>(defaultVal)) + { + value = boolLit->getValue(); + } + if (value) + { + result = emitOpSpecConstantTrue( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType()); + } + else + { + result = emitOpSpecConstantFalse( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType()); + } + } + else if (auto type = as<IRBasicType>(defaultVal->getDataType())) + { + result = emitOpSpecConstant( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType(), + getLiteralBits(type, defaultVal)); + } + else if (as<IRVectorType>(defaultVal->getDataType())) + { + List<IRInst*> operands; + auto makeVector = as<IRMakeVector>(defaultVal); + for (UInt i = 0; i < makeVector->getOperandCount(); i++) + { + operands.add(makeVector->getOperand(i)); + } + result = emitOpSpecConstantComposite( + getSection(SpvLogicalSectionID::ConstantsAndTypes), + param, + param->getDataType(), + operands); + } + + emitOpDecorateSpecId( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + result, + SpvLiteralInteger::from32((uint32_t)offset->getOffset())); + + maybeEmitName(result, param); + return result; + } + /// Emit a global parameter definition. SpvInst* emitGlobalParam(IRGlobalParam* param) { + auto layout = getVarLayout(param); + if (layout) + { + if (auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant)) + { + return emitSpecializationConstant(param, offset); + } + } auto storageClass = SpvStorageClassUniform; if (auto ptrType = as<IRPtrTypeBase>(param->getDataType())) { @@ -2425,7 +2550,7 @@ struct SPIRVEmitContext storageClass ); maybeEmitPointerDecoration(varInst, param); - if (auto layout = getVarLayout(param)) + if (layout) emitVarLayout(param, varInst, layout); emitDecorations(param, getID(varInst)); return varInst; @@ -3419,8 +3544,20 @@ struct SPIRVEmitContext } } paramsSet.add(spvGlobalInst); - params.add(spvGlobalInst); referencedBuiltinIRVars.add(globalInst); + + // Don't add a global param to the interface if it is a specialization constant. + switch (spvGlobalInst->opcode) + { + case SpvOpSpecConstant: + case SpvOpSpecConstantFalse: + case SpvOpSpecConstantTrue: + case SpvOpSpecConstantComposite: + break; + default: + params.add(spvGlobalInst); + break; + } } } break; |
