summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/hlsl.meta.slang35
-rw-r--r--source/slang/slang-compiler.cpp25
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-emit-spirv.cpp102
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp429
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-specialize-target-switch.cpp18
-rw-r--r--source/slang/slang-parameter-binding.cpp17
-rw-r--r--tests/spirv/tessellation.slang65
9 files changed, 604 insertions, 92 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 1fac47588..c03c47703 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -3799,13 +3799,30 @@ struct ConsumeStructuredBuffer
}
};
+__intrinsic_op($(kIROp_GetElement))
+T __getElement<T, U, I>(U collection, I index);
+
__generic<T, let N : int>
[require(glsl_hlsl_spirv, hull)]
__magic_type(HLSLInputPatchType)
__intrinsic_type($(kIROp_HLSLInputPatchType))
struct InputPatch
{
- __subscript(uint index) -> T;
+ __generic<TIndex : __BuiltinIntegerType>
+ __subscript(TIndex index)->T
+ {
+ [__unsafeForceInlineEarly]
+ get
+ {
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm ".operator[]";
+ default:
+ return __getElement<T>(this, index);
+ }
+ }
+ }
};
__generic<T, let N : int>
@@ -3814,7 +3831,21 @@ __magic_type(HLSLOutputPatchType)
__intrinsic_type($(kIROp_HLSLOutputPatchType))
struct OutputPatch
{
- __subscript(uint index) -> T;
+ __generic<TIndex : __BuiltinIntegerType>
+ __subscript(TIndex index)->T
+ {
+ [__unsafeForceInlineEarly]
+ get
+ {
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm ".operator[]";
+ default:
+ return __getElement<T>(this, index);
+ }
+ }
+ }
};
${{{{
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 0277bb092..ed208ca37 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -2536,12 +2536,27 @@ namespace Slang
if (allTargetsCUDARelated && targets.getCount() > 0)
continue;
- auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>();
- if (numThreadsAttr)
- profile.setStage(Stage::Compute);
- else
+ bool canDetermineStage = false;
+ for (auto modifier : funcDecl->modifiers)
+ {
+ if (as<NumThreadsAttribute>(modifier))
+ {
+ if (funcDecl->findModifier<OutputTopologyAttribute>())
+ profile.setStage(Stage::Mesh);
+ else
+ profile.setStage(Stage::Compute);
+ canDetermineStage = true;
+ break;
+ }
+ else if (as<PatchConstantFuncAttribute>(modifier))
+ {
+ profile.setStage(Stage::Hull);
+ canDetermineStage = true;
+ break;
+ }
+ }
+ if (!canDetermineStage)
continue;
-
}
RefPtr<EntryPoint> entryPoint = EntryPoint::create(
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index dd95b862f..eb7b5b993 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -694,6 +694,7 @@ DIAGNOSTIC(39025, Error, conflictingVulkanInferredBindingForParameter, "conflict
DIAGNOSTIC(39026, Error, matrixLayoutModifierOnNonMatrixType, "matrix layout modifier cannot be used on non-matrix type '$0'.")
DIAGNOSTIC(39027, Error, getAttributeAtVertexMustReferToPerVertexInput, "'GetAttributeAtVertex' must reference a vertex input directly, and the vertex input must be decorated with 'pervertex' or 'nointerpolation'.")
+
//
// 4xxxx - IL code generation.
@@ -843,6 +844,8 @@ DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatic
DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0")
+DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.")
+DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.")
// GLSL Compatibility
DIAGNOSTIC(58001, Error, entryPointMustReturnVoidWhenGlobalOutputPresent, "entry point must return 'void' when global output variables are present.")
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 4f7410f00..fd4b1d491 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -2963,6 +2963,15 @@ struct SPIRVEmitContext
result = emitOpAtomicIDecrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics);
}
break;
+ case kIROp_ControlBarrier:
+ {
+ IRBuilder builder{ inst };
+ const auto executionScope = emitIntConstant(IRIntegerValue{ SpvScopeWorkgroup }, builder.getUIntType());
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeInvocation }, builder.getUIntType());
+ const auto memorySemantics = emitIntConstant(IRIntegerValue{ SpvMemorySemanticsMaskNone }, builder.getUIntType());
+ emitInst(parent, inst, SpvOpControlBarrier, executionScope, memoryScope, memorySemantics);
+ }
+ break;
}
if (result)
emitDecorations(inst, getID(result));
@@ -3323,6 +3332,29 @@ struct SPIRVEmitContext
requireSPIRVCapability(SpvCapabilityMeshShadingEXT);
ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_mesh_shader"));
break;
+ case Stage::Hull:
+ {
+ requireSPIRVCapability(SpvCapabilityTessellation);
+
+ SpvExecutionMode mode = SpvExecutionModeSpacingEqual;
+ if (auto partitioningDecor = entryPoint->findDecoration<IRPartitioningDecoration>())
+ {
+ auto arg = partitioningDecor->getPartitioning()->getStringSlice();
+ if (arg.caseInsensitiveEquals(toSlice("integer")))
+ mode = SpvExecutionModeSpacingEqual;
+ else if (arg.caseInsensitiveEquals(toSlice("fractional_even")))
+ mode = SpvExecutionModeSpacingFractionalEven;
+ else if (arg.caseInsensitiveEquals(toSlice("fractional_odd")))
+ mode = SpvExecutionModeSpacingFractionalOdd;
+ else
+ m_sink->diagnose(partitioningDecor, Diagnostics::unknownTessPartitioning, arg);
+ }
+ requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), mode);
+ break;
+ }
+ case Stage::Domain:
+ requireSPIRVCapability(SpvCapabilityTessellation);
+ break;
default:
break;
}
@@ -3463,13 +3495,36 @@ struct SPIRVEmitContext
case kIROp_OutputTopologyDecoration:
{
+ auto entryPoint = decoration->getParent();
+ IREntryPointDecoration* entryPointDecor = entryPoint ? entryPoint->findDecoration<IREntryPointDecoration>() : nullptr;
+
const auto o = cast<IROutputTopologyDecoration>(decoration);
const auto t = o->getTopology()->getStringSlice();
- const auto m =
- t == "triangle" ? SpvExecutionModeOutputTrianglesEXT
- : t == "line" ? SpvExecutionModeOutputLinesEXT
- : t == "point" ? SpvExecutionModeOutputPoints
- : SpvExecutionModeMax;
+
+ SpvExecutionMode m = SpvExecutionModeMax;
+ if (entryPointDecor)
+ {
+ switch (entryPointDecor->getProfile().getStage())
+ {
+ case Stage::Domain:
+ case Stage::Hull:
+ if (t == "triangle_cw")
+ m = SpvExecutionModeVertexOrderCw;
+ else if (t == "triangle_ccw")
+ m = SpvExecutionModeVertexOrderCcw;
+ break;
+ }
+ }
+ if (m == SpvExecutionModeMax)
+ {
+ if (t == "triangle")
+ m = SpvExecutionModeOutputTrianglesEXT;
+ else if (t == "line")
+ m = SpvExecutionModeOutputTrianglesEXT;
+ else if (t == "point")
+ m = SpvExecutionModeOutputPoints;
+ }
+
SLANG_ASSERT(m != SpvExecutionModeMax);
requireSPIRVExecutionMode(decoration, dstID, m);
}
@@ -3544,6 +3599,31 @@ struct SPIRVEmitContext
dstID,
SpvDecorationPerVertexKHR);
break;
+ case kIROp_OutputControlPointsDecoration:
+ requireSPIRVExecutionMode(
+ decoration,
+ dstID,
+ SpvExecutionModeOutputVertices,
+ SpvLiteralInteger::from32(int32_t(getIntVal(decoration->getOperand(0)))));
+ break;
+ case kIROp_DomainDecoration:
+ {
+ auto domain = cast<IRDomainDecoration>(decoration);
+ SpvExecutionMode mode = SpvExecutionModeMax;
+ auto domainName = as<IRStringLit>(domain->getDomain());
+ if (!domainName)
+ break;
+ auto domainStr = domainName->getStringSlice();
+ if (domainStr.startsWithCaseInsensitive(toSlice("tri")))
+ mode = SpvExecutionModeTriangles;
+ else if (domainStr.caseInsensitiveEquals(toSlice("quad")))
+ mode = SpvExecutionModeQuads;
+ else if (domainStr.caseInsensitiveEquals(toSlice("isoline")))
+ mode = SpvExecutionModeIsolines;
+ if (mode != SpvExecutionModeMax)
+ requireSPIRVExecutionMode(decoration, dstID, mode);
+ }
+ break;
case kIROp_MemoryQualifierSetDecoration:
{
auto collection = as<IRMemoryQualifierSetDecoration>(decoration);
@@ -3941,6 +4021,18 @@ struct SPIRVEmitContext
varInst,
builtinVal
);
+ switch (builtinVal)
+ {
+ case SpvBuiltInTessLevelInner:
+ case SpvBuiltInTessLevelOuter:
+ emitOpDecorate(
+ getSection(SpvLogicalSectionID::Annotations),
+ nullptr,
+ varInst,
+ SpvDecorationPatch
+ );
+ break;
+ }
m_builtinGlobalVars[key] = varInst;
maybeEmitFlatDecorationForBuiltinVar(irInst, varInst);
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 5be700be1..bcf2d8a4f 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -9,7 +9,7 @@
#include "slang-ir-specialize-function-call.h"
#include "slang-ir-util.h"
#include "slang-ir-clone.h"
-
+#include "slang-ir-single-return.h"
#include "slang-glsl-extension-tracker.h"
#include "../../external/spirv-headers/include/spirv/unified1/spirv.h"
@@ -293,6 +293,10 @@ struct ScalarizedVal
RefPtr<ScalarizedValImpl> impl;
};
+IRInst* materializeValue(
+ IRBuilder* builder,
+ ScalarizedVal const& val);
+
// This is the case for a value that is a "tuple" of other values
struct ScalarizedTupleValImpl : ScalarizedValImpl
{
@@ -315,9 +319,6 @@ struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl
IRType* pretendType; // the type this value pretends to have
};
-
-
-
struct GlobalVaryingDeclarator
{
enum class Flavor
@@ -404,6 +405,7 @@ struct GLSLLegalizationContext
GLSLExtensionTracker* glslExtensionTracker;
DiagnosticSink* sink;
Stage stage;
+ IRFunc* entryPointFunc;
struct SystemSemanticGlobal
{
@@ -1056,6 +1058,206 @@ void createVarLayoutForLegalizedGlobalParam(
}
}
+IRInst* getOrCreateBuiltinParamForHullShader(GLSLLegalizationContext* context, UnownedStringSlice builtinSemantic)
+{
+ IRInst* outputControlPointIdParam = nullptr;
+ if (context->stage == Stage::Hull)
+ {
+ for (auto param : context->entryPointFunc->getParams())
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ continue;
+ auto sysAttr = layout->findSystemValueSemanticAttr();
+ if (!sysAttr)
+ continue;
+ if (sysAttr->getName().caseInsensitiveEquals(builtinSemantic))
+ {
+ outputControlPointIdParam = param;
+ break;
+ }
+ }
+ if (!outputControlPointIdParam)
+ {
+ IRBuilder builder(context->entryPointFunc);
+ auto paramType = builder.getIntType();
+ builder.setInsertInto(context->entryPointFunc->getFirstBlock()->getFirstOrdinaryInst());
+ outputControlPointIdParam = builder.emitParam(paramType);
+ IRStructTypeLayout::Builder typeBuilder(&builder);
+ auto typeLayout = typeBuilder.build();
+ IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout);
+ varLayoutBuilder.setSystemValueSemantic(builtinSemantic, 0);
+ varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput);
+ auto varLayout = varLayoutBuilder.build();
+ builder.addLayoutDecoration(outputControlPointIdParam, varLayout);
+ }
+ }
+ return outputControlPointIdParam;
+}
+
+IRTypeLayout* createPatchConstantFuncResultTypeLayout(IRBuilder& irBuilder, IRType* type)
+{
+ if (auto structType = as<IRStructType>(type))
+ {
+ IRStructTypeLayout::Builder builder(&irBuilder);
+ for (auto field : structType->getFields())
+ {
+ auto fieldType = field->getFieldType();
+
+ IRTypeLayout* fieldTypeLayout = createPatchConstantFuncResultTypeLayout(irBuilder, fieldType);
+ IRVarLayout::Builder fieldVarLayoutBuilder(&irBuilder, fieldTypeLayout);
+ auto decoration = field->getKey()->findDecoration<IRSemanticDecoration>();
+ if (decoration)
+ {
+ if (decoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_")))
+ fieldVarLayoutBuilder.setSystemValueSemantic(decoration->getSemanticName(), 0);
+ }
+ builder.addField(field->getKey(), fieldVarLayoutBuilder.build());
+ }
+ auto typeLayout = builder.build();
+ return typeLayout;
+ }
+ else if (auto arrayType = as<IRArrayTypeBase>(type))
+ {
+ auto elementTypeLayout = createPatchConstantFuncResultTypeLayout(irBuilder, arrayType->getElementType());
+ IRArrayTypeLayout::Builder builder(&irBuilder, elementTypeLayout);
+ return builder.build();
+ }
+ else
+ {
+ IRTypeLayout::Builder builder(&irBuilder);
+ builder.addResourceUsage(LayoutResourceKind::VaryingOutput, LayoutSize::fromRaw(1));
+ return builder.build();
+ }
+}
+
+ScalarizedVal legalizeEntryPointReturnValueForGLSL(
+ GLSLLegalizationContext* context,
+ CodeGenContext* codeGenContext,
+ IRBuilder& builder,
+ IRFunc* func,
+ IRVarLayout* resultLayout);
+
+void invokePathConstantFuncInHullShader(GLSLLegalizationContext* context, CodeGenContext* codeGenContext, ScalarizedVal outputPatchVal)
+{
+ auto entryPoint = context->entryPointFunc;
+ auto patchConstantFuncDecor = entryPoint->findDecoration<IRPatchConstantFuncDecoration>();
+ if (!patchConstantFuncDecor)
+ return;
+ IRInst* inputPatchArg = nullptr;
+ for (auto param : entryPoint->getParams())
+ {
+ if (as<IRHLSLInputPatchType>(param->getDataType()))
+ {
+ inputPatchArg = param;
+ break;
+ }
+ }
+ IRBuilder builder(entryPoint);
+ builder.setInsertInto(entryPoint);
+ IRBlock* conditionBlock = builder.emitBlock();
+ for (auto block : entryPoint->getBlocks())
+ {
+ if (auto returnInst = as<IRReturn>(block->getTerminator()))
+ {
+ builder.setInsertBefore(returnInst);
+ builder.emitBranch(conditionBlock);
+ returnInst->removeAndDeallocate();
+ }
+ }
+ builder.setInsertInto(conditionBlock);
+ builder.emitIntrinsicInst(builder.getVoidType(), kIROp_ControlBarrier, 0, nullptr);
+ auto index = getOrCreateBuiltinParamForHullShader(context, toSlice("SV_OutputControlPointID"));
+ auto condition = builder.emitEql(index, builder.getIntValue(builder.getIntType(), 0));
+ auto outputPatchArg = materializeValue(&builder, outputPatchVal);
+
+ List<IRInst*> args;
+ auto constantFunc = as<IRFunc>(patchConstantFuncDecor->getFunc());
+ for (auto param : constantFunc->getParams())
+ {
+ if (as<IRHLSLOutputPatchType>(param->getDataType()))
+ {
+ if (!outputPatchArg)
+ {
+ context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param);
+ return;
+ }
+ param->setFullType(outputPatchArg->getDataType());
+ args.add(outputPatchArg);
+ }
+ else if (auto inputPatchType = as<IRHLSLInputPatchType>(param->getDataType()))
+ {
+ if (!inputPatchArg)
+ {
+ context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param);
+ return;
+ }
+ auto arrayType = builder.getArrayType(inputPatchType->getElementType(), inputPatchType->getElementCount());
+ param->setFullType(arrayType);
+ args.add(inputPatchArg);
+ }
+ else
+ {
+ auto layout = findVarLayout(param);
+ if (!layout)
+ {
+ context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param);
+ return;
+ }
+ auto sysAttr = layout->findSystemValueSemanticAttr();
+ if (!sysAttr)
+ {
+ context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param);
+ return;
+ }
+ if (sysAttr->getName().caseInsensitiveEquals(toSlice("SV_OutputControlPointID")))
+ {
+ args.add(getOrCreateBuiltinParamForHullShader(context, toSlice("SV_OutputControlPointID")));
+ }
+ else if (sysAttr->getName().caseInsensitiveEquals(toSlice("SV_PrimitiveID")))
+ {
+ args.add(getOrCreateBuiltinParamForHullShader(context, toSlice("SV_PrimitiveID")));
+ }
+ else
+ {
+ context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param);
+ return;
+ }
+ }
+ }
+
+ IRBlock* trueBlock;
+ IRBlock* mergeBlock;
+ builder.emitIfWithBlocks(condition, trueBlock, mergeBlock);
+ builder.setInsertInto(trueBlock);
+ builder.emitCallInst(builder.getVoidType(), constantFunc, args.getArrayView());
+ builder.emitBranch(mergeBlock);
+ builder.setInsertInto(mergeBlock);
+ builder.emitReturn();
+ fixUpFuncType(entryPoint, builder.getVoidType());
+
+ if (auto readNoneDecor = constantFunc->findDecoration<IRReadNoneDecoration>())
+ readNoneDecor->removeAndDeallocate();
+ if (auto noSideEffectDecor = constantFunc->findDecoration<IRNoSideEffectDecoration>())
+ noSideEffectDecor->removeAndDeallocate();
+
+ builder.setInsertBefore(constantFunc->getFirstBlock()->getFirstOrdinaryInst());
+
+ auto constantOutputType = constantFunc->getResultType();
+ IRTypeLayout* constantOutputLayout = createPatchConstantFuncResultTypeLayout(builder, constantOutputType);
+ IRVarLayout::Builder resultVarLayoutBuilder(&builder, constantOutputLayout);
+ if (auto semanticDecor = constantFunc->findDecoration<IRSemanticDecoration>())
+ resultVarLayoutBuilder.setSystemValueSemantic(semanticDecor->getSemanticName(), 0);
+
+ context->entryPointFunc = constantFunc;
+ context->stage = Stage::Unknown;
+ legalizeEntryPointReturnValueForGLSL(context, codeGenContext, builder, constantFunc, resultVarLayoutBuilder.build());
+ context->entryPointFunc = entryPoint;
+ context->stage = Stage::Hull;
+
+ fixUpFuncType(constantFunc);
+}
+
ScalarizedVal createSimpleGLSLGlobalVarying(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
@@ -1561,10 +1763,26 @@ ScalarizedVal createGLSLGlobalVaryings(
OuterParamInfoLink outerParamInfo;
outerParamInfo.next = nullptr;
outerParamInfo.outerParam = leafVar;
+
+ GlobalVaryingDeclarator* declarator = nullptr;
+ GlobalVaryingDeclarator arrayDeclarator;
+ if (stage == Stage::Hull && kind == LayoutResourceKind::VaryingOutput)
+ {
+ // Hull shader's output should be materialized into an array.
+ auto outputControlPointsDecor = context->entryPointFunc->findDecoration<IROutputControlPointsDecoration>();
+ if (outputControlPointsDecor)
+ {
+ arrayDeclarator.flavor = GlobalVaryingDeclarator::Flavor::array;
+ arrayDeclarator.next = nullptr;
+ arrayDeclarator.elementCount = outputControlPointsDecor->getControlPointCount();
+ declarator = &arrayDeclarator;
+ }
+ }
+
return createGLSLGlobalVaryingsImpl(
context,
codeGenContext,
- builder, type, layout, layout->getTypeLayout(), kind, stage, bindingIndex, bindingSpace, nullptr, &outerParamInfo, leafVar, namehintSB);
+ builder, type, layout, layout->getTypeLayout(), kind, stage, bindingIndex, bindingSpace, declarator, &outerParamInfo, leafVar, namehintSB);
}
ScalarizedVal extractField(
@@ -2090,6 +2308,35 @@ static void legalizeMeshPayloadInputParam(
specializeFunctionCalls(codeGenContext, builder->getModule(), &condition);
}
+static void legalizePatchParam(
+ GLSLLegalizationContext* context,
+ CodeGenContext* codeGenContext,
+ IRFunc* func,
+ IRParam* pp,
+ IRVarLayout* paramLayout,
+ IRHLSLPatchType* patchType)
+{
+ auto builder = context->getBuilder();
+ auto elementType = patchType->getElementType();
+ auto elementCount = patchType->getElementCount();
+ auto arrayType = builder->getArrayType(elementType, elementCount);
+
+ auto globalPatchVal = createGLSLGlobalVaryings(
+ context,
+ codeGenContext,
+ builder,
+ arrayType,
+ paramLayout,
+ LayoutResourceKind::VaryingInput,
+ Stage::Hull, // Doesn't matter whether we are in Hull or Domain shader.
+ pp);
+
+ builder->setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto materializedVal = materializeValue(builder, globalPatchVal);
+ pp->transferDecorationsTo(materializedVal);
+ pp->replaceUsesWith(materializedVal);
+}
+
static void legalizeMeshOutputParam(
GLSLLegalizationContext* context,
CodeGenContext* codeGenContext,
@@ -2725,6 +2972,10 @@ void legalizeEntryPointParameterForGLSL(
{
return legalizeMeshOutputParam(context, codeGenContext, func, pp, paramLayout, meshOutputType);
}
+ if (auto patchType = as<IRHLSLPatchType>(valueType))
+ {
+ return legalizePatchParam(context, codeGenContext, func, pp, paramLayout, patchType);
+ }
if(pp->findDecoration<IRHLSLMeshPayloadDecoration>())
{
return legalizeMeshPayloadInputParam(context, codeGenContext, pp);
@@ -3029,6 +3280,92 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module)
}
}
+void rewriteReturnToOutputStore(IRBuilder& builder, IRFunc* func, ScalarizedVal resultGlobal)
+{
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ auto returnInst = as<IRReturn>(bb->getTerminator());
+ if (!returnInst)
+ continue;
+
+ IRInst* returnValue = returnInst->getVal();
+
+ // Make sure we add these instructions to the right block
+ builder.setInsertInto(bb);
+
+ // Write to our global variable(s) from the value being returned.
+ assign(&builder, resultGlobal, ScalarizedVal::value(returnValue));
+
+ // Emit a `return void_val` to end the block
+ builder.emitReturn();
+
+ // Remove the old `returnVal` instruction.
+ returnInst->removeAndDeallocate();
+ }
+}
+
+ScalarizedVal legalizeEntryPointReturnValueForGLSL(
+ GLSLLegalizationContext* context,
+ CodeGenContext* codeGenContext,
+ IRBuilder& builder,
+ IRFunc* func,
+ IRVarLayout* resultLayout)
+{
+ ScalarizedVal result;
+ auto resultType = func->getResultType();
+ if (as<IRVoidType>(resultType))
+ {
+ // In this case, the function doesn't return a value
+ // so we don't need to transform its `return` sites.
+ //
+ // We can also use this opportunity to quickly
+ // check if the function has any parameters, and if
+ // it doesn't use the chance to bail out immediately.
+ if (func->getParamCount() == 0)
+ {
+ // This function is already legal for GLSL
+ // (at least in terms of parameter/result signature),
+ // so we won't bother doing anything at all.
+ return result;
+ }
+
+ // If the function does have parameters, then we need
+ // to let the logic later in this function handle them.
+ }
+ else
+ {
+ // Function returns a value, so we need
+ // to introduce a new global variable
+ // to hold that value, and then replace
+ // any `returnVal` instructions with
+ // code to write to that variable.
+
+ ScalarizedVal resultGlobal = createGLSLGlobalVaryings(
+ context,
+ codeGenContext,
+ &builder,
+ resultType,
+ resultLayout,
+ LayoutResourceKind::VaryingOutput,
+ context->stage,
+ func);
+ result = resultGlobal;
+
+ if (auto entryPointDecor = func->findDecoration<IREntryPointDecoration>())
+ {
+ if (entryPointDecor->getProfile().getStage() == Stage::Hull)
+ {
+ builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst());
+ auto index = getOrCreateBuiltinParamForHullShader(context, toSlice("SV_OutputControlPointID"));
+ resultGlobal = getSubscriptVal(&builder, resultType, resultGlobal, index);
+ }
+ }
+ rewriteReturnToOutputStore(builder, func, resultGlobal);
+
+ }
+ return result;
+}
+
void legalizeEntryPointForGLSL(
Session* session,
IRModule* module,
@@ -3052,6 +3389,7 @@ void legalizeEntryPointForGLSL(
GLSLLegalizationContext context;
context.session = session;
context.stage = stage;
+ context.entryPointFunc = func;
context.sink = codeGenContext->getSink();
context.glslExtensionTracker = glslExtensionTracker;
@@ -3081,6 +3419,14 @@ void legalizeEntryPointForGLSL(
break;
}
+ // For hull shaders, we need to convert it to single return form, because
+ // we need to insert a barrier after the main body, then invoke the
+ // patch constant function after the barrier.
+ if (stage == Stage::Hull)
+ {
+ convertFuncToSingleReturnForm(module, func);
+ }
+
// We create a dummy IR builder, since some of
// the functions require it.
//
@@ -3105,75 +3451,14 @@ void legalizeEntryPointForGLSL(
// Specifically, we need to check if the function has
// a `void` return type, because there is no work
// to be done on its return value in that case.
- auto resultType = func->getResultType();
- if(as<IRVoidType>(resultType))
- {
- // In this case, the function doesn't return a value
- // so we don't need to transform its `return` sites.
- //
- // We can also use this opportunity to quickly
- // check if the function has any parameters, and if
- // it doesn't use the chance to bail out immediately.
- if( func->getParamCount() == 0 )
- {
- // This function is already legal for GLSL
- // (at least in terms of parameter/result signature),
- // so we won't bother doing anything at all.
- return;
- }
+ auto scalarizedGlobalOutput = legalizeEntryPointReturnValueForGLSL(
+ &context, codeGenContext, builder, func, entryPointLayout->getResultLayout());
- // If the function does have parameters, then we need
- // to let the logic later in this function handle them.
- }
- else
+ // For hull shaders, insert the invocation of the patch constant function
+ // at the end of the entrypoint now.
+ if (stage == Stage::Hull)
{
- // Function returns a value, so we need
- // to introduce a new global variable
- // to hold that value, and then replace
- // any `returnVal` instructions with
- // code to write to that variable.
-
- auto resultGlobal = createGLSLGlobalVaryings(
- &context,
- codeGenContext,
- &builder,
- resultType,
- entryPointLayout->getResultLayout(),
- LayoutResourceKind::VaryingOutput,
- stage,
- func);
-
- for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() )
- {
- // TODO: This is silly, because we are looking at every instruction,
- // when we know that a `returnVal` should only ever appear as a
- // terminator...
- for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() )
- {
- if(ii->getOp() != kIROp_Return)
- continue;
-
- IRReturn* returnInst = (IRReturn*) ii;
- IRInst* returnValue = returnInst->getVal();
-
- // Make sure we add these instructions to the right block
- builder.setInsertInto(bb);
-
- // Write to our global variable(s) from the value being returned.
- assign(&builder, resultGlobal, ScalarizedVal::value(returnValue));
-
- // Emit a `return void_val` to end the block
- auto returnVoid = builder.emitReturn();
-
- // Remove the old `returnVal` instruction.
- returnInst->removeAndDeallocate();
-
- // Make sure to resume our iteration at an
- // appropriate instruciton, since we deleted
- // the one we had been using.
- ii = returnVoid;
- }
- }
+ invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput);
}
// Next we will walk through any parameters of the entry-point function,
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index ab67dc4bf..f639d3343 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -660,6 +660,8 @@ INST(SampleGrad, sampleGrad, 4, 0)
INST(GroupMemoryBarrierWithGroupSync, GroupMemoryBarrierWithGroupSync, 0, 0)
+INST(ControlBarrier, ControlBarrier, 0, 0)
+
// GPU_FOREACH loop of the form
INST(GpuForeach, gpuForeach, 3, 0)
diff --git a/source/slang/slang-ir-specialize-target-switch.cpp b/source/slang/slang-ir-specialize-target-switch.cpp
index 46ea51192..e3ef06e18 100644
--- a/source/slang/slang-ir-specialize-target-switch.cpp
+++ b/source/slang/slang-ir-specialize-target-switch.cpp
@@ -9,6 +9,16 @@ namespace Slang
{
void specializeTargetSwitch(TargetRequest* target, IRGlobalValueWithCode* code, DiagnosticSink* sink)
{
+ if (auto gen = as<IRGeneric>(code))
+ {
+ auto retVal = findGenericReturnVal(gen);
+ if (auto innerCode = as<IRGlobalValueWithCode>(retVal))
+ {
+ specializeTargetSwitch(target, innerCode, sink);
+ return;
+ }
+ }
+
bool changed = false;
for (auto block : code->getBlocks())
{
@@ -76,14 +86,6 @@ namespace Slang
if (auto code = as<IRGlobalValueWithCode>(globalInst))
{
specializeTargetSwitch(target, code, sink);
- if (auto gen = as<IRGeneric>(code))
- {
- auto retVal = findGenericReturnVal(gen);
- if (auto innerCode = as<IRGlobalValueWithCode>(retVal))
- {
- specializeTargetSwitch(target, innerCode, sink);
- }
- }
}
}
}
diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp
index 56c85e9cb..cfbd7f2c5 100644
--- a/source/slang/slang-parameter-binding.cpp
+++ b/source/slang/slang-parameter-binding.cpp
@@ -2092,6 +2092,23 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter(
return arrayTypeLayout;
}
+ else if (auto patchType = as<HLSLPatchType>(type))
+ {
+ // Similar to the MeshOutput case, a `InputPatch` or `OutputPatch` type is just like an array.
+ //
+ auto elementTypeLayout = processEntryPointVaryingParameter(context, patchType->getElementType(), state, varLayout);
+
+ RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout();
+ arrayTypeLayout->elementTypeLayout = elementTypeLayout;
+ arrayTypeLayout->type = arrayType;
+
+ for (auto rr : elementTypeLayout->resourceInfos)
+ {
+ arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count;
+ }
+
+ return arrayTypeLayout;
+ }
// Ignore a bunch of types that don't make sense here...
else if (const auto subpassType = as<SubpassInputType>(type)) { return nullptr; }
else if (const auto textureType = as<TextureType>(type)) { return nullptr; }
diff --git a/tests/spirv/tessellation.slang b/tests/spirv/tessellation.slang
new file mode 100644
index 000000000..deb6ed298
--- /dev/null
+++ b/tests/spirv/tessellation.slang
@@ -0,0 +1,65 @@
+//TEST:SIMPLE(filecheck=CHECK): -target spirv
+
+// CHECK-DAG: OpExecutionMode %main SpacingEqual
+
+// CHECK-DAG: OpExecutionMode %main OutputVertices 4
+
+// CHECK-DAG: OpExecutionMode %main VertexOrderCw
+
+// CHECK-DAG: OpExecutionMode %main Quads
+
+// CHECK: OpDecorate %gl_TessLevelOuter BuiltIn TessLevelOuter
+// CHECK: OpDecorate %gl_TessLevelOuter Patch
+// CHECK: OpDecorate %gl_TessLevelInner BuiltIn TessLevelInner
+// CHECK: OpDecorate %gl_TessLevelInner Patch
+
+// CHECK: OpControlBarrier %uint_2 %uint_4 %uint_0
+
+// CHECK: OpStore %gl_TessLevelOuter
+// CHECK: OpStore %gl_TessLevelInner
+
+struct VS_OUT
+{
+ float3 position : POSITION;
+};
+
+struct HS_OUT
+{
+ float3 position : POSITION;
+};
+
+struct HSC_OUT
+{
+ float EdgeTessFactor[4] : SV_TessFactor;
+ float InsideTessFactor[2] : SV_InsideTessFactor;
+};
+
+// Hull Shader (HS)
+[domain("quad")]
+[partitioning("integer")]
+[outputtopology("triangle_cw")]
+[outputcontrolpoints(4)]
+[patchconstantfunc("constants")]
+HS_OUT main(InputPatch<VS_OUT, 4> patch, uint i : SV_OutputControlPointID)
+{
+ HS_OUT o;
+ o.position = patch[i].position;
+ return o;
+}
+
+HSC_OUT constants(InputPatch<VS_OUT, 4> patch)
+{
+ float3 p0 = patch[0].position;
+ float3 p1 = patch[1].position;
+ float3 p2 = patch[2].position;
+ float3 p3 = patch[3].position;
+
+ HSC_OUT o;
+ o.EdgeTessFactor[0] = dot(p0, p1);
+ o.EdgeTessFactor[1] = dot(p0, p3);
+ o.EdgeTessFactor[2] = dot(p2, p3);
+ o.EdgeTessFactor[3] = dot(p1, p2);
+ o.InsideTessFactor[0] = lerp(o.EdgeTessFactor[1], o.EdgeTessFactor[3], 0.5);
+ o.InsideTessFactor[1] = lerp(o.EdgeTessFactor[0], o.EdgeTessFactor[2], 0.5);
+ return o;
+} \ No newline at end of file