summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-08-30 12:03:23 -0700
committerGitHub <noreply@github.com>2024-08-30 12:03:23 -0700
commitde83628070614ec37349c9f334ed72a54a6889da (patch)
treebc97f74013a073fc958b75e68089696e14d71412 /source/slang/slang-emit-spirv.cpp
parentf428a058ea48535a323c32d206ebc7e551c3c3e9 (diff)
Support specialization constants. (#4963)
* Support specialization constants. * Fix. * Fix. * Fix. * Fix. * Make sure specialization constants have names. * Clean up and support the dxc [vk::constant_id] syntax. * Fix. * Fix. * Fix.
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp141
1 files changed, 139 insertions, 2 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 8f4501c2b..bb1f1378f 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -918,6 +918,51 @@ struct SPIRVEmitContext
};
Dictionary<ConstantValueKey<IRIntegerValue>, SpvInst*> m_spvIntConstants;
Dictionary<ConstantValueKey<IRFloatingPointValue>, SpvInst*> m_spvFloatConstants;
+
+ // Get an SpvLiteralBits from an IRConstant.
+ SpvLiteralBits getLiteralBits(IRInst* type, IRInst* inst)
+ {
+ switch (type->getOp())
+ {
+ case kIROp_DoubleType:
+ {
+ if (auto fval = as<IRFloatLit>(inst))
+ return SpvLiteralBits::from64(DoubleAsInt64(fval->getValue()));
+ break;
+ }
+ case kIROp_HalfType:
+ {
+ if (auto fval = as<IRFloatLit>(inst))
+ return SpvLiteralBits::from32(uint32_t(FloatToHalf((float)fval->getValue())));
+ break;
+ }
+ case kIROp_FloatType:
+ {
+ if (auto fval = as<IRFloatLit>(inst))
+ return SpvLiteralBits::from32(FloatAsInt((float)fval->getValue()));
+ break;
+ }
+ case kIROp_Int64Type:
+ case kIROp_UInt64Type:
+#if SLANG_PTR_IS_64
+ case kIROp_PtrType:
+ case kIROp_UIntPtrType:
+#endif
+ {
+ if (auto val = as<IRIntLit>(inst))
+ return SpvLiteralBits::from64(uint64_t(val->getValue()));
+ break;
+ }
+ default:
+ {
+ if (auto val = as<IRIntLit>(inst))
+ return SpvLiteralBits::from32(uint32_t(val->getValue()));
+ break;
+ }
+ }
+ return SpvLiteralBits::from32(0);
+ }
+
SpvInst* emitIntConstant(IRIntegerValue val, IRType* type, IRInst* inst = nullptr)
{
ConstantValueKey<IRIntegerValue> key;
@@ -1269,6 +1314,7 @@ struct SPIRVEmitContext
return SpvStorageClassPhysicalStorageBuffer;
case AddressSpace::Global:
case AddressSpace::MetalObjectData:
+ case AddressSpace::SpecializationConstant:
// msvc is limiting us from putting the UNEXPECTED macro here, so
// just fall out
;
@@ -2403,9 +2449,88 @@ struct SPIRVEmitContext
}
}
+ /// Emit a specialization constant.
+ SpvInst* emitSpecializationConstant(IRGlobalParam* param, IRVarOffsetAttr* offset)
+ {
+ IRInst* defaultVal = nullptr;
+ if (auto defaultValDecor = param->findDecoration<IRDefaultValueDecoration>())
+ {
+ defaultVal = defaultValDecor->getOperand(0);
+ }
+ else
+ {
+ IRBuilder builder(param);
+ builder.setInsertBefore(param);
+ defaultVal = builder.emitDefaultConstruct(param->getDataType());
+ }
+
+ SpvInst* result = nullptr;
+ if (as<IRBoolType>(defaultVal->getDataType()))
+ {
+ bool value = false;
+ if (auto boolLit = as<IRBoolLit>(defaultVal))
+ {
+ value = boolLit->getValue();
+ }
+ if (value)
+ {
+ result = emitOpSpecConstantTrue(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType());
+ }
+ else
+ {
+ result = emitOpSpecConstantFalse(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType());
+ }
+ }
+ else if (auto type = as<IRBasicType>(defaultVal->getDataType()))
+ {
+ result = emitOpSpecConstant(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType(),
+ getLiteralBits(type, defaultVal));
+ }
+ else if (as<IRVectorType>(defaultVal->getDataType()))
+ {
+ List<IRInst*> operands;
+ auto makeVector = as<IRMakeVector>(defaultVal);
+ for (UInt i = 0; i < makeVector->getOperandCount(); i++)
+ {
+ operands.add(makeVector->getOperand(i));
+ }
+ result = emitOpSpecConstantComposite(
+ getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ param,
+ param->getDataType(),
+ operands);
+ }
+
+ emitOpDecorateSpecId(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ result,
+ SpvLiteralInteger::from32((uint32_t)offset->getOffset()));
+
+ maybeEmitName(result, param);
+ return result;
+ }
+
/// Emit a global parameter definition.
SpvInst* emitGlobalParam(IRGlobalParam* param)
{
+ auto layout = getVarLayout(param);
+ if (layout)
+ {
+ if (auto offset = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
+ {
+ return emitSpecializationConstant(param, offset);
+ }
+ }
auto storageClass = SpvStorageClassUniform;
if (auto ptrType = as<IRPtrTypeBase>(param->getDataType()))
{
@@ -2425,7 +2550,7 @@ struct SPIRVEmitContext
storageClass
);
maybeEmitPointerDecoration(varInst, param);
- if (auto layout = getVarLayout(param))
+ if (layout)
emitVarLayout(param, varInst, layout);
emitDecorations(param, getID(varInst));
return varInst;
@@ -3419,8 +3544,20 @@ struct SPIRVEmitContext
}
}
paramsSet.add(spvGlobalInst);
- params.add(spvGlobalInst);
referencedBuiltinIRVars.add(globalInst);
+
+ // Don't add a global param to the interface if it is a specialization constant.
+ switch (spvGlobalInst->opcode)
+ {
+ case SpvOpSpecConstant:
+ case SpvOpSpecConstantFalse:
+ case SpvOpSpecConstantTrue:
+ case SpvOpSpecConstantComposite:
+ break;
+ default:
+ params.add(spvGlobalInst);
+ break;
+ }
}
}
break;