summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-06-10 12:32:17 -0700
committerGitHub <noreply@github.com>2024-06-10 12:32:17 -0700
commit38c0baccac70ca36a2c90218d6a92b8c036b1a5e (patch)
tree0bc29bad69dc988777e1275b707c1d92737dcd7f /source
parentb5cdd8322bd51603c217dfb7662306628b144c78 (diff)
Fix SPIRV emit for `Flat` decoration and TessLevel builtin. (#4318)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-emit-spirv.cpp176
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp33
2 files changed, 118 insertions, 91 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 1ef3a31e0..d7c6c64a5 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -2100,7 +2100,7 @@ struct SPIRVEmitContext
return result;
}
- bool _maybeEmitInterpolationModifierDecoration(IRInterpolationMode mode, SpvInst* varInst)
+ bool _maybeEmitInterpolationModifierDecoration(IRInterpolationMode mode, SpvId varInst)
{
switch (mode)
{
@@ -2234,16 +2234,7 @@ struct SPIRVEmitContext
SpvLiteralInteger::from32(int32_t(0)));
}
- bool anyModifiers = false;
- for (auto dd : var->getDecorations())
- {
- if (dd->getOp() != kIROp_InterpolationModeDecoration)
- continue;
-
- auto decoration = (IRInterpolationModeDecoration*)dd;
-
- anyModifiers |= _maybeEmitInterpolationModifierDecoration(decoration->getMode(), varInst);
- }
+ bool anyModifiers = (var->findDecoration<IRInterpolationModeDecoration>() != nullptr);
// If the user didn't explicitly qualify a varying
// with integer type, then we need to explicitly
@@ -2341,7 +2332,7 @@ struct SPIRVEmitContext
const auto kind = (SpvBuiltIn)(getIntVal(spvAsmBuiltinVar->getOperand(0)));
IRBuilder builder(spvAsmBuiltinVar);
builder.setInsertBefore(spvAsmBuiltinVar);
- auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), SpvStorageClassInput), kind);
+ auto varInst = getBuiltinGlobalVar(builder.getPtrType(kIROp_PtrType, spvAsmBuiltinVar->getDataType(), SpvStorageClassInput), kind, spvAsmBuiltinVar);
registerInst(spvAsmBuiltinVar, varInst);
return varInst;
}
@@ -3538,6 +3529,11 @@ struct SPIRVEmitContext
requireSPIRVCapability(SpvCapabilityRayQueryKHR);
isRayTracingObject = true;
break;
+ case kIROp_InterpolationModeDecoration:
+ _maybeEmitInterpolationModifierDecoration(
+ (IRInterpolationMode)getIntVal(decoration->getOperand(0)),
+ dstID);
+ break;
case kIROp_MemoryQualifierSetDecoration:
{
auto collection = as<IRMemoryQualifierSetDecoration>(decoration);
@@ -3872,7 +3868,45 @@ struct SPIRVEmitContext
};
Dictionary<BuiltinSpvVarKey, SpvInst*> m_builtinGlobalVars;
- SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal)
+
+ bool isInstUsedInStage(IRInst* inst, Stage s)
+ {
+ auto* referencingEntryPoints = m_referencingEntryPoints.tryGetValue(inst);
+ if (!referencingEntryPoints)
+ return false;
+ for (auto entryPoint : *referencingEntryPoints)
+ {
+ if (auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>())
+ {
+ if (entryPointDecor->getProfile().getStage() == s)
+ return true;
+ }
+ }
+ return false;
+ }
+
+ void maybeEmitFlatDecorationForBuiltinVar(IRInst* irInst, SpvInst* spvInst)
+ {
+ if (!irInst)
+ return;
+ if (irInst->getOp() != kIROp_GlobalVar &&
+ irInst->getOp() != kIROp_GlobalParam)
+ return;
+ auto ptrType = as<IRPtrType>(irInst->getDataType());
+ if (!ptrType)
+ return;
+ auto addrSpace = ptrType->getAddressSpace();
+ if (addrSpace == SpvStorageClassInput)
+ {
+ if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
+ {
+ if (isInstUsedInStage(irInst, Stage::Fragment))
+ _maybeEmitInterpolationModifierDecoration(IRInterpolationMode::NoInterpolation, getID(spvInst));
+ }
+ }
+ }
+
+ SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal, IRInst* irInst)
{
SpvInst* result = nullptr;
auto ptrType = as<IRPtrTypeBase>(type);
@@ -3898,6 +3932,9 @@ struct SPIRVEmitContext
builtinVal
);
m_builtinGlobalVars[key] = varInst;
+
+ maybeEmitFlatDecorationForBuiltinVar(irInst, varInst);
+
return varInst;
}
@@ -3915,9 +3952,9 @@ struct SPIRVEmitContext
{
auto importDecor = inst->findDecoration<IRImportDecoration>();
if (importDecor->getMangledName() == "gl_FragCoord")
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragCoord);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragCoord, inst);
else
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPosition);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPosition, inst);
}
else if (semanticName == "sv_target")
{
@@ -3929,75 +3966,75 @@ struct SPIRVEmitContext
}
else if (semanticName == "sv_clipdistance")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInClipDistance);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInClipDistance, inst);
}
else if (semanticName == "sv_culldistance")
{
requireSPIRVCapability(SpvCapabilityCullDistance);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullDistance);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullDistance, inst);
}
else if (semanticName == "sv_coverage")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleMask);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleMask, inst);
}
else if (semanticName == "sv_innercoverage")
{
requireSPIRVCapability(SpvCapabilityFragmentFullyCoveredEXT);
ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_fragment_fully_covered"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFullyCoveredEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFullyCoveredEXT, inst);
}
else if (semanticName == "sv_depth")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth, inst);
}
else if (semanticName == "sv_depthgreaterequal")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth, inst);
}
else if (semanticName == "sv_depthlessequal")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragDepth, inst);
}
else if (semanticName == "sv_dispatchthreadid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInGlobalInvocationId, inst);
}
else if (semanticName == "sv_domainlocation")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessCoord);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessCoord, inst);
}
else if (semanticName == "sv_groupid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInWorkgroupId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInWorkgroupId, inst);
}
else if (semanticName == "sv_groupindex")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationIndex, inst);
}
else if (semanticName == "sv_groupthreadid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLocalInvocationId, inst);
}
else if (semanticName == "sv_gsinstanceid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId, inst);
}
else if (semanticName == "sv_instanceid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInstanceIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInstanceIndex, inst);
}
else if (semanticName == "sv_isfrontface")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFrontFacing);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFrontFacing, inst);
}
else if (semanticName == "sv_outputcontrolpointid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInInvocationId, inst);
}
else if (semanticName == "sv_pointsize")
{
// float in hlsl & glsl
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPointSize);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPointSize, inst);
}
else if (semanticName == "sv_primitiveid")
{
@@ -4032,7 +4069,7 @@ struct SPIRVEmitContext
}
if (needGeometryCapability)
requireSPIRVCapability(SpvCapabilityGeometry);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveId, inst);
}
else if (semanticName == "sv_rendertargetarrayindex")
{
@@ -4040,37 +4077,42 @@ struct SPIRVEmitContext
requireSPIRVCapability(SpvCapabilityShaderLayer);
else
requireSPIRVCapability(SpvCapabilityGeometry);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLayer);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInLayer, inst);
}
else if (semanticName == "sv_sampleindex")
{
requireSPIRVCapability(SpvCapabilitySampleRateShading);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleId);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInSampleId, inst);
}
else if (semanticName == "sv_stencilref")
{
requireSPIRVCapability(SpvCapabilityStencilExportEXT);
ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_shader_stencil_export"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragStencilRefEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInFragStencilRefEXT, inst);
}
else if (semanticName == "sv_tessfactor")
{
requireSPIRVCapability(SpvCapabilityTessellation);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelOuter);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelOuter, inst);
+ }
+ else if (semanticName == "sv_insidetessfactor")
+ {
+ requireSPIRVCapability(SpvCapabilityTessellation);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInTessLevelInner, inst);
}
else if (semanticName == "sv_vertexid")
{
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInVertexIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInVertexIndex, inst);
}
else if (semanticName == "sv_viewid")
{
requireSPIRVCapability(SpvCapabilityMultiView);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewIndex, inst);
}
else if (semanticName == "sv_viewportarrayindex")
{
requireSPIRVCapability(SpvCapabilityShaderViewportIndex);
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportIndex);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportIndex, inst);
}
else if (semanticName == "nv_x_right")
{
@@ -4080,13 +4122,13 @@ struct SPIRVEmitContext
{
requireSPIRVCapability(SpvCapabilityPerViewAttributesNV);
ensureExtensionDeclaration(UnownedStringSlice("SPV_NV_mesh_shader"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportMaskPerViewNV);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInViewportMaskPerViewNV, inst);
}
else if (semanticName == "sv_barycentrics")
{
requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR);
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shader_barycentric"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInBaryCoordKHR, inst);
// TODO: There is also the `gl_BaryCoordNoPerspNV` builtin, which
// we ought to use if the `noperspective` modifier has been
@@ -4096,7 +4138,7 @@ struct SPIRVEmitContext
{
requireSPIRVCapability(SpvCapabilityMeshShadingEXT);
ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_mesh_shader"));
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullPrimitiveEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInCullPrimitiveEXT, inst);
}
else if (semanticName == "sv_shadingrate")
{
@@ -4104,9 +4146,9 @@ struct SPIRVEmitContext
ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shading_rate"));
auto importDecor = inst->findDecoration<IRImportDecoration>();
if (importDecor && importDecor->getMangledName() == "gl_PrimitiveShadingRateEXT")
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveShadingRateKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveShadingRateKHR, inst);
else
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInShadingRateKHR);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInShadingRateKHR, inst);
}
SLANG_UNREACHABLE("Unimplemented system value in spirv emit.");
}
@@ -4121,11 +4163,11 @@ struct SPIRVEmitContext
{
const auto name = linkageDecoration->getMangledName();
if(name == "gl_PrimitiveTriangleIndicesEXT")
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveTriangleIndicesEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveTriangleIndicesEXT, inst);
if(name == "gl_PrimitiveLineIndicesEXT")
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveLineIndicesEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitiveLineIndicesEXT, inst);
if(name == "gl_PrimitivePointIndicesEXT")
- return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitivePointIndicesEXT);
+ return getBuiltinGlobalVar(inst->getFullType(), SpvBuiltInPrimitivePointIndicesEXT, inst);
}
return nullptr;
@@ -6250,22 +6292,6 @@ struct SPIRVEmitContext
}
};
-bool isInstUsedInStage(SPIRVEmitContext& context, IRInst* inst, Stage s)
-{
- auto* referencingEntryPoints = context.m_referencingEntryPoints.tryGetValue(inst);
- if (!referencingEntryPoints)
- return false;
- for (auto entryPoint : *referencingEntryPoints)
- {
- if (auto entryPointDecor = entryPoint->findDecoration<IREntryPointDecoration>())
- {
- if (entryPointDecor->getProfile().getStage() == s)
- return true;
- }
- }
- return false;
-}
-
SlangResult emitSPIRVFromIR(
CodeGenContext* codeGenContext,
IRModule* irModule,
@@ -6330,28 +6356,6 @@ SlangResult emitSPIRVFromIR(
context.ensureInst(irEntryPoint);
}
- // Declare integral input builtins as Flat if necessary.
- for (auto globalInst : context.m_irModule->getGlobalInsts())
- {
- if (globalInst->getOp() != kIROp_GlobalVar &&
- globalInst->getOp() != kIROp_GlobalParam)
- continue;
- auto spvVar = context.m_mapIRInstToSpvInst.tryGetValue(globalInst);
- if (!spvVar)
- continue;
- auto ptrType = as<IRPtrType>(globalInst->getDataType());
- if (!ptrType)
- continue;
- auto addrSpace = ptrType->getAddressSpace();
- if (addrSpace == SpvStorageClassInput)
- {
- if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
- {
- if (isInstUsedInStage(context, globalInst, Stage::Fragment))
- context._maybeEmitInterpolationModifierDecoration(IRInterpolationMode::NoInterpolation, *spvVar);
- }
- }
- }
// Move forward delcared pointers to the end.
do
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 9b30b6072..b3be21282 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -852,6 +852,13 @@ GLSLSystemValueInfo* getGLSLSystemValueInfo(
// float[4] on glsl
requiredType = builder->getArrayType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 4));
}
+ else if (semanticName == "sv_insidetessfactor")
+ {
+ name = "gl_TessLevelInner";
+
+ // float[2] on glsl
+ requiredType = builder->getArrayType(builder->getBasicType(BaseType::Float), builder->getIntValue(builder->getIntType(), 2));
+ }
else if (semanticName == "sv_vertexid")
{
// uint in hlsl, int in glsl (https://www.khronos.org/opengl/wiki/Built-in_Variable_(GLSL))
@@ -1071,12 +1078,30 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
&systemValueInfoStorage);
IRType* type = inType;
+ IRType* peeledRequiredType = nullptr;
// A system-value semantic might end up needing to override the type
// that the user specified.
if( systemValueInfo && systemValueInfo->requiredType )
{
type = systemValueInfo->requiredType;
+ peeledRequiredType = type;
+ // Unpeel `type` using declarators so that it matches `inType`.
+ for (auto dd = declarator; dd; dd = dd->next)
+ {
+ switch (dd->flavor)
+ {
+ case GlobalVaryingDeclarator::Flavor::array:
+ {
+ if (auto arrayType = as<IRArrayTypeBase>(type))
+ {
+ type = arrayType->getElementType();
+ peeledRequiredType = type;
+ }
+ break;
+ }
+ }
+ }
}
// If we have a declarator, we just use the normal logic, as that seems to work correctly
@@ -1237,16 +1262,14 @@ ScalarizedVal createSimpleGLSLGlobalVarying(
if (systemValueInfo)
{
- if (auto fromType = systemValueInfo->requiredType)
+ if (systemValueInfo->requiredType)
{
// We may need to adapt from the declared type to/from
// the actual type of the GLSL global.
- auto toType = inType;
-
- if (!isTypeEqual(fromType, toType))
+ if (!isTypeEqual(peeledRequiredType, inType))
{
RefPtr<ScalarizedTypeAdapterValImpl> typeAdapter = new ScalarizedTypeAdapterValImpl;
- typeAdapter->actualType = systemValueInfo->requiredType;
+ typeAdapter->actualType = peeledRequiredType;
typeAdapter->pretendType = inType;
typeAdapter->val = val;