diff options
| -rw-r--r-- | source/slang/hlsl.meta.slang | 35 | ||||
| -rw-r--r-- | source/slang/slang-compiler.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 102 | ||||
| -rw-r--r-- | source/slang/slang-ir-glsl-legalize.cpp | 429 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize-target-switch.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-parameter-binding.cpp | 17 | ||||
| -rw-r--r-- | tests/spirv/tessellation.slang | 65 |
9 files changed, 604 insertions, 92 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1fac47588..c03c47703 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -3799,13 +3799,30 @@ struct ConsumeStructuredBuffer } }; +__intrinsic_op($(kIROp_GetElement)) +T __getElement<T, U, I>(U collection, I index); + __generic<T, let N : int> [require(glsl_hlsl_spirv, hull)] __magic_type(HLSLInputPatchType) __intrinsic_type($(kIROp_HLSLInputPatchType)) struct InputPatch { - __subscript(uint index) -> T; + __generic<TIndex : __BuiltinIntegerType> + __subscript(TIndex index)->T + { + [__unsafeForceInlineEarly] + get + { + __target_switch + { + case hlsl: + __intrinsic_asm ".operator[]"; + default: + return __getElement<T>(this, index); + } + } + } }; __generic<T, let N : int> @@ -3814,7 +3831,21 @@ __magic_type(HLSLOutputPatchType) __intrinsic_type($(kIROp_HLSLOutputPatchType)) struct OutputPatch { - __subscript(uint index) -> T; + __generic<TIndex : __BuiltinIntegerType> + __subscript(TIndex index)->T + { + [__unsafeForceInlineEarly] + get + { + __target_switch + { + case hlsl: + __intrinsic_asm ".operator[]"; + default: + return __getElement<T>(this, index); + } + } + } }; ${{{{ diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 0277bb092..ed208ca37 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -2536,12 +2536,27 @@ namespace Slang if (allTargetsCUDARelated && targets.getCount() > 0) continue; - auto numThreadsAttr = funcDecl->findModifier<NumThreadsAttribute>(); - if (numThreadsAttr) - profile.setStage(Stage::Compute); - else + bool canDetermineStage = false; + for (auto modifier : funcDecl->modifiers) + { + if (as<NumThreadsAttribute>(modifier)) + { + if (funcDecl->findModifier<OutputTopologyAttribute>()) + profile.setStage(Stage::Mesh); + else + profile.setStage(Stage::Compute); + canDetermineStage = true; + break; + } + else if (as<PatchConstantFuncAttribute>(modifier)) + { + profile.setStage(Stage::Hull); + canDetermineStage = true; + break; + } + } + if (!canDetermineStage) continue; - } RefPtr<EntryPoint> entryPoint = EntryPoint::create( diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index dd95b862f..eb7b5b993 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -694,6 +694,7 @@ DIAGNOSTIC(39025, Error, conflictingVulkanInferredBindingForParameter, "conflict DIAGNOSTIC(39026, Error, matrixLayoutModifierOnNonMatrixType, "matrix layout modifier cannot be used on non-matrix type '$0'.") DIAGNOSTIC(39027, Error, getAttributeAtVertexMustReferToPerVertexInput, "'GetAttributeAtVertex' must reference a vertex input directly, and the vertex input must be decorated with 'pervertex' or 'nointerpolation'.") + // // 4xxxx - IL code generation. @@ -843,6 +844,8 @@ DIAGNOSTIC(56001, Error, unableToAutoMapCUDATypeToHostType, "Could not automatic DIAGNOSTIC(57001, Warning, spirvOptFailed, "spirv-opt failed. $0") +DIAGNOSTIC(57002, Error, unknownPatchConstantParameter, "unknown patch constant parameter '$0'.") +DIAGNOSTIC(57003, Error, unknownTessPartitioning, "unknown tessellation partitioning '$0'.") // GLSL Compatibility DIAGNOSTIC(58001, Error, entryPointMustReturnVoidWhenGlobalOutputPresent, "entry point must return 'void' when global output variables are present.") diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 4f7410f00..fd4b1d491 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2963,6 +2963,15 @@ struct SPIRVEmitContext result = emitOpAtomicIDecrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); } break; + case kIROp_ControlBarrier: + { + IRBuilder builder{ inst }; + const auto executionScope = emitIntConstant(IRIntegerValue{ SpvScopeWorkgroup }, builder.getUIntType()); + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeInvocation }, builder.getUIntType()); + const auto memorySemantics = emitIntConstant(IRIntegerValue{ SpvMemorySemanticsMaskNone }, builder.getUIntType()); + emitInst(parent, inst, SpvOpControlBarrier, executionScope, memoryScope, memorySemantics); + } + break; } if (result) emitDecorations(inst, getID(result)); @@ -3323,6 +3332,29 @@ struct SPIRVEmitContext requireSPIRVCapability(SpvCapabilityMeshShadingEXT); ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_mesh_shader")); break; + case Stage::Hull: + { + requireSPIRVCapability(SpvCapabilityTessellation); + + SpvExecutionMode mode = SpvExecutionModeSpacingEqual; + if (auto partitioningDecor = entryPoint->findDecoration<IRPartitioningDecoration>()) + { + auto arg = partitioningDecor->getPartitioning()->getStringSlice(); + if (arg.caseInsensitiveEquals(toSlice("integer"))) + mode = SpvExecutionModeSpacingEqual; + else if (arg.caseInsensitiveEquals(toSlice("fractional_even"))) + mode = SpvExecutionModeSpacingFractionalEven; + else if (arg.caseInsensitiveEquals(toSlice("fractional_odd"))) + mode = SpvExecutionModeSpacingFractionalOdd; + else + m_sink->diagnose(partitioningDecor, Diagnostics::unknownTessPartitioning, arg); + } + requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), mode); + break; + } + case Stage::Domain: + requireSPIRVCapability(SpvCapabilityTessellation); + break; default: break; } @@ -3463,13 +3495,36 @@ struct SPIRVEmitContext case kIROp_OutputTopologyDecoration: { + auto entryPoint = decoration->getParent(); + IREntryPointDecoration* entryPointDecor = entryPoint ? entryPoint->findDecoration<IREntryPointDecoration>() : nullptr; + const auto o = cast<IROutputTopologyDecoration>(decoration); const auto t = o->getTopology()->getStringSlice(); - const auto m = - t == "triangle" ? SpvExecutionModeOutputTrianglesEXT - : t == "line" ? SpvExecutionModeOutputLinesEXT - : t == "point" ? SpvExecutionModeOutputPoints - : SpvExecutionModeMax; + + SpvExecutionMode m = SpvExecutionModeMax; + if (entryPointDecor) + { + switch (entryPointDecor->getProfile().getStage()) + { + case Stage::Domain: + case Stage::Hull: + if (t == "triangle_cw") + m = SpvExecutionModeVertexOrderCw; + else if (t == "triangle_ccw") + m = SpvExecutionModeVertexOrderCcw; + break; + } + } + if (m == SpvExecutionModeMax) + { + if (t == "triangle") + m = SpvExecutionModeOutputTrianglesEXT; + else if (t == "line") + m = SpvExecutionModeOutputTrianglesEXT; + else if (t == "point") + m = SpvExecutionModeOutputPoints; + } + SLANG_ASSERT(m != SpvExecutionModeMax); requireSPIRVExecutionMode(decoration, dstID, m); } @@ -3544,6 +3599,31 @@ struct SPIRVEmitContext dstID, SpvDecorationPerVertexKHR); break; + case kIROp_OutputControlPointsDecoration: + requireSPIRVExecutionMode( + decoration, + dstID, + SpvExecutionModeOutputVertices, + SpvLiteralInteger::from32(int32_t(getIntVal(decoration->getOperand(0))))); + break; + case kIROp_DomainDecoration: + { + auto domain = cast<IRDomainDecoration>(decoration); + SpvExecutionMode mode = SpvExecutionModeMax; + auto domainName = as<IRStringLit>(domain->getDomain()); + if (!domainName) + break; + auto domainStr = domainName->getStringSlice(); + if (domainStr.startsWithCaseInsensitive(toSlice("tri"))) + mode = SpvExecutionModeTriangles; + else if (domainStr.caseInsensitiveEquals(toSlice("quad"))) + mode = SpvExecutionModeQuads; + else if (domainStr.caseInsensitiveEquals(toSlice("isoline"))) + mode = SpvExecutionModeIsolines; + if (mode != SpvExecutionModeMax) + requireSPIRVExecutionMode(decoration, dstID, mode); + } + break; case kIROp_MemoryQualifierSetDecoration: { auto collection = as<IRMemoryQualifierSetDecoration>(decoration); @@ -3941,6 +4021,18 @@ struct SPIRVEmitContext varInst, builtinVal ); + switch (builtinVal) + { + case SpvBuiltInTessLevelInner: + case SpvBuiltInTessLevelOuter: + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + nullptr, + varInst, + SpvDecorationPatch + ); + break; + } m_builtinGlobalVars[key] = varInst; maybeEmitFlatDecorationForBuiltinVar(irInst, varInst); diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 5be700be1..bcf2d8a4f 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -9,7 +9,7 @@ #include "slang-ir-specialize-function-call.h" #include "slang-ir-util.h" #include "slang-ir-clone.h" - +#include "slang-ir-single-return.h" #include "slang-glsl-extension-tracker.h" #include "../../external/spirv-headers/include/spirv/unified1/spirv.h" @@ -293,6 +293,10 @@ struct ScalarizedVal RefPtr<ScalarizedValImpl> impl; }; +IRInst* materializeValue( + IRBuilder* builder, + ScalarizedVal const& val); + // This is the case for a value that is a "tuple" of other values struct ScalarizedTupleValImpl : ScalarizedValImpl { @@ -315,9 +319,6 @@ struct ScalarizedTypeAdapterValImpl : ScalarizedValImpl IRType* pretendType; // the type this value pretends to have }; - - - struct GlobalVaryingDeclarator { enum class Flavor @@ -404,6 +405,7 @@ struct GLSLLegalizationContext GLSLExtensionTracker* glslExtensionTracker; DiagnosticSink* sink; Stage stage; + IRFunc* entryPointFunc; struct SystemSemanticGlobal { @@ -1056,6 +1058,206 @@ void createVarLayoutForLegalizedGlobalParam( } } +IRInst* getOrCreateBuiltinParamForHullShader(GLSLLegalizationContext* context, UnownedStringSlice builtinSemantic) +{ + IRInst* outputControlPointIdParam = nullptr; + if (context->stage == Stage::Hull) + { + for (auto param : context->entryPointFunc->getParams()) + { + auto layout = findVarLayout(param); + if (!layout) + continue; + auto sysAttr = layout->findSystemValueSemanticAttr(); + if (!sysAttr) + continue; + if (sysAttr->getName().caseInsensitiveEquals(builtinSemantic)) + { + outputControlPointIdParam = param; + break; + } + } + if (!outputControlPointIdParam) + { + IRBuilder builder(context->entryPointFunc); + auto paramType = builder.getIntType(); + builder.setInsertInto(context->entryPointFunc->getFirstBlock()->getFirstOrdinaryInst()); + outputControlPointIdParam = builder.emitParam(paramType); + IRStructTypeLayout::Builder typeBuilder(&builder); + auto typeLayout = typeBuilder.build(); + IRVarLayout::Builder varLayoutBuilder(&builder, typeLayout); + varLayoutBuilder.setSystemValueSemantic(builtinSemantic, 0); + varLayoutBuilder.findOrAddResourceInfo(LayoutResourceKind::VaryingInput); + auto varLayout = varLayoutBuilder.build(); + builder.addLayoutDecoration(outputControlPointIdParam, varLayout); + } + } + return outputControlPointIdParam; +} + +IRTypeLayout* createPatchConstantFuncResultTypeLayout(IRBuilder& irBuilder, IRType* type) +{ + if (auto structType = as<IRStructType>(type)) + { + IRStructTypeLayout::Builder builder(&irBuilder); + for (auto field : structType->getFields()) + { + auto fieldType = field->getFieldType(); + + IRTypeLayout* fieldTypeLayout = createPatchConstantFuncResultTypeLayout(irBuilder, fieldType); + IRVarLayout::Builder fieldVarLayoutBuilder(&irBuilder, fieldTypeLayout); + auto decoration = field->getKey()->findDecoration<IRSemanticDecoration>(); + if (decoration) + { + if (decoration->getSemanticName().startsWithCaseInsensitive(toSlice("sv_"))) + fieldVarLayoutBuilder.setSystemValueSemantic(decoration->getSemanticName(), 0); + } + builder.addField(field->getKey(), fieldVarLayoutBuilder.build()); + } + auto typeLayout = builder.build(); + return typeLayout; + } + else if (auto arrayType = as<IRArrayTypeBase>(type)) + { + auto elementTypeLayout = createPatchConstantFuncResultTypeLayout(irBuilder, arrayType->getElementType()); + IRArrayTypeLayout::Builder builder(&irBuilder, elementTypeLayout); + return builder.build(); + } + else + { + IRTypeLayout::Builder builder(&irBuilder); + builder.addResourceUsage(LayoutResourceKind::VaryingOutput, LayoutSize::fromRaw(1)); + return builder.build(); + } +} + +ScalarizedVal legalizeEntryPointReturnValueForGLSL( + GLSLLegalizationContext* context, + CodeGenContext* codeGenContext, + IRBuilder& builder, + IRFunc* func, + IRVarLayout* resultLayout); + +void invokePathConstantFuncInHullShader(GLSLLegalizationContext* context, CodeGenContext* codeGenContext, ScalarizedVal outputPatchVal) +{ + auto entryPoint = context->entryPointFunc; + auto patchConstantFuncDecor = entryPoint->findDecoration<IRPatchConstantFuncDecoration>(); + if (!patchConstantFuncDecor) + return; + IRInst* inputPatchArg = nullptr; + for (auto param : entryPoint->getParams()) + { + if (as<IRHLSLInputPatchType>(param->getDataType())) + { + inputPatchArg = param; + break; + } + } + IRBuilder builder(entryPoint); + builder.setInsertInto(entryPoint); + IRBlock* conditionBlock = builder.emitBlock(); + for (auto block : entryPoint->getBlocks()) + { + if (auto returnInst = as<IRReturn>(block->getTerminator())) + { + builder.setInsertBefore(returnInst); + builder.emitBranch(conditionBlock); + returnInst->removeAndDeallocate(); + } + } + builder.setInsertInto(conditionBlock); + builder.emitIntrinsicInst(builder.getVoidType(), kIROp_ControlBarrier, 0, nullptr); + auto index = getOrCreateBuiltinParamForHullShader(context, toSlice("SV_OutputControlPointID")); + auto condition = builder.emitEql(index, builder.getIntValue(builder.getIntType(), 0)); + auto outputPatchArg = materializeValue(&builder, outputPatchVal); + + List<IRInst*> args; + auto constantFunc = as<IRFunc>(patchConstantFuncDecor->getFunc()); + for (auto param : constantFunc->getParams()) + { + if (as<IRHLSLOutputPatchType>(param->getDataType())) + { + if (!outputPatchArg) + { + context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param); + return; + } + param->setFullType(outputPatchArg->getDataType()); + args.add(outputPatchArg); + } + else if (auto inputPatchType = as<IRHLSLInputPatchType>(param->getDataType())) + { + if (!inputPatchArg) + { + context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param); + return; + } + auto arrayType = builder.getArrayType(inputPatchType->getElementType(), inputPatchType->getElementCount()); + param->setFullType(arrayType); + args.add(inputPatchArg); + } + else + { + auto layout = findVarLayout(param); + if (!layout) + { + context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param); + return; + } + auto sysAttr = layout->findSystemValueSemanticAttr(); + if (!sysAttr) + { + context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param); + return; + } + if (sysAttr->getName().caseInsensitiveEquals(toSlice("SV_OutputControlPointID"))) + { + args.add(getOrCreateBuiltinParamForHullShader(context, toSlice("SV_OutputControlPointID"))); + } + else if (sysAttr->getName().caseInsensitiveEquals(toSlice("SV_PrimitiveID"))) + { + args.add(getOrCreateBuiltinParamForHullShader(context, toSlice("SV_PrimitiveID"))); + } + else + { + context->getSink()->diagnose(param->sourceLoc, Diagnostics::unknownPatchConstantParameter, param); + return; + } + } + } + + IRBlock* trueBlock; + IRBlock* mergeBlock; + builder.emitIfWithBlocks(condition, trueBlock, mergeBlock); + builder.setInsertInto(trueBlock); + builder.emitCallInst(builder.getVoidType(), constantFunc, args.getArrayView()); + builder.emitBranch(mergeBlock); + builder.setInsertInto(mergeBlock); + builder.emitReturn(); + fixUpFuncType(entryPoint, builder.getVoidType()); + + if (auto readNoneDecor = constantFunc->findDecoration<IRReadNoneDecoration>()) + readNoneDecor->removeAndDeallocate(); + if (auto noSideEffectDecor = constantFunc->findDecoration<IRNoSideEffectDecoration>()) + noSideEffectDecor->removeAndDeallocate(); + + builder.setInsertBefore(constantFunc->getFirstBlock()->getFirstOrdinaryInst()); + + auto constantOutputType = constantFunc->getResultType(); + IRTypeLayout* constantOutputLayout = createPatchConstantFuncResultTypeLayout(builder, constantOutputType); + IRVarLayout::Builder resultVarLayoutBuilder(&builder, constantOutputLayout); + if (auto semanticDecor = constantFunc->findDecoration<IRSemanticDecoration>()) + resultVarLayoutBuilder.setSystemValueSemantic(semanticDecor->getSemanticName(), 0); + + context->entryPointFunc = constantFunc; + context->stage = Stage::Unknown; + legalizeEntryPointReturnValueForGLSL(context, codeGenContext, builder, constantFunc, resultVarLayoutBuilder.build()); + context->entryPointFunc = entryPoint; + context->stage = Stage::Hull; + + fixUpFuncType(constantFunc); +} + ScalarizedVal createSimpleGLSLGlobalVarying( GLSLLegalizationContext* context, CodeGenContext* codeGenContext, @@ -1561,10 +1763,26 @@ ScalarizedVal createGLSLGlobalVaryings( OuterParamInfoLink outerParamInfo; outerParamInfo.next = nullptr; outerParamInfo.outerParam = leafVar; + + GlobalVaryingDeclarator* declarator = nullptr; + GlobalVaryingDeclarator arrayDeclarator; + if (stage == Stage::Hull && kind == LayoutResourceKind::VaryingOutput) + { + // Hull shader's output should be materialized into an array. + auto outputControlPointsDecor = context->entryPointFunc->findDecoration<IROutputControlPointsDecoration>(); + if (outputControlPointsDecor) + { + arrayDeclarator.flavor = GlobalVaryingDeclarator::Flavor::array; + arrayDeclarator.next = nullptr; + arrayDeclarator.elementCount = outputControlPointsDecor->getControlPointCount(); + declarator = &arrayDeclarator; + } + } + return createGLSLGlobalVaryingsImpl( context, codeGenContext, - builder, type, layout, layout->getTypeLayout(), kind, stage, bindingIndex, bindingSpace, nullptr, &outerParamInfo, leafVar, namehintSB); + builder, type, layout, layout->getTypeLayout(), kind, stage, bindingIndex, bindingSpace, declarator, &outerParamInfo, leafVar, namehintSB); } ScalarizedVal extractField( @@ -2090,6 +2308,35 @@ static void legalizeMeshPayloadInputParam( specializeFunctionCalls(codeGenContext, builder->getModule(), &condition); } +static void legalizePatchParam( + GLSLLegalizationContext* context, + CodeGenContext* codeGenContext, + IRFunc* func, + IRParam* pp, + IRVarLayout* paramLayout, + IRHLSLPatchType* patchType) +{ + auto builder = context->getBuilder(); + auto elementType = patchType->getElementType(); + auto elementCount = patchType->getElementCount(); + auto arrayType = builder->getArrayType(elementType, elementCount); + + auto globalPatchVal = createGLSLGlobalVaryings( + context, + codeGenContext, + builder, + arrayType, + paramLayout, + LayoutResourceKind::VaryingInput, + Stage::Hull, // Doesn't matter whether we are in Hull or Domain shader. + pp); + + builder->setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto materializedVal = materializeValue(builder, globalPatchVal); + pp->transferDecorationsTo(materializedVal); + pp->replaceUsesWith(materializedVal); +} + static void legalizeMeshOutputParam( GLSLLegalizationContext* context, CodeGenContext* codeGenContext, @@ -2725,6 +2972,10 @@ void legalizeEntryPointParameterForGLSL( { return legalizeMeshOutputParam(context, codeGenContext, func, pp, paramLayout, meshOutputType); } + if (auto patchType = as<IRHLSLPatchType>(valueType)) + { + return legalizePatchParam(context, codeGenContext, func, pp, paramLayout, patchType); + } if(pp->findDecoration<IRHLSLMeshPayloadDecoration>()) { return legalizeMeshPayloadInputParam(context, codeGenContext, pp); @@ -3029,6 +3280,92 @@ void assignRayPayloadHitObjectAttributeLocations(IRModule* module) } } +void rewriteReturnToOutputStore(IRBuilder& builder, IRFunc* func, ScalarizedVal resultGlobal) +{ + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + auto returnInst = as<IRReturn>(bb->getTerminator()); + if (!returnInst) + continue; + + IRInst* returnValue = returnInst->getVal(); + + // Make sure we add these instructions to the right block + builder.setInsertInto(bb); + + // Write to our global variable(s) from the value being returned. + assign(&builder, resultGlobal, ScalarizedVal::value(returnValue)); + + // Emit a `return void_val` to end the block + builder.emitReturn(); + + // Remove the old `returnVal` instruction. + returnInst->removeAndDeallocate(); + } +} + +ScalarizedVal legalizeEntryPointReturnValueForGLSL( + GLSLLegalizationContext* context, + CodeGenContext* codeGenContext, + IRBuilder& builder, + IRFunc* func, + IRVarLayout* resultLayout) +{ + ScalarizedVal result; + auto resultType = func->getResultType(); + if (as<IRVoidType>(resultType)) + { + // In this case, the function doesn't return a value + // so we don't need to transform its `return` sites. + // + // We can also use this opportunity to quickly + // check if the function has any parameters, and if + // it doesn't use the chance to bail out immediately. + if (func->getParamCount() == 0) + { + // This function is already legal for GLSL + // (at least in terms of parameter/result signature), + // so we won't bother doing anything at all. + return result; + } + + // If the function does have parameters, then we need + // to let the logic later in this function handle them. + } + else + { + // Function returns a value, so we need + // to introduce a new global variable + // to hold that value, and then replace + // any `returnVal` instructions with + // code to write to that variable. + + ScalarizedVal resultGlobal = createGLSLGlobalVaryings( + context, + codeGenContext, + &builder, + resultType, + resultLayout, + LayoutResourceKind::VaryingOutput, + context->stage, + func); + result = resultGlobal; + + if (auto entryPointDecor = func->findDecoration<IREntryPointDecoration>()) + { + if (entryPointDecor->getProfile().getStage() == Stage::Hull) + { + builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); + auto index = getOrCreateBuiltinParamForHullShader(context, toSlice("SV_OutputControlPointID")); + resultGlobal = getSubscriptVal(&builder, resultType, resultGlobal, index); + } + } + rewriteReturnToOutputStore(builder, func, resultGlobal); + + } + return result; +} + void legalizeEntryPointForGLSL( Session* session, IRModule* module, @@ -3052,6 +3389,7 @@ void legalizeEntryPointForGLSL( GLSLLegalizationContext context; context.session = session; context.stage = stage; + context.entryPointFunc = func; context.sink = codeGenContext->getSink(); context.glslExtensionTracker = glslExtensionTracker; @@ -3081,6 +3419,14 @@ void legalizeEntryPointForGLSL( break; } + // For hull shaders, we need to convert it to single return form, because + // we need to insert a barrier after the main body, then invoke the + // patch constant function after the barrier. + if (stage == Stage::Hull) + { + convertFuncToSingleReturnForm(module, func); + } + // We create a dummy IR builder, since some of // the functions require it. // @@ -3105,75 +3451,14 @@ void legalizeEntryPointForGLSL( // Specifically, we need to check if the function has // a `void` return type, because there is no work // to be done on its return value in that case. - auto resultType = func->getResultType(); - if(as<IRVoidType>(resultType)) - { - // In this case, the function doesn't return a value - // so we don't need to transform its `return` sites. - // - // We can also use this opportunity to quickly - // check if the function has any parameters, and if - // it doesn't use the chance to bail out immediately. - if( func->getParamCount() == 0 ) - { - // This function is already legal for GLSL - // (at least in terms of parameter/result signature), - // so we won't bother doing anything at all. - return; - } + auto scalarizedGlobalOutput = legalizeEntryPointReturnValueForGLSL( + &context, codeGenContext, builder, func, entryPointLayout->getResultLayout()); - // If the function does have parameters, then we need - // to let the logic later in this function handle them. - } - else + // For hull shaders, insert the invocation of the patch constant function + // at the end of the entrypoint now. + if (stage == Stage::Hull) { - // Function returns a value, so we need - // to introduce a new global variable - // to hold that value, and then replace - // any `returnVal` instructions with - // code to write to that variable. - - auto resultGlobal = createGLSLGlobalVaryings( - &context, - codeGenContext, - &builder, - resultType, - entryPointLayout->getResultLayout(), - LayoutResourceKind::VaryingOutput, - stage, - func); - - for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) - { - // TODO: This is silly, because we are looking at every instruction, - // when we know that a `returnVal` should only ever appear as a - // terminator... - for( auto ii = bb->getFirstInst(); ii; ii = ii->getNextInst() ) - { - if(ii->getOp() != kIROp_Return) - continue; - - IRReturn* returnInst = (IRReturn*) ii; - IRInst* returnValue = returnInst->getVal(); - - // Make sure we add these instructions to the right block - builder.setInsertInto(bb); - - // Write to our global variable(s) from the value being returned. - assign(&builder, resultGlobal, ScalarizedVal::value(returnValue)); - - // Emit a `return void_val` to end the block - auto returnVoid = builder.emitReturn(); - - // Remove the old `returnVal` instruction. - returnInst->removeAndDeallocate(); - - // Make sure to resume our iteration at an - // appropriate instruciton, since we deleted - // the one we had been using. - ii = returnVoid; - } - } + invokePathConstantFuncInHullShader(&context, codeGenContext, scalarizedGlobalOutput); } // Next we will walk through any parameters of the entry-point function, diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index ab67dc4bf..f639d3343 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -660,6 +660,8 @@ INST(SampleGrad, sampleGrad, 4, 0) INST(GroupMemoryBarrierWithGroupSync, GroupMemoryBarrierWithGroupSync, 0, 0) +INST(ControlBarrier, ControlBarrier, 0, 0) + // GPU_FOREACH loop of the form INST(GpuForeach, gpuForeach, 3, 0) diff --git a/source/slang/slang-ir-specialize-target-switch.cpp b/source/slang/slang-ir-specialize-target-switch.cpp index 46ea51192..e3ef06e18 100644 --- a/source/slang/slang-ir-specialize-target-switch.cpp +++ b/source/slang/slang-ir-specialize-target-switch.cpp @@ -9,6 +9,16 @@ namespace Slang { void specializeTargetSwitch(TargetRequest* target, IRGlobalValueWithCode* code, DiagnosticSink* sink) { + if (auto gen = as<IRGeneric>(code)) + { + auto retVal = findGenericReturnVal(gen); + if (auto innerCode = as<IRGlobalValueWithCode>(retVal)) + { + specializeTargetSwitch(target, innerCode, sink); + return; + } + } + bool changed = false; for (auto block : code->getBlocks()) { @@ -76,14 +86,6 @@ namespace Slang if (auto code = as<IRGlobalValueWithCode>(globalInst)) { specializeTargetSwitch(target, code, sink); - if (auto gen = as<IRGeneric>(code)) - { - auto retVal = findGenericReturnVal(gen); - if (auto innerCode = as<IRGlobalValueWithCode>(retVal)) - { - specializeTargetSwitch(target, innerCode, sink); - } - } } } } diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index 56c85e9cb..cfbd7f2c5 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -2092,6 +2092,23 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( return arrayTypeLayout; } + else if (auto patchType = as<HLSLPatchType>(type)) + { + // Similar to the MeshOutput case, a `InputPatch` or `OutputPatch` type is just like an array. + // + auto elementTypeLayout = processEntryPointVaryingParameter(context, patchType->getElementType(), state, varLayout); + + RefPtr<ArrayTypeLayout> arrayTypeLayout = new ArrayTypeLayout(); + arrayTypeLayout->elementTypeLayout = elementTypeLayout; + arrayTypeLayout->type = arrayType; + + for (auto rr : elementTypeLayout->resourceInfos) + { + arrayTypeLayout->findOrAddResourceInfo(rr.kind)->count = rr.count; + } + + return arrayTypeLayout; + } // Ignore a bunch of types that don't make sense here... else if (const auto subpassType = as<SubpassInputType>(type)) { return nullptr; } else if (const auto textureType = as<TextureType>(type)) { return nullptr; } diff --git a/tests/spirv/tessellation.slang b/tests/spirv/tessellation.slang new file mode 100644 index 000000000..deb6ed298 --- /dev/null +++ b/tests/spirv/tessellation.slang @@ -0,0 +1,65 @@ +//TEST:SIMPLE(filecheck=CHECK): -target spirv + +// CHECK-DAG: OpExecutionMode %main SpacingEqual + +// CHECK-DAG: OpExecutionMode %main OutputVertices 4 + +// CHECK-DAG: OpExecutionMode %main VertexOrderCw + +// CHECK-DAG: OpExecutionMode %main Quads + +// CHECK: OpDecorate %gl_TessLevelOuter BuiltIn TessLevelOuter +// CHECK: OpDecorate %gl_TessLevelOuter Patch +// CHECK: OpDecorate %gl_TessLevelInner BuiltIn TessLevelInner +// CHECK: OpDecorate %gl_TessLevelInner Patch + +// CHECK: OpControlBarrier %uint_2 %uint_4 %uint_0 + +// CHECK: OpStore %gl_TessLevelOuter +// CHECK: OpStore %gl_TessLevelInner + +struct VS_OUT +{ + float3 position : POSITION; +}; + +struct HS_OUT +{ + float3 position : POSITION; +}; + +struct HSC_OUT +{ + float EdgeTessFactor[4] : SV_TessFactor; + float InsideTessFactor[2] : SV_InsideTessFactor; +}; + +// Hull Shader (HS) +[domain("quad")] +[partitioning("integer")] +[outputtopology("triangle_cw")] +[outputcontrolpoints(4)] +[patchconstantfunc("constants")] +HS_OUT main(InputPatch<VS_OUT, 4> patch, uint i : SV_OutputControlPointID) +{ + HS_OUT o; + o.position = patch[i].position; + return o; +} + +HSC_OUT constants(InputPatch<VS_OUT, 4> patch) +{ + float3 p0 = patch[0].position; + float3 p1 = patch[1].position; + float3 p2 = patch[2].position; + float3 p3 = patch[3].position; + + HSC_OUT o; + o.EdgeTessFactor[0] = dot(p0, p1); + o.EdgeTessFactor[1] = dot(p0, p3); + o.EdgeTessFactor[2] = dot(p2, p3); + o.EdgeTessFactor[3] = dot(p1, p2); + o.InsideTessFactor[0] = lerp(o.EdgeTessFactor[1], o.EdgeTessFactor[3], 0.5); + o.InsideTessFactor[1] = lerp(o.EdgeTessFactor[0], o.EdgeTessFactor[2], 0.5); + return o; +}
\ No newline at end of file |
