From 7c162eba5329eae7755e55298a455a144fcb0dce Mon Sep 17 00:00:00 2001 From: sriramm-nv <85252063+sriramm-nv@users.noreply.github.com> Date: Fri, 19 Apr 2024 09:12:56 -0700 Subject: Enable NonUniformResourceIndex support for glsl, hlsl and spirv (#3899) Fixes #387676* ForceInline SampleLevel to allow decorations to apply * explictly add all the SPIRVAsmOperand Insts in non-differentiable list, which might get inadvertently processed when these functions are inlined into the main shader * Support NonUniformResourceIndex for SPIR-V target Fixes #3876 * add a new IR instruction for NonUniformResourceIndex * slang ir emitter for nonuniform resource index * update the hlsl meta slang * Add test cases for NonUniformResourceIndex access for buffers and textures, with/without cast, nested access etc. * add default c-like emitter for nonuniformresourceinfo * added hlsl emitter * added glsl emitter * requisites for spirv enabling - new decorator for nonuniformresourceindex - emitter for nonuniformresourceindex signature change * add hasResourceType checker * add rwStructBuffType in resourcetype checker * add a case for nonuniformres in emitDecorations * DO NOT COMMIT: This change adds special handling for RWStructBuf within the isResourceType function, if it is a pointer to this resource, return true to make it work with nonuniformres test * spirv emitter for decorations - update the emitLocalInst to perform decorations at the end * added main spirv emitter code * slang emit spirv bugfix * hacky way of supporting Call Inst * move code to cleanup nonuniform inst into helper function * remove stale codefrom test * add spirv decoration for nonuniform * update test to remove global variables * update coherent-2 test * update comment for special handling * update the spirv legalize to handle nested nonuniforms improved logic that handles call ops, rwstructbuf, nested nonuniforms etc. * update nonuniform-array-of-tex test * missed removing nonuniform inst causing duplicate decorations * add glsl and hlsl variants of nonuniform tests * repurpose the hasResource function into something specific for nonuniform inst decoration helper * clean up comments and code around spirv-legalization to emit nonuniform inst by recursively looking into the inst * use the helper canDecorateNonUniformInst to convert `nonUniformResourceInfo` inst to decoration * converted compute/unbounded-array-of-array cross compile test into a simple check test * update contains Resource helper function to be more generic * clean up the case for opcall handling with nonuniform resource inst * update ptr to struct buffer check to be more explicit and rename the function to check for ptr to resource type * update comments and fix the test for coherent * fix typos * update logic on spirv legalize to delete dead instructions - for some reason this doesn't automatically happen * add comments to declarations * add NonuniformResourceIndex to the non-differential inst list --- source/slang/hlsl.meta.slang | 53 +--------- source/slang/slang-emit-c-like.cpp | 7 ++ source/slang/slang-emit-glsl.cpp | 15 +++ source/slang/slang-emit-hlsl.cpp | 12 +++ source/slang/slang-emit-spirv.cpp | 175 +++++++++++++++++++++---------- source/slang/slang-ir-autodiff-fwd.cpp | 1 + source/slang/slang-ir-inst-defs.h | 14 ++- source/slang/slang-ir-insts.h | 24 +++++ source/slang/slang-ir-spirv-legalize.cpp | 124 ++++++++++++++++++++++ source/slang/slang-ir-util.h | 6 ++ source/slang/slang-ir.cpp | 12 +++ source/slang/slang-legalize-types.cpp | 15 +++ 12 files changed, 354 insertions(+), 104 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 3576e46e3..cdd08b5d7 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6990,55 +6990,12 @@ T __copyObject(T v) } } - -__glsl_extension(GL_EXT_nonuniform_qualifier) -[__readNone] -[ForceInline] -uint NonUniformResourceIndex(uint index) -{ - __target_switch - { - case hlsl: - __intrinsic_asm "NonUniformResourceIndex"; - case glsl: - __intrinsic_asm "nonuniformEXT"; - case spirv: - var indexCopy = __copyObject(index); - spirv_asm - { - OpCapability ShaderNonUniform; - OpDecorate $indexCopy NonUniform; - }; - return indexCopy; - default: - return index; - } -} - -__glsl_extension(GL_EXT_nonuniform_qualifier) +/// `NonUniformResourceIndex` function is used to indicate if the resource index is +/// divergent, and ensure scalarization happens correctly for each divergent lane. [__readNone] -[ForceInline] -[NonUniformReturn] -int NonUniformResourceIndex(int index) -{ - __target_switch - { - case hlsl: - __intrinsic_asm "NonUniformResourceIndex"; - case glsl: - __intrinsic_asm "nonuniformEXT"; - case spirv: - var indexCopy = __copyObject(index); - spirv_asm - { - OpCapability ShaderNonUniform; - OpDecorate $indexCopy NonUniform; - }; - return indexCopy; - default: - return index; - } -} +__generic +__intrinsic_op($(kIROp_NonUniformResourceIndex)) +T NonUniformResourceIndex(T index); /// HLSL allows NonUniformResourceIndex around non int/uint types. /// It's effect is presumably to ignore it, which the following implementation does. diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 7cb4871be..1926cbdcb 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1300,6 +1300,9 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) case kIROp_GetVulkanRayTracingPayloadLocation: return true; + + case kIROp_NonUniformResourceIndex: + return true; } // Layouts and attributes are only present to annotate other @@ -2371,6 +2374,10 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO m_writer->emit("GroupMemoryBarrierWithGroupSync()"); break; + case kIROp_NonUniformResourceIndex: + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); // Directly emit NonUniformResourceIndex Operand0; + break; + case kIROp_getNativeStr: { auto prec = getInfo(EmitOp::Postfix); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 1ae178aa8..293bcb891 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2068,6 +2068,21 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu maybeCloseParens(assignNeedsClose); return true; } + case kIROp_NonUniformResourceIndex: + { + // Need to emit as a Function call for HLSL + m_writer->emit("nonuniformEXT"); + m_writer->emit("("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + + // Forcibly enabling the GL extension when using 'implict-sized' arrays + // with the qualifier. May be this is not advisable. + _requireGLSLExtension(UnownedStringSlice::fromLiteral("GL_EXT_nonuniform_qualifier")); + + // Handled + return true; + } default: break; } diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 03de2108b..9fcdef106 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -725,6 +725,18 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return false; } break; + case kIROp_NonUniformResourceIndex: + { + // Need to emit as a Function call for HLSL + m_writer->emit("NonUniformResourceIndex"); + m_writer->emit("("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + + // Handled + return true; + } + break; default: break; } diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index b5a94cf0d..3a551f301 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -2588,6 +2588,7 @@ struct SPIRVEmitContext /// Emit an instruction that is local to the body of the given `parent`. SpvInst* emitLocalInst(SpvInstParent* parent, IRInst* inst) { + SpvInst* result = nullptr; switch( inst->getOp() ) { default: @@ -2600,65 +2601,91 @@ struct SPIRVEmitContext } case kIROp_Specialize: case kIROp_MissingReturn: - return nullptr; + break; case kIROp_Var: - return emitVar(parent, inst); + result = emitVar(parent, inst); + break; case kIROp_Call: - return emitCall(parent, static_cast(inst)); + result = emitCall(parent, static_cast(inst)); + break; case kIROp_FieldAddress: - return emitFieldAddress(parent, as(inst)); + result = emitFieldAddress(parent, as(inst)); + break; case kIROp_FieldExtract: - return emitFieldExtract(parent, as(inst)); + result = emitFieldExtract(parent, as(inst)); + break; case kIROp_GetElementPtr: - return emitGetElementPtr(parent, as(inst)); + result = emitGetElementPtr(parent, as(inst)); + break; case kIROp_GetOffsetPtr: - return emitGetOffsetPtr(parent, inst); + result = emitGetOffsetPtr(parent, inst); + break; case kIROp_GetElement: - return emitGetElement(parent, as(inst)); + result = emitGetElement(parent, as(inst)); + break; case kIROp_MakeStruct: - return emitCompositeConstruct(parent, inst); + result = emitCompositeConstruct(parent, inst); + break; case kIROp_MakeArrayFromElement: - return emitMakeArrayFromElement(parent, inst); + result = emitMakeArrayFromElement(parent, inst); + break; case kIROp_MakeMatrixFromScalar: - return emitMakeMatrixFromScalar(parent, inst); + result = emitMakeMatrixFromScalar(parent, inst); + break; case kIROp_MakeMatrix: - return emitMakeMatrix(parent, inst); + result = emitMakeMatrix(parent, inst); + break; case kIROp_Load: - return emitLoad(parent, as(inst)); + result = emitLoad(parent, as(inst)); + break; case kIROp_Store: - return emitStore(parent, as(inst)); + result = emitStore(parent, as(inst)); + break; case kIROp_SwizzledStore: - return emitSwizzledStore(parent, as(inst)); + result = emitSwizzledStore(parent, as(inst)); + break; case kIROp_swizzleSet: - return emitSwizzleSet(parent, as(inst)); + result = emitSwizzleSet(parent, as(inst)); + break; case kIROp_RWStructuredBufferGetElementPtr: - return emitStructuredBufferGetElementPtr(parent, inst); + result = emitStructuredBufferGetElementPtr(parent, inst); + break; case kIROp_StructuredBufferGetDimensions: - return emitStructuredBufferGetDimensions(parent, inst); + result = emitStructuredBufferGetDimensions(parent, inst); + break; case kIROp_swizzle: - return emitSwizzle(parent, as(inst)); + result = emitSwizzle(parent, as(inst)); + break; case kIROp_IntCast: - return emitIntCast(parent, as(inst)); + result = emitIntCast(parent, as(inst)); + break; case kIROp_FloatCast: - return emitFloatCast(parent, as(inst)); + result = emitFloatCast(parent, as(inst)); + break; case kIROp_CastIntToFloat: - return emitIntToFloatCast(parent, as(inst)); + result = emitIntToFloatCast(parent, as(inst)); + break; case kIROp_CastFloatToInt: - return emitFloatToIntCast(parent, as(inst)); + result = emitFloatToIntCast(parent, as(inst)); + break; case kIROp_CastPtrToInt: - return emitCastPtrToInt(parent, inst); + result = emitCastPtrToInt(parent, inst); + break; case kIROp_CastPtrToBool: - return emitCastPtrToBool(parent, inst); + result = emitCastPtrToBool(parent, inst); + break; case kIROp_CastIntToPtr: - return emitCastIntToPtr(parent, inst); + result = emitCastIntToPtr(parent, inst); + break; case kIROp_PtrCast: case kIROp_BitCast: - return emitOpBitcast( + result = emitOpBitcast( parent, inst, inst->getDataType(), inst->getOperand(0) ); + break; case kIROp_Add: case kIROp_Sub: case kIROp_Mul: @@ -2681,12 +2708,14 @@ struct SPIRVEmitContext case kIROp_Geq: case kIROp_Rsh: case kIROp_Lsh: - return emitArithmetic(parent, inst); + result = emitArithmetic(parent, inst); + break; case kIROp_GlobalValueRef: { auto inner = ensureInst(inst->getOperand(0)); registerInst(inst, inner); - return inner; + result = inner; + break; } case kIROp_GetVulkanRayTracingPayloadLocation: { @@ -2700,15 +2729,18 @@ struct SPIRVEmitContext } auto inner = ensureInst(location); registerInst(inst, inner); - return inner; + result = inner; + break; } case kIROp_Return: if (as(inst)->getVal()->getOp() == kIROp_VoidLit) - return emitOpReturn(parent, inst); + result = emitOpReturn(parent, inst); else - return emitOpReturnValue(parent, inst, as(inst)->getVal()); + result = emitOpReturnValue(parent, inst, as(inst)->getVal()); + break; case kIROp_discard: - return emitOpKill(parent, inst); + result = emitOpKill(parent, inst); + break; case kIROp_unconditionalBranch: { // If we are jumping to the main block of a loop, @@ -2719,7 +2751,8 @@ struct SPIRVEmitContext if (isLoopTargetBlock(targetBlock, loopInst)) return emitOpBranch(parent, inst, getIRInstSpvID(loopInst)); // Otherwise, emit a normal branch inst into the target block. - return emitOpBranch(parent, inst, getIRInstSpvID(targetBlock)); + result = emitOpBranch(parent, inst, getIRInstSpvID(targetBlock)); + break; } case kIROp_loop: { @@ -2735,7 +2768,8 @@ struct SPIRVEmitContext // from the actual loop target block) are emitted first. emitOpBranch(parent, nullptr, blockId); - return block; + result = block; + break; } case kIROp_ifElse: { @@ -2743,7 +2777,7 @@ struct SPIRVEmitContext auto afterBlockID = getIRInstSpvID(ifelseInst->getAfterBlock()); emitOpSelectionMerge(parent, nullptr, afterBlockID, SpvSelectionControlMaskNone); auto falseLabel = ifelseInst->getFalseBlock(); - return emitOpBranchConditional( + result = emitOpBranchConditional( parent, inst, ifelseInst->getCondition(), @@ -2751,13 +2785,14 @@ struct SPIRVEmitContext falseLabel ? getID(ensureInst(falseLabel)) : afterBlockID, makeArray() ); + break; } case kIROp_Switch: { auto switchInst = as(inst); auto mergeBlockID = getIRInstSpvID(switchInst->getBreakLabel()); emitOpSelectionMerge(parent, nullptr, mergeBlockID, SpvSelectionControlMaskNone); - return emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { + result = emitInstCustomOperandFunc(parent, inst, SpvOpSwitch, [&]() { emitOperand(switchInst->getCondition()); auto defaultLabel = switchInst->getDefaultLabel(); emitOperand(defaultLabel ? getID(ensureInst(defaultLabel)) : mergeBlockID); @@ -2771,13 +2806,17 @@ struct SPIRVEmitContext emitOperand(caseLabel ? getID(ensureInst(caseLabel)) : mergeBlockID); } }); + break; } case kIROp_Unreachable: - return emitOpUnreachable(parent, inst); + result = emitOpUnreachable(parent, inst); + break; case kIROp_conditionalBranch: SLANG_UNEXPECTED("Unstructured branching is not supported by SPIRV."); + break; case kIROp_MakeVector: - return emitConstruct(parent, inst); + result = emitConstruct(parent, inst); + break; case kIROp_MakeVectorFromScalar: { const auto scalar = inst->getOperand(0); @@ -2785,45 +2824,62 @@ struct SPIRVEmitContext SLANG_ASSERT(vecTy); const auto numElems = as(vecTy->getElementCount()); SLANG_ASSERT(numElems); - return emitSplat(parent, inst, scalar, numElems->getValue()); + result = emitSplat(parent, inst, scalar, numElems->getValue()); } + break; case kIROp_MakeArray: - return emitConstruct(parent, inst); + result = emitConstruct(parent, inst); + break; case kIROp_Select: - return emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst)); + result = emitInst(parent, inst, SpvOpSelect, inst->getFullType(), kResultID, OperandsOf(inst)); + break; case kIROp_DebugLine: - return emitDebugLine(parent, as(inst)); + result = emitDebugLine(parent, as(inst)); + break; case kIROp_DebugVar: - return emitDebugVar(parent, as(inst)); + result = emitDebugVar(parent, as(inst)); + break; case kIROp_DebugValue: - return emitDebugValue(parent, as(inst)); + result = emitDebugValue(parent, as(inst)); + break; case kIROp_GetStringHash: - return emitGetStringHash(inst); + result = emitGetStringHash(inst); + break; case kIROp_undefined: - return emitOpUndef(parent, inst, inst->getDataType()); + result = emitOpUndef(parent, inst, inst->getDataType()); + break; case kIROp_SPIRVAsm: - return emitSPIRVAsm(parent, as(inst)); + result = emitSPIRVAsm(parent, as(inst)); + break; case kIROp_ImageLoad: - return emitImageLoad(parent, as(inst)); + result = emitImageLoad(parent, as(inst)); + break; case kIROp_ImageStore: - return emitImageStore(parent, as(inst)); + result = emitImageStore(parent, as(inst)); + break; case kIROp_ImageSubscript: - return emitImageSubscript(parent, as(inst)); + result = emitImageSubscript(parent, as(inst)); + break; case kIROp_AtomicCounterIncrement: { IRBuilder builder{inst}; const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType()); const auto memorySemantics = emitIntConstant(IRIntegerValue{SpvMemorySemanticsMaskNone}, builder.getUIntType()); - return emitOpAtomicIIncrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); + result = emitOpAtomicIIncrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); } + break; case kIROp_AtomicCounterDecrement: { IRBuilder builder{inst}; const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType()); const auto memorySemantics = emitIntConstant(IRIntegerValue{SpvMemorySemanticsMaskNone}, builder.getUIntType()); - return emitOpAtomicIDecrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); + result = emitOpAtomicIDecrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); } + break; } + if (result) + emitDecorations(inst, getID(result)); + return result; } SpvInst* emitImageLoad(SpvInstParent* parent, IRImageLoad* load) @@ -2925,7 +2981,6 @@ struct SPIRVEmitContext } } - SpvExecutionMode getDepthOutputExecutionMode(IRInst* builtinVar) { SpvExecutionMode result = SpvExecutionModeMax; @@ -3301,6 +3356,18 @@ struct SPIRVEmitContext } break; + case kIROp_SPIRVNonUniformResourceDecoration: + { + requireSPIRVCapability(SpvCapabilityShaderNonUniform); + emitOpDecorate( + getSection(SpvLogicalSectionID::Annotations), + decoration, + dstID, + SpvDecorationNonUniform + ); + } + break; + case kIROp_OutputTopologyDecoration: { const auto o = cast(decoration); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 858d94514..0e934eeb5 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1908,6 +1908,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_RWStructuredBufferLoadStatus: case kIROp_RWStructuredBufferStore: case kIROp_RWStructuredBufferGetElementPtr: + case kIROp_NonUniformResourceIndex: case kIROp_IsType: case kIROp_ImageSubscript: case kIROp_ImageLoad: diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 2612d2ac7..ca4c8359a 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -480,6 +480,9 @@ INST(StructuredBufferAppend, StructuredBufferAppend, 1, 0) INST(StructuredBufferConsume, StructuredBufferConsume, 1, 0) INST(StructuredBufferGetDimensions, StructuredBufferGetDimensions, 1, 0) +// Resource qualifiers for dynamically varying index +INST(NonUniformResourceIndex, nonUniformResourceIndex, 1, 0) + INST(AtomicCounterIncrement, AtomicCounterIncrement, 1, 0) INST(AtomicCounterDecrement, AtomicCounterDecrement, 1, 0) @@ -690,7 +693,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) /* Decoration */ -INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) + INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(LayoutDecoration, layout, 1, 0) INST(BranchDecoration, branch, 0, 0) INST(FlattenDecoration, flatten, 0, 0) @@ -1005,6 +1008,13 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Recognized by SPIRV-emit pass so we can emit a SPIRV `Block` decoration. INST(SPIRVBlockDecoration, spvBlock, 0, 0) + /// Decorates a SPIRV-inst as `NonUniformResource` to guarantee non-uniform index lookup of + /// - a resource within an array of resources via IRGetElement. + /// - an IRLoad that takes a pointer within a memory buffer via IRGetElementPtr. + /// - an IRIntCast to a resource that is casted from signed to unsigned or viceversa. + /// - an IRGetElementPtr itself when using the pointer on an intrinsic operation. + INST(SPIRVNonUniformResourceDecoration, NonUniformResource, 0, 0) + // Stores flag bits of which memory qualifiers an object has INST(MemoryQualifierSetDecoration, MemoryQualifierSetDecoration, 1, 0) @@ -1012,7 +1022,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// backing value key, width and offset INST(BitFieldAccessorDecoration, BitFieldAccessorDecoration, 3, 0) - INST_RANGE(Decoration, HighLevelDeclDecoration, BitFieldAccessorDecoration) + INST_RANGE(Decoration, HighLevelDeclDecoration, BitFieldAccessorDecoration) // diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 60730f135..d1615e89c 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -327,6 +327,18 @@ struct IRRequireGLSLVersionDecoration : IRDecoration } }; +struct IRSPIRVNonUniformResourceDecoration : IRDecoration +{ + enum { kOp = kIROp_SPIRVNonUniformResourceDecoration }; + IR_LEAF_ISA(RequireGLSLVersionDecoration) + + IRConstant* getSPIRVNonUniformResourceOperand() { return cast(getOperand(0)); } + IntegerLiteralValue getSPIRVNonUniformResource() + { + return getSPIRVNonUniformResourceOperand()->value.intVal; + } +}; + struct IRRequireSPIRVVersionDecoration : IRDecoration { enum { kOp = kIROp_RequireSPIRVVersionDecoration }; @@ -2336,6 +2348,11 @@ struct IRStructuredBufferGetDimensions : IRInst IRInst* getBuffer() { return getOperand(0); } }; +struct IRNonUniformResourceIndex : IRInst +{ + IR_LEAF_ISA(NonUniformResourceIndex); +}; + struct IRLoadReverseGradient : IRInst { IR_LEAF_ISA(LoadReverseGradient) @@ -4229,6 +4246,8 @@ public: IRInst* emitGenericAsm(UnownedStringSlice asmText); IRInst* emitRWStructuredBufferGetElementPtr(IRInst* structuredBuffer, IRInst* index); + + IRInst* emitNonUniformResourceIndexInst(IRInst* val); // // Decorations // @@ -4455,6 +4474,11 @@ public: addDecoration(value, kIROp_RequireSPIRVVersionDecoration, getIntValue(getBasicType(BaseType::UInt64), intValue)); } + void addSPIRVNonUniformResourceDecoration(IRInst* value) + { + addDecoration(value, kIROp_SPIRVNonUniformResourceDecoration); + } + void addRequireCUDASMVersionDecoration(IRInst* value, const SemanticVersion& version) { SemanticVersion::IntegerType intValue = version.toInteger(); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index a1126104a..65ad29a9a 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -1167,6 +1167,127 @@ struct SPIRVLegalizationContext : public SourceEmitterBase addUsersToWorkList(newStore); } + void processNonUniformResourceIndex(IRInst* nonUniformResourceIndexInst) + { + // implement the translation to spirv by walking up the use-def chain + // from nonUniformResource inst of an index to an array of buffer or + // texture def all the way to the leaf operations. To be precise: + // - go through GEP and see if it calls an intrinsic function, + // then decorate the address itself (GetElementPtr) + // - go through GEP to identify the pointer access and the Loads that it + // accesses (GetElementPtr -> Load), then decorate the load instruction. + // - go through IntCasts to deal with u32 -> i32 / vice-versa (IntCast) + List resWorkList; + + // Handle cases when `nonUniformResourceIndexInst` inst is wrapped around + // an index in a nested fashion, i.e. nonUniform(nonUniform(index)) by + // only adding the inner-most inst in the worklist, and work our way out. + auto insti = nonUniformResourceIndexInst; + while (insti->getOp() == kIROp_NonUniformResourceIndex) + { + if (resWorkList.getCount() != 0) + resWorkList.removeLast(); + resWorkList.add(insti); + insti = insti->getOperand(0); + } + + // For all the users of a `nonUniformResourceIndexInst`, make them directly + // use the underlying base inst that is wrapped by `nonUniformResourceIndex` + // and finally wrap them with a `nonUniformResourceIndex`, and add back to the + // worklist, and keep bubbling them up until it can. + for (Index i = 0; i < resWorkList.getCount(); i++) + { + auto inst = resWorkList[i]; + traverseUses(inst, [&](IRUse* use) + { + auto user = use->getUser(); + IRBuilder builder(user); + builder.setInsertBefore(user); + + IRInst* newUser = nullptr; + switch (user->getOp()) + { + case kIROp_IntCast: + // Replace intCast(nonUniformRes(x)), into nonUniformRes(intCast(x)) + newUser = builder.emitCast(user->getFullType(), inst->getOperand(0)); + break; + case kIROp_GetElementPtr: + // Ignore when `NonUniformResourceIndex` is not on the index + if (user->getOperand(0)->getOp() != kIROp_NonUniformResourceIndex) + { + // Replace gep(pArray, nonUniformRes(x)), into nonUniformRes(gep(pArray, x)) + newUser = builder.emitElementAddress(user->getFullType(), user->getOperand(0), inst->getOperand(0)); + } + break; + case kIROp_NonUniformResourceIndex: + // Replace nonUniformRes(nonUniformRes(x)), into nonUniformRes(x) + newUser = inst->getOperand(0); + break; + case kIROp_Load: + // Replace load(nonUniformRes(x)), into nonUniformRes(load(x)) + newUser = builder.emitLoad(user->getFullType(), inst->getOperand(0)); + break; + default: + // Ignore for all other unknown insts. + break; + }; + + // Early exit when we could not process the `NonUniformResourceIndex` inst. + if (!newUser) + return; + + auto nonuniformUser = builder.emitNonUniformResourceIndexInst(newUser); + user->replaceUsesWith(nonuniformUser); + + // Update the worklist with the newly added `NonUniformResourceIndex` inst, based on + // the base inst it was constructed around, in case we need to further bubble up + // the `NonUniformResourceIndex` inst. + switch (user->getOp()) + { + case kIROp_IntCast: + case kIROp_GetElementPtr: + case kIROp_Load: + case kIROp_NonUniformResourceIndex: + resWorkList.add(nonuniformUser); + break; + }; + + // Clean up the base inst from the IR module, to avoid duplicate decorations. + user->removeAndDeallocate(); + }); + } + + // Once all the `NonUniformResourceIndex` insts are visited, and the inst type is bubbled up + // to the parent, a decoration is added to the operands of the insts. + for (int i = 0; i < resWorkList.getCount(); ++i) + { + // It is only required to decorate the base inst, if the `NonUniformResourceIndex` inst + // around it has any active uses. + auto inst = resWorkList[i]; + if (!inst->hasUses()) + { + inst->removeAndDeallocate(); + continue; + } + // For each of the `NonUniformResourceIndex` inst that remain, decorate the base inst + // with a [NonUniformResource] decoration, which is the operand0 of the inst, only + // when the type is a resource type, or a pointer to a resource type, or a pointer + // in the Physical Storage buffer address space. + auto operand = inst->getOperand(0); + auto type = operand->getDataType(); + if (isResourceType(type) || + isPointerToResourceType(type)) + { + IRBuilder builder(operand); + builder.addSPIRVNonUniformResourceDecoration(operand); + inst->replaceUsesWith(operand); + inst->removeAndDeallocate(); + } + } + nonUniformResourceIndexInst->removeFromParent(); + m_instsToRemove.add(nonUniformResourceIndexInst); + } + void processImageSubscript(IRImageSubscript* subscript) { if (auto ptrType = as(subscript->getDataType())) @@ -1822,6 +1943,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_RWStructuredBufferStore: processRWStructuredBufferStore(inst); break; + case kIROp_NonUniformResourceIndex: + processNonUniformResourceIndex(inst); + break; case kIROp_loop: processLoop(as(inst)); break; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 94ae4bc9f..fba1e784e 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -78,6 +78,12 @@ bool isComInterfaceType(IRType* type); // If `type` is a vector, returns its element type. Otherwise, return `type`. IRType* getVectorElementType(IRType* type); +// True if type is a resource backing memory +bool isResourceType(IRType* type); + +// True if type is a pointer to a resource +bool isPointerToResourceType(IRType* type); + IROp getTypeStyle(IROp op); IROp getTypeStyle(BaseType op); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3a6dc667b..7cdc8ecfa 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6132,6 +6132,18 @@ namespace Slang return i; } + // IR emitter for a dedicated instruction to represent NonUniformResourceIndex qualifier. + IRInst* IRBuilder::emitNonUniformResourceIndexInst(IRInst* val) + { + const auto i = createInst( + this, + kIROp_NonUniformResourceIndex, + getTypeType(), + val); + addInst(i); + return i; + } + // // Decorations // diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index 792429a01..fc0947c62 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -203,6 +203,21 @@ bool isResourceType(IRType* type) return false; } +// Helper wrapper function around isResourceType that checks if the given +// type is a pointer to a resource type or a physical storage buffer. +bool isPointerToResourceType(IRType* type) +{ + while (auto ptrType = as(type)) + { + if (ptrType->getAddressSpace() == SpvStorageClassStorageBuffer || + ptrType->getAddressSpace() == SpvStorageClassPhysicalStorageBufferEXT) + return true; + type = ptrType->getValueType(); + } + + return isResourceType(type); +} + ModuleDecl* findModuleForDecl( Decl* decl) { -- cgit v1.2.3