summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-03-01 12:42:37 -0800
committerGitHub <noreply@github.com>2024-03-01 12:42:37 -0800
commit0d01b3701aae582b5fe3b6e2c2c718bec568c741 (patch)
tree5525bec1a786ace62263b68b89d287105c590df6
parent3ade07303783605ddef8c0f0c5237952c903798d (diff)
Various SPIRV fixes. (#3655)
* Various SPIRV fixes. * Fix debugValue.
-rw-r--r--source/slang/hlsl.meta.slang7
-rw-r--r--source/slang/slang-emit-spirv-ops-debug-info-ext.h9
-rw-r--r--source/slang/slang-emit-spirv.cpp121
-rw-r--r--source/slang/slang-ir-insert-debug-value-store.cpp36
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp1
-rw-r--r--tests/spirv/debug-value-dynamic-index.slang29
6 files changed, 189 insertions, 14 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 84ba11cad..5c97de525 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -1596,6 +1596,13 @@ extension __TextureImpl<T,Shape,isArray,1,sampleCount,0,isShadow,isCombined,form
}
}
+ [__readNone]
+ [ForceInline]
+ T Load(vector<int, Shape.dimensions + isArray + 1> locationAndSampleIndex)
+ {
+ return Load(__vectorReshape<Shape.dimensions + isArray>(locationAndSampleIndex), locationAndSampleIndex[Shape.dimensions + isArray]);
+ }
+
__glsl_extension(GL_EXT_samplerless_texture_functions)
[__readNone]
[ForceInline]
diff --git a/source/slang/slang-emit-spirv-ops-debug-info-ext.h b/source/slang/slang-emit-spirv-ops-debug-info-ext.h
index fcf931f0a..cd4f274f8 100644
--- a/source/slang/slang-emit-spirv-ops-debug-info-ext.h
+++ b/source/slang/slang-emit-spirv-ops-debug-info-ext.h
@@ -10,12 +10,19 @@ SpvInst* emitOpDebugCompilationUnit(SpvInstParent* parent, IRInst* inst, const T
// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/nonsemantic/NonSemantic.Shader.DebugInfo.100.asciidoc#DebugSource
template<typename T>
-SpvInst* emitOpDebugSource(SpvInstParent* parent, IRInst* inst, const T& idResultType, SpvInst* set, IRInst* file, IRInst* text)
+SpvInst* emitOpDebugSource(SpvInstParent* parent, IRInst* inst, const T& idResultType, SpvInst* set, IRInst* file, SpvInst* text)
{
static_assert(isSingular<T>);
return emitInst(parent, inst, SpvOpExtInst, idResultType, kResultID, set, SpvWord(35), file, text);
}
+template<typename T>
+SpvInst* emitOpDebugSourceContinued(SpvInstParent* parent, IRInst* inst, const T& idResultType, SpvInst* set, SpvInst* text)
+{
+ static_assert(isSingular<T>);
+ return emitInst(parent, inst, SpvOpExtInst, idResultType, kResultID, set, SpvWord(102), text);
+}
+
// https://github.com/KhronosGroup/SPIRV-Registry/blob/main/nonsemantic/NonSemantic.Shader.DebugInfo.100.asciidoc#DebugLine
template<typename T>
SpvInst* emitOpDebugLine(SpvInstParent* parent, IRInst* inst, const T& idResultType, SpvInst* set, IRInst* source, IRInst* lineStart, IRInst* lineEnd, IRInst* colStart, IRInst* colEnd)
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 7cb263b61..d81ceae81 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1599,14 +1599,40 @@ struct SPIRVEmitContext
case kIROp_DebugSource:
{
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_non_semantic_info"));
+ // SPIRV does not allow string lits longer than 65535, so we need to split the source string
+ // in OpDebugSourceContinued instructions.
auto debugSource = as<IRDebugSource>(inst);
+ auto sourceStr = as<IRStringLit>(debugSource->getSource())->getStringSlice();
+ auto sourceStrHead = sourceStr.getLength() > 65535 ? sourceStr.head(65535) : sourceStr;
+ auto spvStrHead = emitInst(
+ getSection(SpvLogicalSectionID::DebugStringsAndSource),
+ nullptr,
+ SpvOpString,
+ kResultID,
+ SpvLiteralBits::fromUnownedStringSlice(sourceStrHead));
+
auto result = emitOpDebugSource(
getSection(SpvLogicalSectionID::ConstantsAndTypes),
inst,
inst->getFullType(),
getNonSemanticDebugInfoExtInst(),
debugSource->getFileName(),
- debugSource->getSource());
+ spvStrHead);
+
+ for (Index start = 65535; start < sourceStr.getLength(); start += 65535)
+ {
+ auto slice = sourceStr.tail(start);
+ slice = slice.getLength() > 65535 ? slice.head(65535) : slice;
+ auto sliceSpvStr = emitInst(
+ getSection(SpvLogicalSectionID::DebugStringsAndSource),
+ nullptr,
+ SpvOpString,
+ kResultID,
+ SpvLiteralBits::fromUnownedStringSlice(slice));
+ emitOpDebugSourceContinued(getSection(SpvLogicalSectionID::ConstantsAndTypes),
+ nullptr, m_voidType, getNonSemanticDebugInfoExtInst(), sliceSpvStr);
+ }
+
auto moduleInst = inst->getModule()->getModuleInst();
if (!m_defaultDebugSource)
m_defaultDebugSource = debugSource;
@@ -1629,6 +1655,8 @@ struct SPIRVEmitContext
}
case kIROp_GetStringHash:
return emitGetStringHash(inst);
+ case kIROp_AttributedType:
+ return ensureInst(as<IRAttributedType>(inst)->getBaseType());
case kIROp_AllocateOpaqueHandle:
return nullptr;
case kIROp_HLSLTriangleStreamType:
@@ -1760,7 +1788,7 @@ struct SPIRVEmitContext
//
- const auto sampledType = inst->getElementType();
+ IRInst* sampledType = inst->getElementType();
SpvDim dim = SpvDim1D; // Silence uninitialized warnings from msvc...
switch(inst->GetBaseShape())
{
@@ -1800,6 +1828,57 @@ struct SPIRVEmitContext
}
SpvImageFormat format = getSpvImageFormat(inst);
+ // If format is unknown, we need to deduce the format if there is
+ // unorm or snorm attributes on the sampled type.
+ if (auto attribType = as<IRAttributedType>(sampledType))
+ {
+ sampledType = unwrapAttributedType(sampledType);
+ if (format == SpvImageFormatUnknown)
+ {
+ IRIntegerValue vectorSize = 1;
+ if (auto vecType = as<IRVectorType>(sampledType))
+ vectorSize = getIntVal(vecType->getElementCount());
+
+ for (auto attr : attribType->getAllAttrs())
+ {
+ switch (attr->getOp())
+ {
+ case kIROp_UNormAttr:
+ switch (vectorSize)
+ {
+ case 1:
+ format = SpvImageFormatR8;
+ break;
+ case 2:
+ format = SpvImageFormatRg8;
+ break;
+ case 3:
+ format = SpvImageFormatRgba8;
+ break;
+ case 4:
+ format = SpvImageFormatRgba8;
+ break;
+ }
+ case kIROp_SNormAttr:
+ switch (vectorSize)
+ {
+ case 1:
+ format = SpvImageFormatR8Snorm;
+ break;
+ case 2:
+ format = SpvImageFormatRg8Snorm;
+ break;
+ case 3:
+ format = SpvImageFormatRgba8Snorm;
+ break;
+ case 4:
+ format = SpvImageFormatRgba8Snorm;
+ break;
+ }
+ }
+ }
+ }
+ }
//
// Capabilities, according to section 3.8
@@ -1856,7 +1935,7 @@ struct SPIRVEmitContext
{
auto imageType = emitOpTypeImage(
nullptr,
- dropVector(sampledType),
+ dropVector((IRType*)sampledType),
dim,
SpvLiteralInteger::from32(depth),
SpvLiteralInteger::from32(arrayed),
@@ -1871,7 +1950,7 @@ struct SPIRVEmitContext
return emitOpTypeImage(
assignee,
- dropVector(sampledType),
+ dropVector((IRType*)sampledType),
dim,
SpvLiteralInteger::from32(depth),
SpvLiteralInteger::from32(arrayed),
@@ -3327,19 +3406,39 @@ struct SPIRVEmitContext
}
}
- Dictionary<SpvBuiltIn, SpvInst*> m_builtinGlobalVars;
+ struct BuiltinSpvVarKey
+ {
+ SpvBuiltIn builtinName;
+ SpvStorageClass storageClass = SpvStorageClassInput;
+ BuiltinSpvVarKey() = default;
+ BuiltinSpvVarKey(SpvBuiltIn builtin, SpvStorageClass storageClass)
+ : builtinName(builtin), storageClass(storageClass)
+ {
+ }
+ bool operator==(const BuiltinSpvVarKey& other) const
+ {
+ return builtinName == other.builtinName && storageClass == other.storageClass;
+ }
+ HashCode getHashCode() const
+ {
+ return combineHash(Slang::getHashCode(builtinName), Slang::getHashCode(storageClass));
+ }
+ };
+ Dictionary<BuiltinSpvVarKey, SpvInst*> m_builtinGlobalVars;
SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal)
{
SpvInst* result = nullptr;
- if (m_builtinGlobalVars.tryGetValue(builtinVal, result))
+ auto ptrType = as<IRPtrTypeBase>(type);
+ SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type.");
+ auto storageClass = static_cast<SpvStorageClass>(ptrType->getAddressSpace());
+ auto key = BuiltinSpvVarKey(builtinVal, storageClass);
+ if (m_builtinGlobalVars.tryGetValue(key, result))
{
return result;
}
IRBuilder builder(m_irModule);
builder.setInsertBefore(type);
- auto ptrType = as<IRPtrTypeBase>(type);
- SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type.");
auto varInst = emitOpVariable(
getSection(SpvLogicalSectionID::GlobalVariables),
nullptr,
@@ -3352,7 +3451,7 @@ struct SPIRVEmitContext
varInst,
builtinVal
);
- m_builtinGlobalVars[builtinVal] = varInst;
+ m_builtinGlobalVars[key] = varInst;
return varInst;
}
@@ -5449,9 +5548,9 @@ struct SPIRVEmitContext
// Otherwise, we are truncating a vector to a smaller vector
else
{
- const auto toVector = cast<IRVectorType>(toType);
+ const auto toVector = cast<IRVectorType>(unwrapAttributedType(toType));
const auto toVectorSize = getIntVal(toVector->getElementCount());
- const auto fromVector = cast<IRVectorType>(fromType);
+ const auto fromVector = cast<IRVectorType>(unwrapAttributedType(fromType));
const auto fromVectorSize = getIntVal(fromVector->getElementCount());
if(toVectorSize > fromVectorSize)
m_sink->diagnose(inst, Diagnostics::spirvInvalidTruncate);
diff --git a/source/slang/slang-ir-insert-debug-value-store.cpp b/source/slang/slang-ir-insert-debug-value-store.cpp
index 8aad3e417..6243e6c09 100644
--- a/source/slang/slang-ir-insert-debug-value-store.cpp
+++ b/source/slang/slang-ir-insert-debug-value-store.cpp
@@ -80,6 +80,38 @@ namespace Slang
}
// Collect all stores and insert debug value insts to update debug vars.
+
+ // Helper func to insert debugValue updates.
+ auto setDebugValue = [&](IRInst* debugVar, IRInst* rootVar, IRInst* newValue, ArrayView<IRInst*> accessChain)
+ {
+ // SPIRV does not allow dynamic indices in DebugValue,
+ // so we need to stop the access chain at the first dynamic index.
+ Index i = 0;
+ for (; i < accessChain.getCount(); i++)
+ {
+ if (auto key = as<IRStructKey>(accessChain[i]))
+ {
+ continue;
+ }
+ if (as<IRIntLit>(accessChain[i]))
+ {
+ continue;
+ }
+ break;
+ }
+ // If everything is static on the access chain, we can simply emit a DebugValue.
+ if (i == accessChain.getCount())
+ {
+ builder.emitDebugValue(debugVar, newValue, accessChain);
+ return;
+ }
+
+ // Otherwise we need to load the entire composite value starting at the dynamic index access chain
+ // and set it.
+ auto compositePtr = builder.emitElementAddress(rootVar, accessChain.head(i));
+ auto compositeVal = builder.emitLoad(compositePtr);
+ builder.emitDebugValue(debugVar, compositeVal, accessChain.head(i));
+ };
for (auto block : func->getBlocks())
{
IRInst* nextInst = nullptr;
@@ -95,7 +127,7 @@ namespace Slang
if (mapVarToDebugVar.tryGetValue(varInst, debugVar))
{
builder.setInsertAfter(storeInst);
- builder.emitDebugValue(debugVar, storeInst->getVal(), accessChain.getArrayView());
+ setDebugValue(debugVar, varInst, storeInst->getVal(), accessChain.getArrayView());
}
}
else if (auto callInst = as<IRCall>(inst))
@@ -115,7 +147,7 @@ namespace Slang
{
builder.setInsertAfter(callInst);
auto loadVal = builder.emitLoad(arg);
- builder.emitDebugValue(debugVar, loadVal, accessChain.getArrayView());
+ setDebugValue(debugVar, varInst, loadVal, accessChain.getArrayView());
}
}
}
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 8652127da..bc571cba5 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -1503,6 +1503,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
case kIROp_Or:
case kIROp_Not:
case kIROp_Neg:
+ case kIROp_Div:
case kIROp_FieldExtract:
case kIROp_FieldAddress:
case kIROp_GetElement:
diff --git a/tests/spirv/debug-value-dynamic-index.slang b/tests/spirv/debug-value-dynamic-index.slang
new file mode 100644
index 000000000..591b1f6da
--- /dev/null
+++ b/tests/spirv/debug-value-dynamic-index.slang
@@ -0,0 +1,29 @@
+//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry main -stage compute -g2 -emit-spirv-directly
+
+struct TestType
+{
+ float memberA;
+ float3 memberB;
+ float arrayVal[10];
+ RWStructuredBuffer<float> memberC;
+ float getValue()
+ {
+ return memberA;
+ }
+}
+RWStructuredBuffer<float> result;
+void main(int id : SV_DispatchThreadID)
+{
+ TestType t;
+ t.memberA = 1.0;
+ t.arrayVal[id] = 2;
+ result[0] = t.arrayVal[id];
+}
+
+// CHECK: OpExtInst %void {{.*}} DebugExpression
+// CHECK: DebugTypeMember
+// CHECK: DebugTypeComposite
+// CHECK: DebugFunctionDefinition
+// CHECK: DebugScope
+// CHECK: DebugLine
+// CHECK: DebugValue