summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-06-11 00:01:12 -0700
committerGitHub <noreply@github.com>2024-06-11 00:01:12 -0700
commit6d5ef9b650a9db35f7774ca09d9225d0c30849e4 (patch)
treed8f812a52d937709efa23b17c1e36c51ee1b66f4
parent21bbebb19dfdbbee107b9fd9830e18d5fb6a573a (diff)
Fix `GetAttributeAtVertex` for spirv and glsl targets. (#4334)
-rw-r--r--source/slang/hlsl.meta.slang12
-rw-r--r--source/slang/slang-check-expr.cpp40
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-glsl.cpp5
-rw-r--r--source/slang/slang-emit-spirv.cpp9
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp93
-rw-r--r--source/slang/slang-ir-inst-defs.h8
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--tests/diagnostics/get-vertex-attribute.slang13
-rw-r--r--tests/spirv/get-vertex-attribute.slang15
11 files changed, 192 insertions, 8 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index e215fd93b..abc178317 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -8243,6 +8243,9 @@ matrix<T, N, M> fwidth(matrix<T, N, M> x)
}
}
+__intrinsic_op($(kIROp_GetPerVertexInputArray))
+Array<T, 3> __GetPerVertexInputArray<T>(T attribute);
+
/// Get the value of a vertex attribute at a specific vertex.
///
/// The `GetAttributeAtVertex()` function can be used in a fragment shader
@@ -8260,6 +8263,8 @@ __generic<T : __BuiltinType>
__glsl_version(450)
__glsl_extension(GL_EXT_fragment_shader_barycentric)
[require(glsl_hlsl_spirv, getattributeatvertex)]
+[KnownBuiltin("GetAttributeAtVertex")]
+[__unsafeForceInlineEarly]
T GetAttributeAtVertex(T attribute, uint vertexIndex)
{
__target_switch
@@ -8267,13 +8272,8 @@ T GetAttributeAtVertex(T attribute, uint vertexIndex)
case hlsl:
__intrinsic_asm "GetAttributeAtVertex";
case glsl:
- __intrinsic_asm "$0[$1]";
case spirv:
- return spirv_asm {
- %_ptr_Input_T = OpTypePointer Input $$T;
- %addr = OpAccessChain %_ptr_Input_T $attribute $vertexIndex;
- result:$$T = OpLoad %addr;
- };
+ return __GetPerVertexInputArray(attribute)[vertexIndex];
}
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 065a74c77..f7ad5bdbf 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2616,6 +2616,9 @@ namespace Slang
auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr);
+ // Perform additional validation for known built-in functions.
+ maybeCheckKnownBuiltinInvocation(checkedExpr);
+
if (m_parentDifferentiableAttr)
{
FunctionDifferentiableLevel callerDiffLevel = FunctionDifferentiableLevel::None;
@@ -3401,6 +3404,43 @@ namespace Slang
return expr;
}
+ void SemanticsExprVisitor::maybeCheckKnownBuiltinInvocation(Expr* invokeExpr)
+ {
+ auto checkedInvokeExpr = as<InvokeExpr>(invokeExpr);
+ if (!checkedInvokeExpr)
+ return;
+ auto declRefFuncExpr = as<DeclRefExpr>(checkedInvokeExpr->functionExpr);
+ if (!declRefFuncExpr)
+ return;
+ auto callee = declRefFuncExpr->declRef.getDecl();
+ if (!callee)
+ return;
+ auto knownBuiltinAttr = callee->findModifier<KnownBuiltinAttribute>();
+ if (!knownBuiltinAttr)
+ return;
+ if (knownBuiltinAttr->name == "GetAttributeAtVertex")
+ {
+ if (checkedInvokeExpr->arguments.getCount() != 2)
+ return;
+ auto vertexAttributeArg = checkedInvokeExpr->arguments[0];
+ auto vertexAttributeArgDeclRefExpr = as<DeclRefExpr>(vertexAttributeArg);
+ if (!vertexAttributeArgDeclRefExpr)
+ {
+ getSink()->diagnose(invokeExpr, Diagnostics::getAttributeAtVertexMustReferToPerVertexInput);
+ return;
+ }
+ auto vertexAttributeArgDecl = vertexAttributeArgDeclRefExpr->declRef.getDecl();
+ if (!vertexAttributeArgDecl)
+ return;
+ if (!vertexAttributeArgDecl->findModifier<PerVertexModifier>() &&
+ !vertexAttributeArgDecl->findModifier<HLSLNoInterpolationModifier>())
+ {
+ getSink()->diagnose(vertexAttributeArgDeclRefExpr, Diagnostics::getAttributeAtVertexMustReferToPerVertexInput);
+ return;
+ }
+ }
+ }
+
Expr* SemanticsVisitor::MaybeDereference(Expr* inExpr)
{
Expr* expr = inExpr;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ab6bc6585..d7a4827fb 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2676,6 +2676,8 @@ namespace Slang
Expr* visitAsTypeExpr(AsTypeExpr* expr);
+ void maybeCheckKnownBuiltinInvocation(Expr* invokeExpr);
+
//
// Some syntax nodes should not occur in the concrete input syntax,
// and will only appear *after* checking is complete. We need to
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 7fd31a46e..dd95b862f 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -692,6 +692,8 @@ DIAGNOSTIC(39024, Warning, cannotInferVulkanBindingWithoutRegisterModifier, "sha
DIAGNOSTIC(39025, Error, conflictingVulkanInferredBindingForParameter, "conflicting vulkan inferred binding for parameter '$0' overlap is $1 and $2")
DIAGNOSTIC(39026, Error, matrixLayoutModifierOnNonMatrixType, "matrix layout modifier cannot be used on non-matrix type '$0'.")
+
+DIAGNOSTIC(39027, Error, getAttributeAtVertexMustReferToPerVertexInput, "'GetAttributeAtVertex' must reference a vertex input directly, and the vertex input must be decorated with 'pervertex' or 'nointerpolation'.")
//
// 4xxxx - IL code generation.
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index d13bc96d1..936dc15ff 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -2755,6 +2755,11 @@ void GLSLSourceEmitter::emitVarDecorationsImpl(IRInst* varDecl)
{
prefix = toSlice("hitAttribute");
}
+ else if (as<IRPerVertexDecoration>(decoration))
+ {
+ _requireGLSLExtension(toSlice("GL_EXT_fragment_shader_barycentric"));
+ prefix = toSlice("pervertex");
+ }
else
{
IRIntegerValue locationValue = -1;
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 1db323993..4f7410f00 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -3535,6 +3535,15 @@ struct SPIRVEmitContext
(IRInterpolationMode)getIntVal(decoration->getOperand(0)),
dstID);
break;
+ case kIROp_PerVertexDecoration:
+ ensureExtensionDeclaration(UnownedStringSlice("SPV_KHR_fragment_shader_barycentric"));
+ requireSPIRVCapability(SpvCapabilityFragmentBarycentricKHR);
+ emitOpDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ decoration,
+ dstID,
+ SpvDecorationPerVertexKHR);
+ break;
case kIROp_MemoryQualifierSetDecoration:
{
auto collection = as<IRMemoryQualifierSetDecoration>(decoration);
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index b3be21282..5be700be1 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -8,6 +8,8 @@
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-specialize-function-call.h"
#include "slang-ir-util.h"
+#include "slang-ir-clone.h"
+
#include "slang-glsl-extension-tracker.h"
#include "../../external/spirv-headers/include/spirv/unified1/spirv.h"
@@ -417,6 +419,10 @@ struct GLSLLegalizationContext
// Currently only used for special cases of semantics which map to global variables
Dictionary<UnownedStringSlice, SystemSemanticGlobal> systemNameToGlobalMap;
+ // Map from a input parameter in fragment shader to its corresponding per-vertex array
+ // to support the `GetAttributeAtVertex` intrinsic.
+ Dictionary<IRInst*, IRInst*> mapVertexInputToPerVertexArray;
+
void requireGLSLExtension(const UnownedStringSlice& name)
{
glslExtensionTracker->requireExtension(name);
@@ -2443,6 +2449,91 @@ static void legalizeMeshOutputParam(
g->removeAndDeallocate();
}
+IRInst* getOrCreatePerVertexInputArray(
+ GLSLLegalizationContext* context,
+ IRInst* inputVertexAttr)
+{
+ IRInst* arrayInst = nullptr;
+ if (context->mapVertexInputToPerVertexArray.tryGetValue(inputVertexAttr, arrayInst))
+ return arrayInst;
+ IRBuilder builder(inputVertexAttr);
+ builder.setInsertBefore(inputVertexAttr);
+ auto arrayType = builder.getArrayType(inputVertexAttr->getDataType(), builder.getIntValue(builder.getIntType(), 3));
+ arrayInst = builder.createGlobalParam(arrayType);
+ context->mapVertexInputToPerVertexArray[inputVertexAttr] = arrayInst;
+ builder.addDecoration(arrayInst, kIROp_PerVertexDecoration);
+
+ // Clone decorations from original input.
+ for (auto decoration : inputVertexAttr->getDecorations())
+ {
+ switch (decoration->getOp())
+ {
+ case kIROp_InterpolationModeDecoration:
+ continue;
+ default:
+ cloneDecoration(decoration, arrayInst);
+ break;
+ }
+ }
+ return arrayInst;
+}
+
+void tryReplaceUsesOfStageInput(
+ GLSLLegalizationContext* context,
+ ScalarizedVal val,
+ IRInst* originalVal)
+{
+ switch (val.flavor)
+ {
+ case ScalarizedVal::Flavor::value:
+ {
+ traverseUses(originalVal, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (user->getOp() == kIROp_GetPerVertexInputArray)
+ {
+ auto arrayInst = getOrCreatePerVertexInputArray(context, val.irValue);
+ user->replaceUsesWith(arrayInst);
+ user->removeAndDeallocate();
+ }
+ else
+ {
+ IRBuilder builder(user);
+ builder.setInsertBefore(user);
+ builder.replaceOperand(use, val.irValue);
+ }
+ });
+ }
+ break;
+ case ScalarizedVal::Flavor::tuple:
+ {
+ auto tupleVal = as<ScalarizedTupleValImpl>(val.impl);
+ traverseUses(originalVal, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (auto fieldExtract = as<IRFieldExtract>(user))
+ {
+ auto fieldKey = fieldExtract->getField();
+ ScalarizedVal fieldVal;
+ for (auto element : tupleVal->elements)
+ {
+ if (element.key == fieldKey)
+ {
+ fieldVal = element.val;
+ break;
+ }
+ }
+ if (fieldVal.flavor != ScalarizedVal::Flavor::none)
+ {
+ tryReplaceUsesOfStageInput(context, fieldVal, user);
+ }
+ }
+ });
+ }
+ break;
+ }
+}
+
void legalizeEntryPointParameterForGLSL(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
@@ -2751,6 +2842,8 @@ void legalizeEntryPointParameterForGLSL(
codeGenContext,
builder, paramType, paramLayout, LayoutResourceKind::VaryingInput, stage, pp);
+ tryReplaceUsesOfStageInput(context, globalValue, pp);
+
// we have a simple struct which represents all materialized GlobalParams, this
// struct will replace the no longer needed global variable which proxied as a
// GlobalParam.
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 74990ed55..ab67dc4bf 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -366,7 +366,7 @@ INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL)
// An inst to represent the workgroup size of the calling entry point.
// We will materialize this inst during `translateGLSLGlobalVar`.
-INST(GetWorkGroupSize, kIROp_GetWorkGroupSize, 0, HOISTABLE)
+INST(GetWorkGroupSize, GetWorkGroupSize, 0, HOISTABLE)
INST(Param, param, 0, 0)
INST(StructField, field, 2, 0)
@@ -678,7 +678,9 @@ INST(GetOptiXSbtDataPtr, getOptiXSbtDataPointer, 0, 0)
INST(GetVulkanRayTracingPayloadLocation, GetVulkanRayTracingPayloadLocation, 1, 0)
-INST(GetLegalizedSPIRVGlobalParamAddr, kIROp_GetLegalizedSPIRVGlobalParamAddr, 1, 0)
+INST(GetLegalizedSPIRVGlobalParamAddr, GetLegalizedSPIRVGlobalParamAddr, 1, 0)
+
+INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, 0)
INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0)
@@ -904,6 +906,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
// Marks an inst that represents the gl_Position input.
INST(GLPositionInputDecoration, PositionInput, 0, 0)
+ // Marks a fragment shader input as per-vertex.
+ INST(PerVertexDecoration, PerVertex, 0, 0)
/* StageAccessDecoration */
INST(StageReadAccessDecoration, stageReadAccess, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9e377fe73..a4e9906f8 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -326,6 +326,7 @@ IR_SIMPLE_DECORATION(VulkanHitAttributesDecoration)
/// to it.
IR_SIMPLE_DECORATION(VulkanHitObjectAttributesDecoration)
+IR_SIMPLE_DECORATION(PerVertexDecoration)
struct IRRequireGLSLVersionDecoration : IRDecoration
{
diff --git a/tests/diagnostics/get-vertex-attribute.slang b/tests/diagnostics/get-vertex-attribute.slang
new file mode 100644
index 000000000..613a5a7c8
--- /dev/null
+++ b/tests/diagnostics/get-vertex-attribute.slang
@@ -0,0 +1,13 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+struct VertexOutput
+{
+ float color;
+}
+
+[shader("fragment")]
+float4 fsmain(VertexOutput vout) : SV_Target
+{
+ // CHECK: ([[# @LINE+1]]): error 39027
+ return GetAttributeAtVertex(vout.color, 0); // error: color must be decorated with `nointerpolation`.
+}
diff --git a/tests/spirv/get-vertex-attribute.slang b/tests/spirv/get-vertex-attribute.slang
new file mode 100644
index 000000000..655b7ad03
--- /dev/null
+++ b/tests/spirv/get-vertex-attribute.slang
@@ -0,0 +1,15 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+//TEST:SIMPLE(filecheck=CHECK): -target spirv -emit-spirv-via-glsl
+
+// CHECK: OpDecorate %vout_vertexID{{.*}} PerVertexKHR
+
+struct VertexOutput
+{
+ nointerpolation int vertexID;
+}
+
+[shader("fragment")]
+float4 fsmain(VertexOutput vout) : SV_Target
+{
+ return GetAttributeAtVertex(vout.vertexID, 0);
+}