diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-09 09:44:33 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-09 09:44:33 -0700 |
| commit | 38ed03a7203baacf36fca62539ac74fd45ed42d2 (patch) | |
| tree | 9648daee25c0a2aaac2fa8cd7d91908fd2aeef2f | |
| parent | 89a1234964a1927c4936a2758f72b7d6c9d0bc73 (diff) | |
Fix function side-effectness prop logic. (#2875)
31 files changed, 912 insertions, 438 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index a2ed1d1df..b992def6e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -112,6 +112,12 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} +__attributeTarget(AggTypeDecl) +attribute_syntax [__NonCopyableType] : NonCopyableTypeAttribute; + +__attributeTarget(FunctionDeclBase) +attribute_syntax [__NoSideEffect] : NoSideEffectAttribute; + /// Marks a function for forward-mode differentiation. /// i.e. the compiler will automatically generate a new function /// that computes the jacobian-vector product of the original. diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 84a72a425..0725103da 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -32,16 +32,16 @@ __intrinsic_type($(kIROp_TensorViewType)) struct TensorView { __target_intrinsic(cuda, "$0.data_ptr<$G0>()") - [__readNone] + [__NoSideEffect] Ptr<T> data_ptr(); __target_intrinsic(cuda, "$0.data_ptr_at<$G0>($1)") - [__readNone] + [__NoSideEffect] Ptr<T> data_ptr_at(uint index); __generic<let N: int> __target_intrinsic(cuda, "$0.data_ptr_at<$G0>($1)") - [__readNone] + [__NoSideEffect] Ptr<T> data_ptr_at(vector<uint, N> index); __implicit_conversion($(kConversionCost_ImplicitDereference)) @@ -49,19 +49,19 @@ struct TensorView __init(TorchTensor<T> t); __target_intrinsic(cuda, "$0.load<$G0>($1)") - [__readNone] + [__NoSideEffect] T load(uint x); __target_intrinsic(cuda, "$0.load<$G0>($1, $2)") - [__readNone] + [__NoSideEffect] T load(uint x, uint y); __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)") - [__readNone] + [__NoSideEffect] T load(uint x, uint y, uint z); __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)") - [__readNone] + [__NoSideEffect] T load(uint x, uint y, uint z, uint w); __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)") - [__readNone] + [__NoSideEffect] T load(uint i0, uint i1, uint i2, uint i3, uint i4); __target_intrinsic(cuda, "$0.store<$G0>($1, $2)") @@ -96,59 +96,67 @@ struct TensorView __subscript(uint index) -> T { - [ForceInline] [__readNone] get { return load(index); } + [ForceInline] [__NoSideEffect] get { return load(index); } [ForceInline] set { store(index, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1)") + [__NoSideEffect] ref; } __subscript(uint i1, uint i2) -> T { - [ForceInline] [__readNone] get { return load(i1, i2); } + [ForceInline] [__NoSideEffect] get { return load(i1, i2); } [ForceInline] set { store(i1, i2, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1, $2)") + [__NoSideEffect] ref; } __subscript(uint2 i) -> T { - [ForceInline] [__readNone] get { return load(i.x, i.y); } + [ForceInline] [__NoSideEffect] get { return load(i.x, i.y); } [ForceInline] set { store(i.x, i.y, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1.x, $1.y)") + [__NoSideEffect] ref; } __subscript(uint i1, uint i2, uint i3) -> T { - [ForceInline] [__readNone] get { return load(i1, i2, i3); } + [ForceInline] [__NoSideEffect] get { return load(i1, i2, i3); } [ForceInline] set { store(i1, i2, i3, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3)") + [__NoSideEffect] ref; } __subscript(uint3 i) -> T { - [ForceInline] [__readNone] get { return load(i.x, i.y, i.z); } + [ForceInline] [__NoSideEffect] get { return load(i.x, i.y, i.z); } [ForceInline] set { store(i.x, i.y, i.z, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1.x, $1.y, $1.z)") + [__NoSideEffect] ref; } __subscript(uint i1, uint i2, uint i3, uint i4) -> T { - [ForceInline] [__readNone] get { return load(i1, i2, i3, i4); } + [ForceInline] [__NoSideEffect] get { return load(i1, i2, i3, i4); } [ForceInline] set { store(i1, i2, i3, i4, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4)") + [__NoSideEffect] ref; } __subscript(uint4 i) -> T { - [__readNone][ForceInline] get { return load(i.x, i.y, i.z, i.w); } + [__NoSideEffect][ForceInline] get { return load(i.x, i.y, i.z, i.w); } [ForceInline] set { store(i.x, i.y, i.z, i.w, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1.x, $1.y, $1.z, $1.w)") + [__NoSideEffect] ref; } __subscript(uint i1, uint i2, uint i3, uint i4, uint i5) -> T { - [ForceInline] [__readNone] get { return load(i1, i2, i3, i4, i5); } + [ForceInline] [__NoSideEffect] get { return load(i1, i2, i3, i4, i5); } [ForceInline] set { store(i1, i2, i3, i4, i5, newValue); } __target_intrinsic(cuda, "$0.load<$G0>($1, $2, $3, $4, $5)") + [__NoSideEffect] ref; } } diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1580a7a23..37af3ef5a 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -300,25 +300,34 @@ struct $(item.name) out uint dim); __target_intrinsic(glsl, "$0._data[$1/4]") + [__NoSideEffect] uint Load(int location); + [__NoSideEffect] uint Load(int location, out uint status); __target_intrinsic(glsl, "uvec2($0._data[$1/4], $0._data[$1/4+1])") + [__NoSideEffect] uint2 Load2(int location); + [__NoSideEffect] uint2 Load2(int location, out uint status); __target_intrinsic(glsl, "uvec3($0._data[$1/4], $0._data[$1/4+1], $0._data[$1/4+2])") + [__NoSideEffect] uint3 Load3(int location); + [__NoSideEffect] uint3 Load3(int location, out uint status); __target_intrinsic(glsl, "uvec4($0._data[$1/4], $0._data[$1/4+1], $0._data[$1/4+2], $0._data[$1/4+3])") + [__NoSideEffect] uint4 Load4(int location); + [__NoSideEffect] uint4 Load4(int location, out uint status); + [__NoSideEffect] T Load<T>(int location) { return __byteAddressBufferLoad<T>(this, location); @@ -713,13 +722,16 @@ struct $(item.name) __target_intrinsic(glsl, "$0._data[$1]") __target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") + [__NoSideEffect] T Load(int location); + [__NoSideEffect] T Load(int location, out uint status); __subscript(uint index) -> T { __target_intrinsic(glsl, "$0._data[$1]") __target_intrinsic(spirv_direct, "*StorageBuffer OpAccessChain resultType resultId _0 const(int, 0) _1") + [__NoSideEffect] ref; } }; @@ -5685,13 +5697,9 @@ __target_intrinsic(hlsl, RayQuery) __target_intrinsic(glsl, rayQueryEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) +[__NonCopyableType] struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> { - // Initialize the query object in a "fresh" state. - // - __intrinsic_op($(kIROp_DefaultConstruct)) - __init(); - // Initialize a ray-tracing query. // // This method may be called on a "fresh" ray query, or @@ -5705,6 +5713,8 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> // must obey any API-imposed restrictions. // __target_intrinsic(hlsl) + [__NoSideEffect] + [mutating] void TraceRayInline( RaytracingAccelerationStructure accelerationStructure, RAY_FLAG rayFlags, @@ -5725,6 +5735,7 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> [__unsafeForceInlineEarly] __specialized_for_target(glsl) + [__NoSideEffect] void TraceRayInline( RaytracingAccelerationStructure accelerationStructure, RAY_FLAG rayFlags, @@ -5758,6 +5769,8 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, rayQueryProceedEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__NoSideEffect] + [mutating] bool Proceed(); // Causes the ray query to terminate. @@ -5769,6 +5782,8 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, rayQueryTerminateEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__NoSideEffect] + [mutating] void Abort(); // Get the type of candidate hit being considered. @@ -5783,6 +5798,7 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "rayQueryGetIntersectionTypeEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] CANDIDATE_TYPE CandidateType(); // Access properties of a candidate hit. @@ -5790,46 +5806,55 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "transpose(rayQueryGetIntersectionObjectToWorldEXT($0, false))") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3x4 CandidateObjectToWorld3x4(); __target_intrinsic(glsl, "rayQueryGetIntersectionObjectToWorldEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float4x3 CandidateObjectToWorld4x3(); __target_intrinsic(glsl, "transpose(rayQueryGetIntersectionWorldToObjectEXT($0, false))") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3x4 CandidateWorldToObject3x4(); __target_intrinsic(glsl, "rayQueryGetIntersectionWorldToObjectEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float4x3 CandidateWorldToObject4x3(); __target_intrinsic(glsl, "rayQueryGetIntersectionInstanceIdEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CandidateInstanceIndex(); __target_intrinsic(glsl, "rayQueryGetIntersectionInstanceCustomIndexEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CandidateInstanceID(); __target_intrinsic(glsl, "rayQueryGetIntersectionGeometryIndexEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CandidateGeometryIndex(); __target_intrinsic(glsl, "rayQueryGetIntersectionPrimitiveIndexEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CandidatePrimitiveIndex(); __target_intrinsic(glsl, "rayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CandidateInstanceContributionToHitGroupIndex(); // Access properties of the ray being traced @@ -5838,11 +5863,13 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "rayQueryGetIntersectionObjectRayOriginEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3 CandidateObjectRayOrigin(); __target_intrinsic(glsl, "rayQueryGetIntersectionObjectRayDirectionEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3 CandidateObjectRayDirection(); // Access properties of a candidate procedural primitive hit. @@ -5850,6 +5877,7 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "rayQueryGetIntersectionCandidateAABBOpaqueEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] bool CandidateProceduralPrimitiveNonOpaque(); // Access properties of a candidate non-opaque triangle hit. @@ -5857,34 +5885,42 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "rayQueryGetIntersectionFrontFaceEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] bool CandidateTriangleFrontFace(); __target_intrinsic(glsl, "rayQueryGetIntersectionBarycentricsEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float2 CandidateTriangleBarycentrics(); __target_intrinsic(glsl, "rayQueryGetIntersectionTEXT($0, false)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float CandidateTriangleRayT(); // Commit the current non-opaque triangle hit. __target_intrinsic(glsl, rayQueryConfirmIntersectionEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__NoSideEffect] + [mutating] void CommitNonOpaqueTriangleHit(); // Commit the current procedural primitive hit, with hit time `t`. __target_intrinsic(glsl, rayQueryGenerateIntersectionEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__NoSideEffect] + [mutating] void CommitProceduralPrimitiveHit(float t); // Get the status of the committed (closest) hit, if any. __target_intrinsic(glsl, "rayQueryGetIntersectionTypeEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] COMMITTED_STATUS CommittedStatus(); // Access properties of the committed hit. @@ -5892,51 +5928,61 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "transpose(rayQueryGetIntersectionObjectToWorldEXT($0, true))") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3x4 CommittedObjectToWorld3x4(); __target_intrinsic(glsl, "rayQueryGetIntersectionObjectToWorldEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float4x3 CommittedObjectToWorld4x3(); __target_intrinsic(glsl, "transpose(rayQueryGetIntersectionWorldToObjectEXT($0, true))") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3x4 CommittedWorldToObject3x4(); __target_intrinsic(glsl, "rayQueryGetIntersectionWorldToObjectEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float4x3 CommittedWorldToObject4x3(); __target_intrinsic(glsl, "rayQueryGetIntersectionTEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float CommittedRayT(); __target_intrinsic(glsl, "rayQueryGetIntersectionInstanceIdEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CommittedInstanceIndex(); __target_intrinsic(glsl, "rayQueryGetIntersectionInstanceCustomIndexEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CommittedInstanceID(); __target_intrinsic(glsl, "rayQueryGetIntersectionGeometryIndexEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CommittedGeometryIndex(); __target_intrinsic(glsl, "rayQueryGetIntersectionPrimitiveIndexEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CommittedPrimitiveIndex(); __target_intrinsic(glsl, "rayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint CommittedInstanceContributionToHitGroupIndex(); // Access properties of the ray being traced @@ -5945,11 +5991,13 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "rayQueryGetIntersectionObjectRayOriginEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3 CommittedObjectRayOrigin(); __target_intrinsic(glsl, "rayQueryGetIntersectionObjectRayDirectionEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3 CommittedObjectRayDirection(); // Access properties of a committed triangle hit. @@ -5957,11 +6005,13 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, "rayQueryGetIntersectionFrontFaceEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] bool CommittedTriangleFrontFace(); __target_intrinsic(glsl, "rayQueryGetIntersectionBarycentricsEXT($0, true)") __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float2 CommittedTriangleBarycentrics(); // Access properties of the ray being traced. @@ -5969,21 +6019,25 @@ struct RayQuery <let rayFlagsGeneric : RAY_FLAG = RAY_FLAG_NONE> __target_intrinsic(glsl, rayQueryGetRayFlagsEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] uint RayFlags(); __target_intrinsic(glsl, rayQueryGetWorldRayOriginEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3 WorldRayOrigin(); __target_intrinsic(glsl, rayQueryGetWorldRayDirectionEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float3 WorldRayDirection(); __target_intrinsic(glsl, rayQueryGetRayTMinEXT) __glsl_extension(GL_EXT_ray_query) __glsl_version(460) + [__readNone] float RayTMin(); } diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index ab66febb9..25761b11c 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1294,6 +1294,15 @@ class DeprecatedAttribute : public Attribute String message; }; +class NonCopyableTypeAttribute : public Attribute +{ + SLANG_AST_CLASS(NonCopyableTypeAttribute) +}; + +class NoSideEffectAttribute : public Attribute +{ + SLANG_AST_CLASS(NoSideEffectAttribute) +}; /// A modifier that applies to types rather than declarations. /// /// In most cases, the Slang compiler assumes that a modifier should diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 16bd67f66..4d44aac1f 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -1,5 +1,6 @@ #include "slang-ir-addr-inst-elimination.h" #include "slang-ir-insts.h" +#include "slang-ir-util.h" namespace Slang { @@ -110,7 +111,6 @@ struct AddressInstEliminationContext } SlangResult eliminateAddressInstsImpl( - AddressConversionPolicy* policy, IRFunc* func, DiagnosticSink* inSink) { @@ -123,9 +123,13 @@ struct AddressInstEliminationContext { for (auto inst : block->getChildren()) { - if (as<IRPtrTypeBase>(inst->getDataType())) + if (auto ptrType = as<IRPtrTypeBase>(inst->getDataType())) { - workList.add(inst); + auto valType = unwrapAttributedType(ptrType->getValueType()); + if (!getResolvedInstForDecorations(valType)->findDecoration<IRNonCopyableTypeDecoration>()) + { + workList.add(inst); + } } } } @@ -134,9 +138,6 @@ struct AddressInstEliminationContext { auto addrInst = workList[workListIndex]; - if (!policy->shouldConvertAddrInst(addrInst)) - continue; - for (auto use = addrInst->firstUse; use; ) { auto nextUse = use->nextUse; @@ -174,14 +175,13 @@ struct AddressInstEliminationContext }; SlangResult eliminateAddressInsts( - AddressConversionPolicy* policy, IRFunc* func, DiagnosticSink* sink) { AddressInstEliminationContext ctx; ctx.module = func->getModule(); ctx.sink = sink; - return ctx.eliminateAddressInstsImpl(policy, func, sink); + return ctx.eliminateAddressInstsImpl(func, sink); } } // namespace Slang diff --git a/source/slang/slang-ir-addr-inst-elimination.h b/source/slang/slang-ir-addr-inst-elimination.h index 6c6506bc0..b80372345 100644 --- a/source/slang/slang-ir-addr-inst-elimination.h +++ b/source/slang/slang-ir-addr-inst-elimination.h @@ -7,12 +7,7 @@ namespace Slang { class DiagnosticSink; -struct AddressConversionPolicy -{ - virtual bool shouldConvertAddrInst(IRInst* addrInst) = 0; -}; SlangResult eliminateAddressInsts( - AddressConversionPolicy* policy, IRFunc* func, DiagnosticSink* sink); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 444816ff7..e6bfc751c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1590,16 +1590,6 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func) } } -struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy -{ - DifferentiableTypeConformanceContext* diffTypeContext; - - virtual bool shouldConvertAddrInst(IRInst*) override - { - return true; - } -}; - SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) { insertTempVarForMutableParams(autoDiffSharedContext->moduleInst->getModule(), func); @@ -1609,9 +1599,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) initializeLocalVariables(autoDiffSharedContext->moduleInst->getModule(), func); - AutoDiffAddressConversionPolicy cvtPolicty; - cvtPolicty.diffTypeContext = &differentiableTypeConformanceContext; - auto result = eliminateAddressInsts(&cvtPolicty, func, sink); + auto result = eliminateAddressInsts(func, sink); if (SLANG_SUCCEEDED(result)) { diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 353d56cfa..0016f25e3 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -1602,20 +1602,6 @@ static CheckpointPreference getCheckpointPreference(IRInst* callee) return CheckpointPreference::None; } -static bool isGlobalMutableAddress(IRInst* inst) -{ - auto root = getRootAddr(inst); - if (root) - { - if (as<IRParameterGroupType>(root->getDataType())) - { - return false; - } - return as<IRModuleInst>(root->getParent()) != nullptr; - } - return false; -} - static bool isInstInPrimalOrTransposedParameterBlocks(IRInst* inst) { auto func = getParentFunc(inst); @@ -1790,7 +1776,7 @@ bool DefaultCheckpointPolicy::canRecompute(UseOrPseudoUse use) // We can't recompute a `load` is if it is a load from a global mutable // variable. - if (isGlobalMutableAddress(ptr)) + if (isGlobalOrUnknownMutableAddress(getParentFunc(load), ptr)) return false; // We can't recompute a 'load' from a mutable function parameter. diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index ffdd2b337..d9036f8bc 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -763,6 +763,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Applie to an IR function and signals that inlining should not be performed unless unavoidable. INST(NoInlineDecoration, noInline, 0, 0) + // Marks a type to be non copyable, causing SSA pass to skip turning variables of the the type into SSA values. + INST(NonCopyableTypeDecoration, nonCopyable, 0, 0) + /// A call to the decorated function should always be folded into its use site. INST(AlwaysFoldIntoUseSiteDecoration, alwaysFold, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9c31798e0..3f49be801 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -336,6 +336,7 @@ IR_SIMPLE_DECORATION(RequiresNVAPIDecoration) IR_SIMPLE_DECORATION(NoInlineDecoration) IR_SIMPLE_DECORATION(AlwaysFoldIntoUseSiteDecoration) IR_SIMPLE_DECORATION(StaticRequirementDecoration) +IR_SIMPLE_DECORATION(NonCopyableTypeDecoration) struct IRNVAPIMagicDecoration : IRDecoration { @@ -3958,6 +3959,11 @@ public: addDecoration(value, kIROp_NVAPISlotDecoration, getStringValue(registerName), getStringValue(spaceName)); } + void addNonCopyableTypeDecoration(IRInst* value) + { + addDecoration(value, kIROp_NonCopyableTypeDecoration); + } + /// Add a decoration that indicates that the given `inst` depends on the given `dependency`. /// /// This decoration can be used to ensure that a value that an instruction diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index de9071adf..14e79560e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -445,6 +445,7 @@ static void cloneExtraDecorationsFromInst( case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_PrimalSubstituteDecoration: case kIROp_IntrinsicOpDecoration: + case kIROp_NonCopyableTypeDecoration: if (!clonedInst->findDecorationImpl(decoration->getOp())) { cloneInst(context, builder, decoration); diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp index fd86e1d3c..f87633656 100644 --- a/source/slang/slang-ir-lower-witness-lookup.cpp +++ b/source/slang/slang-ir-lower-witness-lookup.cpp @@ -387,7 +387,7 @@ struct WitnessLookupLoweringContext bool changed = false; for (auto bb : func->getBlocks()) { - for (auto inst : bb->getChildren()) + for (auto inst : bb->getModifiableChildren()) { if (auto witnessLookupInst = as<IRLookupWitnessMethod>(inst)) { diff --git a/source/slang/slang-ir-propagate-func-properties.cpp b/source/slang/slang-ir-propagate-func-properties.cpp index 7ce4bfc80..7b29aaf14 100644 --- a/source/slang/slang-ir-propagate-func-properties.cpp +++ b/source/slang/slang-ir-propagate-func-properties.cpp @@ -7,7 +7,112 @@ namespace Slang { -bool propagateFuncProperties(IRModule* module) +class FuncPropertyPropagationContext +{ +public: + virtual bool canProcess(IRFunc* f) = 0; + virtual bool propagate(IRBuilder& builder, IRFunc* func) = 0; +}; + +class ReadNoneFuncPropertyPropagationContext : public FuncPropertyPropagationContext +{ +public: + virtual bool canProcess(IRFunc* f) override + { + // If the func has already been marked with any decorations, skip. + for (auto decoration : f->getDecorations()) + { + switch (decoration->getOp()) + { + case kIROp_ReadNoneDecoration: + case kIROp_TargetIntrinsicDecoration: + return false; + } + } + return true; + } + + virtual bool propagate(IRBuilder& builder, IRFunc* f) override + { + bool hasSideEffectCall = false; + for (auto block : f->getBlocks()) + { + for (auto inst : block->getChildren()) + { + // Is this inst known to not have global side effect/analyzable? + if (inst->mightHaveSideEffects()) + { + switch (inst->getOp()) + { + case kIROp_ifElse: + case kIROp_unconditionalBranch: + case kIROp_Switch: + case kIROp_Return: + case kIROp_loop: + case kIROp_Call: + case kIROp_Param: + case kIROp_Unreachable: + case kIROp_Store: + case kIROp_SwizzledStore: + break; + default: + // We have a inst that has side effect and is not understood by this method. + // e.g. bufferStore, discard, etc. + hasSideEffectCall = true; + break; + } + } + + if (auto call = as<IRCall>(inst)) + { + auto callee = getResolvedInstForDecorations(call->getCallee()); + switch (callee->getOp()) + { + default: + // We are calling an unknown function, so we have to assume + // there are side effects in the call. + hasSideEffectCall = true; + break; + case kIROp_Func: + if (!callee->findDecoration<IRReadNoneDecoration>()) + { + hasSideEffectCall = true; + break; + } + } + } + + // Do any operands defined have pointer type of global or + // unknown source? Passing them into a readNone callee may cause + // side effects that breaks the readNone property. + for (UInt o = 0; o < inst->getOperandCount(); o++) + { + auto operand = inst->getOperand(o); + if (as<IRConstant>(operand)) + continue; + if (as<IRType>(operand)) + continue; + if (isGlobalOrUnknownMutableAddress(f, operand)) + { + hasSideEffectCall = true; + break; + } + break; + } + } + if (hasSideEffectCall) + break; + } + if (!hasSideEffectCall) + { + builder.addDecoration(f, kIROp_ReadNoneDecoration); + return true; + } + return false; + } +}; + +bool propagateFuncPropertiesImpl(IRModule* module, FuncPropertyPropagationContext* context) { bool result = false; List<IRFunc*> workList; @@ -61,7 +166,7 @@ bool propagateFuncProperties(IRModule* module) } if (auto func = as<IRFunc>(inst)) { - if (func->findDecoration<IRReadNoneDecoration>()) + if (context->canProcess(func)) { addCallersToWorkList(func); } @@ -87,101 +192,139 @@ bool propagateFuncProperties(IRModule* module) for (Index i = 0; i < workList.getCount(); i++) { auto f = workList[i]; - bool hasSideEffectCall = false; - if (f->findDecoration<IRReadNoneDecoration>()) + if (!context->canProcess(f)) continue; + // Never propagate to functions without a body. if (f->getFirstBlock() == nullptr) continue; - if (f->findDecoration<IRTargetIntrinsicDecoration>()) - continue; - for (auto block : f->getBlocks()) + + if (context->propagate(builder, f)) + { + addCallersToWorkList(f); + changed = true; + } + } + result |= changed; + if (!changed) + break; + } + return result; +} + +class NoSideEffectFuncPropertyPropagationContext : public FuncPropertyPropagationContext +{ +public: + virtual bool canProcess(IRFunc* f) override + { + // If the func has already been marked with any decorations, skip. + for (auto decoration : f->getDecorations()) + { + switch (decoration->getOp()) { - for (auto inst : block->getChildren()) + case kIROp_ReadNoneDecoration: + case kIROp_NoSideEffectDecoration: + case kIROp_TargetIntrinsicDecoration: + return false; + } + } + return true; + } + + virtual bool propagate(IRBuilder& builder, IRFunc* f) override + { + bool hasSideEffectCall = false; + for (auto block : f->getBlocks()) + { + for (auto inst : block->getChildren()) + { + // Is this inst known to not have global side effect/analyzable? + if (inst->mightHaveSideEffects()) { - // Is this inst known to not have global side effect/analyzable? - if (inst->mightHaveSideEffects()) + switch (inst->getOp()) { - switch (inst->getOp()) - { - case kIROp_ifElse: - case kIROp_unconditionalBranch: - case kIROp_Switch: - case kIROp_Return: - case kIROp_loop: - case kIROp_Store: - case kIROp_Call: - case kIROp_Param: - case kIROp_Unreachable: - break; - default: - // We have a inst that has side effect and is not understood by this method. - // e.g. bufferStore, discard, etc. - hasSideEffectCall = true; - break; - } + case kIROp_ifElse: + case kIROp_unconditionalBranch: + case kIROp_Switch: + case kIROp_Return: + case kIROp_loop: + case kIROp_Call: + case kIROp_Param: + case kIROp_Unreachable: + case kIROp_Store: + case kIROp_SwizzledStore: + break; + default: + // We have a inst that has side effect and is not understood by this method. + // e.g. bufferStore, discard, etc. + hasSideEffectCall = true; + break; } + } + else + { + // A side effect free inst can't generate side effects for the function. + continue; + } - if (auto call = as<IRCall>(inst)) + if (auto call = as<IRCall>(inst)) + { + auto callee = getResolvedInstForDecorations(call->getCallee()); + switch (callee->getOp()) { - auto callee = getResolvedInstForDecorations(call->getCallee()); - switch (callee->getOp()) + default: + // We are calling an unknown function, so we have to assume + // there are side effects in the call. + hasSideEffectCall = true; + break; + case kIROp_Func: + if (!callee->findDecoration<IRReadNoneDecoration>() && + !callee->findDecoration<IRNoSideEffectDecoration>()) { - default: - // We are calling an unknown function, so we have to assume - // there are side effects in the call. hasSideEffectCall = true; break; - case kIROp_Func: - if (!callee->findDecoration<IRReadNoneDecoration>()) - { - hasSideEffectCall = true; - break; - } } } - - // Are any operands defined in global scope? - for (UInt o = 0; o < inst->getOperandCount(); o++) + } + + // Do any operands defined have pointer type of global or + // unknown source? Passing them into a NoSideEffect callee may cause + // side effects that breaks the NoSideEffect property. + for (UInt o = 0; o < inst->getOperandCount(); o++) + { + auto operand = inst->getOperand(o); + if (as<IRConstant>(operand)) + continue; + if (as<IRType>(operand)) + continue; + if (isGlobalOrUnknownMutableAddress(f, operand)) { - auto operand = inst->getOperand(o); - if (getParentFunc(operand) == f) - continue; - if (as<IRConstant>(operand)) - continue; - if (as<IRType>(operand)) - continue; - switch (operand->getOp()) - { - case kIROp_Specialize: - case kIROp_LookupWitness: - case kIROp_StructKey: - case kIROp_WitnessTable: - case kIROp_WitnessTableEntry: - case kIROp_undefined: - case kIROp_Func: - continue; - default: - break; - } hasSideEffectCall = true; break; } - } - if (hasSideEffectCall) break; + } } - if (!hasSideEffectCall) - { - builder.addDecoration(f, kIROp_ReadNoneDecoration); - addCallersToWorkList(f); - changed = true; - } + if (hasSideEffectCall) + break; } - result |= changed; - if (!changed) - break; + if (!hasSideEffectCall) + { + builder.addDecoration(f, kIROp_NoSideEffectDecoration); + return true; + } + return false; } - return result; +}; + +bool propagateFuncProperties(IRModule* module) +{ + ReadNoneFuncPropertyPropagationContext readNoneContext; + bool changed = propagateFuncPropertiesImpl(module, &readNoneContext); + + NoSideEffectFuncPropertyPropagationContext noSideEffectContext; + changed|= propagateFuncPropertiesImpl(module, &noSideEffectContext); + + return changed; } } diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index 797e5c9ea..c37284dce 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -6,6 +6,7 @@ #include "slang-ir-restructure.h" #include "slang-ir-util.h" #include "slang-ir-loop-unroll.h" +#include "slang-ir-reachability.h" namespace Slang { @@ -103,37 +104,59 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) HashSet<IRBlock*> loopBlocks; for (auto b : blocks) loopBlocks.add(b); - auto addressHasOutOfLoopUses = [&](IRInst* addr) + + ReachabilityContext reachability = {}; + + // Construct a map from a root address to all derived addresses. + Dictionary<IRInst*, List<IRInst*>> relatedAddrMap; + for (auto b : func->getBlocks()) { - // The entire access chain of `addr` must have no uses outside the loop. - // The root variable must be a local var. - for (auto chainNode = addr; chainNode;) + for (auto inst : b->getChildren()) { - if (getParentFunc(chainNode) != func) - return true; - for (auto use = chainNode->firstUse; use; use = use->nextUse) + if (as<IRPtrTypeBase>(inst->getDataType())) { - if (!loopBlocks.contains(as<IRBlock>(use->getUser()->getParent()))) - return true; + auto root = getRootAddr(inst); + if (!root) continue; + auto list = relatedAddrMap.tryGetValue(root); + if (!list) + { + relatedAddrMap.add(root, List<IRInst*>()); + list = relatedAddrMap.tryGetValue(root); + } + list->add(inst); } - switch (chainNode->getOp()) - { - case kIROp_GetElementPtr: - case kIROp_FieldAddress: - chainNode = chainNode->getOperand(0); + } + } + + auto addressHasOutOfLoopUses = [&](IRInst* addr) + { + auto rootAddr = getRootAddr(addr); + if (isGlobalOrUnknownMutableAddress(func, rootAddr)) + return true; + if (as<IRParam>(rootAddr)) + return true; + + // If we can't find the address from our map, we conservatively assume it is an unknown address. + auto relatedAddrs = relatedAddrMap.tryGetValue(getRootAddr(addr)); + if (!relatedAddrs) + return true; + + // For all related address of `addr` that may alias with it, we check their uses. + for (auto relatedAddr : *relatedAddrs) + { + if (!canAddressesPotentiallyAlias(func, relatedAddr, addr)) continue; - case kIROp_Var: - if (auto rate = chainNode->getFullType()->getRate()) + for (auto use = relatedAddr->firstUse; use; use = use->nextUse) + { + if (!loopBlocks.contains(as<IRBlock>(use->getUser()->getParent()))) { - if (!as<IRConstExprRate>(rate)) + // Is this use reachable from the loop header? + if (reachability.isInstReachable(loopInst, use->getUser())) return true; } - break; - default: - return true; } - break; } + return false; }; @@ -267,16 +290,8 @@ static bool isTrivialIfElseBranch(IRIfElse* condBranch, IRBlock* branchBlock) return false; } -static bool arePhiArgsEquivalentInBranches(IRIfElse* ifElse) +static bool arePhiArgsEquivalentInBranchesImpl(IRBlock* branch1, IRBlock* branch2, IRBlock* afterBlock) { - // If one of the branch target is afterBlock itself, and the other branch - // is a trivial block that jumps into the afterBlock, this if-else is trivial. - // In this case the argCount must be 0 because a block with phi parameters can't - // be used as targets in a conditional branch. - auto branch1 = ifElse->getTrueBlock(); - auto branch2 = ifElse->getFalseBlock(); - auto afterBlock = ifElse->getAfterBlock(); - if (branch1 == afterBlock) return true; if (branch2 == afterBlock) return true; @@ -291,7 +306,7 @@ static bool arePhiArgsEquivalentInBranches(IRIfElse* ifElse) // This should never happen, return false now to be safe. return false; } - + for (UInt i = 0; i < branchInst1->getArgCount(); i++) { if (branchInst1->getArg(i) != branchInst2->getArg(i)) @@ -303,6 +318,19 @@ static bool arePhiArgsEquivalentInBranches(IRIfElse* ifElse) return true; } +static bool arePhiArgsEquivalentInBranches(IRIfElse* ifElse) +{ + // If one of the branch target is afterBlock itself, and the other branch + // is a trivial block that jumps into the afterBlock, this if-else is trivial. + // In this case the argCount must be 0 because a block with phi parameters can't + // be used as targets in a conditional branch. + auto branch1 = ifElse->getTrueBlock(); + auto branch2 = ifElse->getFalseBlock(); + auto afterBlock = ifElse->getAfterBlock(); + + return arePhiArgsEquivalentInBranchesImpl(branch1, branch2, afterBlock); +} + static bool isTrivialIfElse(IRIfElse* condBranch, bool& isTrueBranchTrivial, bool& isFalseBranchTrivial) { isTrueBranchTrivial = isTrivialIfElseBranch(condBranch, condBranch->getTrueBlock()); @@ -315,6 +343,62 @@ static bool isTrivialIfElse(IRIfElse* condBranch, bool& isTrueBranchTrivial, boo return false; } +// Return the true of the switch branch block if the branch is a trivial jump +// to after block with no other insts. +static bool isTrivialSwitchBranch(IRSwitch* switchInst, IRBlock* branchBlock) +{ + if (branchBlock != switchInst->getBreakLabel()) + { + if (auto br = as<IRUnconditionalBranch>(branchBlock->getFirstOrdinaryInst())) + { + if (br->getTargetBlock() == switchInst->getBreakLabel() && br->getOp() == kIROp_unconditionalBranch) + { + return true; + } + } + } + else + { + return true; + } + return false; +} + +static bool arePhiArgsEquivalentInBranches(IRSwitch* switchInst) +{ + ShortList<IRBlock*> jumpTargets; + if (switchInst->getDefaultLabel()) + jumpTargets.add(switchInst->getDefaultLabel()); + for (UInt i = 0; i < switchInst->getCaseCount(); i++) + { + jumpTargets.add(switchInst->getCaseLabel(i)); + } + if (jumpTargets.getCount() == 0) + return true; + for (Index i = 1; i < jumpTargets.getCount(); i++) + { + auto branch1 = jumpTargets[0]; + auto branch2 = jumpTargets[i]; + auto afterBlock = switchInst->getBreakLabel(); + + if (!arePhiArgsEquivalentInBranchesImpl(branch1, branch2, afterBlock)) + return false; + } + return true; +} + +static bool isTrivialSwitch(IRSwitch* switchBranch) +{ + for (UInt i = 0; i < switchBranch->getCaseCount(); i++) + { + if (!isTrivialSwitchBranch(switchBranch, switchBranch->getCaseLabel(i))) + return false; + } + if (!isTrivialSwitchBranch(switchBranch, switchBranch->getDefaultLabel())) + return false; + return true; +} + static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) { bool isTrueBranchTrivial = false; @@ -338,6 +422,29 @@ static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst) return false; } +static bool trySimplifySwitch(IRBuilder& builder, IRSwitch* switchInst) +{ + if (!isTrivialSwitch(switchInst)) + return false; + if (switchInst->getCaseCount() == 0) + return false; + + auto termInst = as<IRUnconditionalBranch>(switchInst->getCaseLabel(0)->getTerminator()); + if (!termInst) + return false; + + if (!arePhiArgsEquivalentInBranches(switchInst)) + return false; + + List<IRInst*> args; + for (UInt i = 0; i < termInst->getArgCount(); i++) + args.add(termInst->getArg(i)); + builder.setInsertBefore(switchInst); + builder.emitBranch(switchInst->getBreakLabel(), (Int)args.getCount(), args.getBuffer()); + switchInst->removeAndDeallocate(); + return true; +} + static bool isTrueLit(IRInst* lit) { if (auto boolLit = as<IRBoolLit>(lit)) @@ -582,6 +689,10 @@ static bool processFunc(IRGlobalValueWithCode* func) { changed |= trySimplifyIfElse(builder, condBranch); } + else if (auto switchBranch = as<IRSwitch>(block->getTerminator())) + { + changed |= trySimplifySwitch(builder, switchBranch); + } // If `block` does not end with an unconditional branch, bail. if (block->getTerminator()->getOp() != kIROp_unconditionalBranch) diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index b0ad58f8c..9b7cdaea6 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -822,7 +822,7 @@ struct FunctionParameterSpecializationContext { decoration->removeAndDeallocate(); } - else if (as<IRReadNoneDecoration>(decoration)) + else if (as<IRReadNoneDecoration>(decoration) || as<IRNoSideEffectDecoration>(decoration)) { // After specialization, the function may no longer be side effect free // because the parameter we substituted in maybe a global param. diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 506e96c81..ba34a725d 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -180,6 +180,7 @@ bool isValueType(IRInst* dataType) case kIROp_AnyValueType: case kIROp_ArrayType: case kIROp_FuncType: + case kIROp_RaytracingAccelerationStructureType: return true; default: // Read-only resource handles are considered as Value type. @@ -406,12 +407,19 @@ bool isPtrLikeOrHandleType(IRInst* type) { if (!type) return false; + if (as<IRPointerLikeType>(type)) + return true; + if (as<IRPseudoPtrType>(type)) + return true; + if (as<IRMeshOutputType>(type)) + return true; + if (as<IRHLSLOutputPatchType>(type)) + return true; switch (type->getOp()) { case kIROp_ComPtrType: case kIROp_RawPointerType: case kIROp_RTTIPointerType: - case kIROp_PseudoPtrType: case kIROp_OutType: case kIROp_InOutType: case kIROp_PtrType: @@ -445,7 +453,7 @@ bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, I { auto callee = call->getCallee(); if (callee && - callee->findDecoration<IRReadNoneDecoration>()) + doesCalleeHaveSideEffect(callee)) { // An exception is if the callee is side-effect free and is not reading from // memory. @@ -641,71 +649,100 @@ void setInsertAfterOrdinaryInst(IRBuilder* builder, IRInst* inst) } } -bool isPureFunctionalCall(IRCall* call) +bool areCallArgumentsSideEffectFree(IRCall* call) { - auto callee = getResolvedInstForDecorations(call->getCallee()); - if (callee->findDecoration<IRReadNoneDecoration>()) + // If the function has no side effect and is not writing to any outputs, + // we can safely treat the call as a normal inst. + IRFunc* parentFunc = nullptr; + for (UInt i = 0; i < call->getArgCount(); i++) { - // If the function has no side effect and is not writing to any outputs, - // we can safely treat the call as a normal inst. - IRFunc* parentFunc = nullptr; - for (UInt i = 0; i < call->getArgCount(); i++) - { - auto arg = call->getArg(i); - if (isValueType(arg->getDataType())) - continue; + auto arg = call->getArg(i); + if (isValueType(arg->getDataType())) + continue; - // If the argument type is not a known value type, - // assume it is a pointer or handle through which side effect can take place. + // If the argument type is not a known value type, + // assume it is a pointer or handle through which side effect can take place. + if (!parentFunc) + { + parentFunc = getParentFunc(call); if (!parentFunc) - { - parentFunc = getParentFunc(call); - if (!parentFunc) - return false; - } + return false; + } - if (arg->getOp() == kIROp_Var && getParentFunc(arg) == parentFunc) + if (arg->getOp() == kIROp_Var && getParentFunc(arg) == parentFunc) + { + // If the pointer argument is a local variable (thus can't alias with other addresses) + // and it is never read from in the function, we can safely treat the call as having + // no side-effect. + // This is a conservative test, but is sufficient to detect the most common case where + // a temporary variable is used as the inout argument and the result stored in the temp + // variable isn't being used elsewhere in the parent func. + // + // A more aggresive test can check all other address uses reachable from the call site + // and see if any of them are aliasing with the argument. + for (auto use = arg->firstUse; use; use = use->nextUse) { - // If the pointer argument is a local variable (thus can't alias with other addresses) - // and it is never read from in the function, we can safely treat the call as having - // no side-effect. - // This is a conservative test, but is sufficient to detect the most common case where - // a temporary variable is used as the inout argument and the result stored in the temp - // variable isn't being used elsewhere in the parent func. - // - // A more aggresive test can check all other address uses reachable from the call site - // and see if any of them are aliasing with the argument. - for (auto use = arg->firstUse; use; use = use->nextUse) + if (as<IRDecoration>(use->getUser())) + continue; + switch (use->getUser()->getOp()) { - if (as<IRDecoration>(use->getUser())) - continue; - switch (use->getUser()->getOp()) - { - case kIROp_Store: - // We are fine with stores into the variable, since store operations - // are not dependent on whatever we do in the call here. + case kIROp_Store: + case kIROp_SwizzledStore: + // We are fine with stores into the variable, since store operations + // are not dependent on whatever we do in the call here. + continue; + default: + // Skip the call itself, since we are checking if the call has side effect. + if (use->getUser() == call) continue; - default: - // Skip the call itself, since we are checking if the call has side effect. - if (use->getUser() == call) - continue; - // We have some other unknown use of the variable address, they can - // be loads, or calls using addresses derived from the variable, - // we will treat the call as having side effect to be safe. - return false; - } + // We have some other unknown use of the variable address, they can + // be loads, or calls using addresses derived from the variable, + // we will treat the call as having side effect to be safe. + return false; } } - else - { - return false; - } } - return true; + else + { + return false; + } + } + return true; +} + +bool isPureFunctionalCall(IRCall* call) +{ + auto callee = getResolvedInstForDecorations(call->getCallee()); + if (callee->findDecoration<IRReadNoneDecoration>()) + { + return areCallArgumentsSideEffectFree(call); } return false; } +bool isSideEffectFreeFunctionalCall(IRCall* call) +{ + if (!doesCalleeHaveSideEffect(call->getCallee())) + { + return areCallArgumentsSideEffectFree(call); + } + return false; +} + +bool doesCalleeHaveSideEffect(IRInst* callee) +{ + for (auto decor : getResolvedInstForDecorations(callee)->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_NoSideEffectDecoration: + case kIROp_ReadNoneDecoration: + return false; + } + } + return true; +} + IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key) { for (UInt i = 0; i < type->getOperandCount(); i++) @@ -779,6 +816,40 @@ int getParamIndexInBlock(IRParam* paramInst) return -1; } +bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst) +{ + auto root = getRootAddr(inst); + + auto type = unwrapAttributedType(inst->getDataType()); + if (!isPtrLikeOrHandleType(type)) + return false; + + switch (root->getOp()) + { + case kIROp_GlobalVar: + return true; + case kIROp_GlobalParam: + case kIROp_GlobalConstant: + case kIROp_Var: + case kIROp_Param: + break; + default: + // The inst is defined by an unknown inst. + return true; + } + + if (root) + { + if (as<IRParameterGroupType>(root->getDataType())) + { + return false; + } + auto addrInstParent = getParentFunc(root); + return (addrInstParent != parentFunc); + } + return false; +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index 492a9f312..9fd6dd972 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -167,10 +167,18 @@ bool canAddressesPotentiallyAlias(IRGlobalValueWithCode* func, IRInst* addr1, IR String dumpIRToString(IRInst* root); -// Returns whether a call insts can be treated as a pure functional inst -// (no writes to memory, no side effects). +// Returns whether a call insts can be treated as a pure functional inst, and thus can be +// DCE'd and deduplicated. +// (no writes to memory, no reads from unknown memory, no side effects). bool isPureFunctionalCall(IRCall* callInst); +// Returns whether a call insts can be treated as a pure functional inst, and thus can be +// DCE'd (but not necessarily deduplicated). +// (no side effects). +bool isSideEffectFreeFunctionalCall(IRCall* call); + +bool doesCalleeHaveSideEffect(IRInst* callee); + bool isPtrLikeOrHandleType(IRInst* type); bool canInstHaveSideEffectAtAddress(IRGlobalValueWithCode* func, IRInst* inst, IRInst* addr); @@ -205,6 +213,7 @@ void removePhiArgs(IRInst* phiParam); int getParamIndexInBlock(IRParam* paramInst); +bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* inst); } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 78b62265b..6adf8ee1c 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7157,7 +7157,7 @@ namespace Slang // will treat it so, by-passing all other checks. if (call->findDecoration<IRNoSideEffectDecoration>()) return false; - return !isPureFunctionalCall(call); + return !isSideEffectFreeFunctionalCall(call); } break; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index de2d5aff2..ad338709d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7664,7 +7664,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> irAggType->moveToEnd(); addTargetIntrinsicDecorations(irAggType, decl); - + for (auto modifier : decl->modifiers) + { + if (as<NonCopyableTypeAttribute>(modifier)) + subBuilder->addNonCopyableTypeDecoration(irAggType); + } return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irAggType, outerGeneric)); } @@ -8779,6 +8783,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); } + else if (as<NoSideEffectAttribute>(modifier)) + { + getBuilder()->addSimpleDecoration<IRNoSideEffectDecoration>(irFunc); + } else if (as<EarlyDepthStencilAttribute>(modifier)) { getBuilder()->addSimpleDecoration<IREarlyDepthStencilDecoration>(irFunc); diff --git a/tests/cross-compile/func-resource-param-array.slang b/tests/cross-compile/func-resource-param-array.slang index 7062169dc..d7b7fca99 100644 --- a/tests/cross-compile/func-resource-param-array.slang +++ b/tests/cross-compile/func-resource-param-array.slang @@ -32,7 +32,6 @@ void main(uint3 tid : SV_DispatchThreadID) { uint ii = tid.x; uint jj = tid.y; - uint kk = tid.z; // Can we specialize `f`? // @@ -56,7 +55,7 @@ void main(uint3 tid : SV_DispatchThreadID) // What if the function takes an array, and we pass // in an element of an array-of-arrays? // - tmp += g(c[ii], jj, kk); + tmp += g(c[ii], jj, tid.z); a[ii] = tmp; } diff --git a/tests/cross-compile/func-resource-param-array.slang.glsl b/tests/cross-compile/func-resource-param-array.slang.glsl index e9d1b5a97..d5d1bd08c 100644 --- a/tests/cross-compile/func-resource-param-array.slang.glsl +++ b/tests/cross-compile/func-resource-param-array.slang.glsl @@ -1,93 +1,48 @@ -// func-resource-param-array.slang.glsl #version 450 +layout(row_major) uniform; +layout(row_major) buffer; -#define a a_0 -#define b b_0 -#define c c_0 -#define ii ii_0 -#define jj jj_0 -#define kk kk_0 - -#define f_a f_0 -#define f_b f_1 -#define g_b g_0 -#define g_c g_1 - -#define a_block _S1 -#define b_block _S2 -#define c_block _S3 - -#define f_a_i _S4 -#define f_b_t _S5 -#define f_b_i _S6 -#define g_b_i _S7 -#define g_b_j _S8 -#define g_c_t _S9 -#define g_c_i _S10 -#define g_c_j _S11 - -#define tmp_f_a_ii _S12 -#define tmp_f_a_jj _S13 - -#define tmp_f_b _S14 -#define tmp_g_b _S15 -#define tmp_g_c _S16 - -layout(std430, binding = 0) buffer a_block { +layout(std430, binding = 0) buffer _S1 { int _data[]; -} a; +} a_0; -layout(std430, binding = 1) buffer b_block { +layout(std430, binding = 1) buffer _S2 { int _data[]; -} b[3]; +} b_0[3]; -layout(std430, binding = 2) buffer c_block { +layout(std430, binding = 2) buffer _S3 { int _data[]; -} c[4][3]; +} c_0[4][3]; -int f_a(uint f_a_i) +int f_0(uint _S4) { - return a._data[f_a_i]; + return ((a_0)._data[(_S4)]); } -int f_b(uint f_b_t, uint f_b_i) +int f_1(uint _S5, uint _S6) { - return b[f_b_t]._data[f_b_i]; + return ((b_0[_S5])._data[(_S6)]); } -int g_b(uint g_b_i, uint g_b_j) +int g_0(uint _S7, uint _S8) { - return b[g_b_i]._data[g_b_j]; + return ((b_0[_S7])._data[(_S8)]); } -int g_c(uint g_c_t, uint g_c_i, uint g_c_j) +int g_1(uint _S9, uint _S10, uint _S11) { - return c[g_c_t][g_c_i]._data[g_c_j]; + return ((c_0[_S9][_S10])._data[(_S11)]); } layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; - void main() { - uint ii = gl_GlobalInvocationID.x; - uint jj = gl_GlobalInvocationID.y; - uint kk = gl_GlobalInvocationID.z; + uint ii_0 = gl_GlobalInvocationID.x; + uint jj_0 = gl_GlobalInvocationID.y; - int tmp_f_a_ii = f_a(ii); - - int tmp_f_a_jj = f_a(jj); - int tmp_0 = tmp_f_a_ii + tmp_f_a_jj; - - int tmp_f_b = f_b(ii, jj); - int tmp_1 = tmp_0 + tmp_f_b; - - int tmp_g_b = g_b(ii, jj); - int tmp_2 = tmp_1 + tmp_g_b; - - int tmp_g_c = g_c(ii, jj, kk); - int tmp_3 = tmp_2 + tmp_g_c; - - a._data[ii] = tmp_3; + int tmp_0 = f_0(ii_0) + f_0(jj_0) + f_1(ii_0, jj_0) + g_0(ii_0, jj_0) + g_1(ii_0, jj_0, gl_GlobalInvocationID.z); + ((a_0)._data[(ii_0)]) = tmp_0; return; } + diff --git a/tests/cross-compile/vk-texture-indexing.slang.glsl b/tests/cross-compile/vk-texture-indexing.slang.glsl index 069e6efc3..73513c623 100644 --- a/tests/cross-compile/vk-texture-indexing.slang.glsl +++ b/tests/cross-compile/vk-texture-indexing.slang.glsl @@ -1,16 +1,16 @@ #version 450 -#extension GL_EXT_nonuniform_qualifier : require #extension GL_EXT_samplerless_texture_functions : require +#extension GL_EXT_nonuniform_qualifier : require +layout(row_major) uniform; +layout(row_major) buffer; layout(binding = 0) uniform texture2D gParams_textures_0[10]; + float fetchData_0(uvec2 coords_0, uint index_0) { - float _S1 = texelFetch( - gParams_textures_0[nonuniformEXT(index_0)], - ivec2(coords_0), - 0).x; + float _S1 = (texelFetch((gParams_textures_0[nonuniformEXT(index_0)]), ivec2((coords_0)), 0).x); return _S1; } @@ -18,12 +18,16 @@ float fetchData_0(uvec2 coords_0, uint index_0) layout(location = 0) out vec4 _S2; + flat layout(location = 0) in uvec3 _S3; + void main() { - float v_0 = fetchData_0(_S3.xy, _S3.z); - _S2 = vec4(v_0); + + _S2 = vec4(fetchData_0(_S3.xy, _S3.z)); + return; } + diff --git a/tests/experimental/liveness/liveness-4.slang.expected b/tests/experimental/liveness/liveness-4.slang.expected index 802388e40..ee7986319 100644 --- a/tests/experimental/liveness/liveness-4.slang.expected +++ b/tests/experimental/liveness/liveness-4.slang.expected @@ -10,25 +10,15 @@ layout(std430, binding = 0) buffer _S1 { int _data[]; } outputBuffer_0; spirv_instruction(id = 256) -void livenessStart_0(spirv_by_reference int _0[2], spirv_literal int _1); - -spirv_instruction(id = 256) -void livenessStart_1(spirv_by_reference int _0, spirv_literal int _1); +void livenessStart_0(spirv_by_reference int _0, spirv_literal int _1); spirv_instruction(id = 257) void livenessEnd_0(spirv_by_reference int _0, spirv_literal int _1); -spirv_instruction(id = 257) -void livenessEnd_1(spirv_by_reference int _0[2], spirv_literal int _1); - int calcThing_0(int offset_0) { - int another_0[2]; - livenessStart_0(another_0, 0); - another_0[0] = 1; - another_0[1] = 2; int k_0; - livenessStart_1(k_0, 0); + livenessStart_0(k_0, 0); k_0 = 0; for(;;) { @@ -41,33 +31,16 @@ int calcThing_0(int offset_0) break; } bool _S2 = (k_0 + 7) % 5 == 4; - int k_1 = k_0 + 1; - int i_0; - livenessStart_1(i_0, 0); - i_0 = 0; - for(;;) - { - if(i_0 < 17) - { - } - else - { - livenessEnd_0(i_0, 0); - break; - } - another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); - i_0 = i_0 + 1; - } + int _S3 = k_0; livenessEnd_0(k_0, 0); + int k_1 = _S3 + 1; if(_S2) { - livenessEnd_1(another_0, 0); return 1; } - livenessStart_1(k_0, 0); + livenessStart_0(k_0, 0); k_0 = k_1; } - livenessEnd_1(another_0, 0); return -2; } @@ -75,8 +48,8 @@ layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - int _S3 = calcThing_0(index_0); - ((outputBuffer_0)._data[(uint(index_0))]) = _S3; + int _S4 = calcThing_0(index_0); + ((outputBuffer_0)._data[(uint(index_0))]) = _S4; return; } diff --git a/tests/ir/dce-rw-buffer-load.slang b/tests/ir/dce-rw-buffer-load.slang new file mode 100644 index 000000000..01a14ee45 --- /dev/null +++ b/tests/ir/dce-rw-buffer-load.slang @@ -0,0 +1,21 @@ +//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none + +// Test that we can DCE load of a rw buffer. + +RWStructuredBuffer<float> gOutputBuffer; + +float test() +{ + return gOutputBuffer[0]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + test(); +} + +// CHECK: void computeMain +// CHECK-NOT: test +// CHECK: } + diff --git a/tests/pipeline/ray-tracing/acceleration-structure-in-compute-nv.slang.glsl b/tests/pipeline/ray-tracing/acceleration-structure-in-compute-nv.slang.glsl index 5819d657d..0374569fe 100644 --- a/tests/pipeline/ray-tracing/acceleration-structure-in-compute-nv.slang.glsl +++ b/tests/pipeline/ray-tracing/acceleration-structure-in-compute-nv.slang.glsl @@ -1,18 +1,11 @@ #version 460 - -#extension GL_NV_ray_tracing : require - -int helper_0(accelerationStructureNV a_0, int b_0) -{ - return b_0; -} - -layout(binding = 1) -uniform accelerationStructureNV entryPointParams_x_0; +layout(row_major) uniform; +layout(row_major) buffer; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { - int _S1 = helper_0(entryPointParams_x_0, 1); + return; } + diff --git a/tests/pipeline/ray-tracing/acceleration-structure-in-compute.slang.glsl b/tests/pipeline/ray-tracing/acceleration-structure-in-compute.slang.glsl index 83797d2d5..0374569fe 100644 --- a/tests/pipeline/ray-tracing/acceleration-structure-in-compute.slang.glsl +++ b/tests/pipeline/ray-tracing/acceleration-structure-in-compute.slang.glsl @@ -1,18 +1,11 @@ #version 460 -#extension GL_EXT_ray_tracing : require layout(row_major) uniform; layout(row_major) buffer; -int helper_0(accelerationStructureEXT a_0, int b_0) -{ - return b_0; -} - -layout(binding = 1) -uniform accelerationStructureEXT entryPointParams_x_0; layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { - int _S1 = helper_0(entryPointParams_x_0, 1); + return; } + diff --git a/tests/pipeline/ray-tracing/ray-query-subroutine.slang b/tests/pipeline/ray-tracing/ray-query-subroutine.slang index 501071717..3279acc12 100644 --- a/tests/pipeline/ray-tracing/ray-query-subroutine.slang +++ b/tests/pipeline/ray-tracing/ray-query-subroutine.slang @@ -7,7 +7,7 @@ RWStructuredBuffer<int> gOutput; RaytracingAccelerationStructure gScene; -int helper<let N : int>(RayQuery<N> q) +float3 helper<let N : int>(RayQuery<N> q) { RayDesc ray; ray.Origin = 0; @@ -20,7 +20,7 @@ int helper<let N : int>(RayQuery<N> q) /* instanceInclusionmask: */ 0xFFFFFFFF, /* ray: */ ray ); - return 1; + return q.WorldRayDirection(); } [shader("compute")] @@ -36,5 +36,5 @@ void computeMain(uint tid : SV_DispatchThreadID) RayQuery<0> rayQuery; let result = helper(rayQuery); - gOutput[tid.x] = result; + gOutput[tid.x] = int(result.x); } diff --git a/tests/pipeline/ray-tracing/ray-query-subroutine.slang.hlsl b/tests/pipeline/ray-tracing/ray-query-subroutine.slang.hlsl index 5906823e5..db97f4278 100644 --- a/tests/pipeline/ray-tracing/ray-query-subroutine.slang.hlsl +++ b/tests/pipeline/ray-tracing/ray-query-subroutine.slang.hlsl @@ -1,26 +1,32 @@ -//TEST_IGNORE_FILE: - RaytracingAccelerationStructure gScene_0 : register(t0); -int helper_0(RayQuery<int(0) > q_0) +RWStructuredBuffer<int > gOutput_0 : register(u0); + +float3 helper_0(RayQuery<int(0) > q_0) { + RayQuery<int(0) > _S1 = q_0; + RayDesc ray_0; - ray_0.Origin = (vector<float,3>) int(0); - ray_0.Direction = (vector<float,3>) int(0); - ray_0.TMin = (float) int(0); - ray_0.TMax = 1000.00000000000000000000; - q_0.TraceRayInline(gScene_0, (uint) int(0), (uint) int(-1), ray_0); - return int(1); + float3 _S2 = (float3)0.0; + + ray_0.Origin = _S2; + ray_0.Direction = _S2; + ray_0.TMin = 0.0; + ray_0.TMax = 1000.0; + _S1.TraceRayInline(gScene_0, 0U, 4294967295U, ray_0); + + return _S1.WorldRayDirection(); } -RWStructuredBuffer<int > gOutput_0 : register(u0); [shader("compute")][numthreads(1, 1, 1)] void computeMain(uint tid_0 : SV_DISPATCHTHREADID) { RayQuery<int(0) > rayQuery_0; - int _S1 = helper_0(rayQuery_0); - gOutput_0[tid_0.x] = _S1; + int _S3 = int(helper_0(rayQuery_0).x); + + gOutput_0[tid_0.x] = _S3; return; } + diff --git a/tests/pipeline/ray-tracing/trace-ray-inline.slang b/tests/pipeline/ray-tracing/trace-ray-inline.slang index d44500e18..e952bb802 100644 --- a/tests/pipeline/ray-tracing/trace-ray-inline.slang +++ b/tests/pipeline/ray-tracing/trace-ray-inline.slang @@ -83,6 +83,8 @@ bool myProceduralIntersection(inout float tHit, inout MyProceduralHitAttrs hitAt return true; } +RWStructuredBuffer<int> resultBuffer; + // In order to kick of tracing we need the properties of a ray // query to trace, so we will pipe those in via a constant buffer. // @@ -169,4 +171,6 @@ void main(uint3 tid : SV_DispatchThreadID) myMiss(payload); break; } + + resultBuffer[index] = payload.value; }
\ No newline at end of file diff --git a/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl b/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl index bb605a14a..56926e956 100644 --- a/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl +++ b/tests/pipeline/ray-tracing/trace-ray-inline.slang.glsl @@ -4,7 +4,7 @@ layout(row_major) uniform; layout(row_major) buffer; -#line 89 "tests/pipeline/ray-tracing/trace-ray-inline.slang" +#line 91 "tests/pipeline/ray-tracing/trace-ray-inline.slang" struct SLANG_ParameterGroup_C_0 { vec3 origin_0; @@ -17,8 +17,8 @@ struct SLANG_ParameterGroup_C_0 }; -#line 89 -layout(binding = 1) +#line 91 +layout(binding = 2) layout(std140) uniform _S1 { vec3 origin_0; @@ -35,6 +35,11 @@ layout(binding = 0) uniform accelerationStructureEXT myAccelerationStructure_0; +#line 86 +layout(std430, binding = 1) buffer _S2 { + int _data[]; +} resultBuffer_0; + #line 59 struct MyProceduralHitAttrs_0 { @@ -94,61 +99,64 @@ void myMiss_0(inout MyRayPayload_0 payload_4) } -#line 103 +#line 105 layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { - rayQueryEXT query_0; - +#line 107 + uint index_0 = gl_GlobalInvocationID.x; +#line 112 MyRayPayload_0 payload_5; -#line 110 +#line 112 payload_5.value_1 = -1; +#line 109 + rayQueryEXT query_0; + +#line 114 rayQueryInitializeEXT((query_0), (myAccelerationStructure_0), (C_0.rayFlags_0 | 512), (C_0.instanceMask_0), (C_0.origin_0), (C_0.tMin_0), (C_0.direction_0), (C_0.tMax_0)); -#line 112 +#line 114 MyProceduralHitAttrs_0 committedProceduralAttrs_0; -#line 112 +#line 114 for(;;) { -#line 121 - bool _S2 = rayQueryProceedEXT(query_0); +#line 123 + bool _S3 = rayQueryProceedEXT(query_0); -#line 121 - if(!_S2) +#line 123 + if(!_S3) { -#line 121 +#line 123 break; } - uint _S3 = (rayQueryGetIntersectionTypeEXT((query_0), false)); #line 123 MyProceduralHitAttrs_0 committedProceduralAttrs_1; -#line 123 - switch(_S3) + switch((rayQueryGetIntersectionTypeEXT((query_0), false))) { case 1U: { MyProceduralHitAttrs_0 candidateProceduralAttrs_0; -#line 127 +#line 129 candidateProceduralAttrs_0.value_0 = 0; float tHit_1 = 0.0; bool _S4 = myProceduralIntersection_0(tHit_1, candidateProceduralAttrs_0); -#line 129 +#line 131 if(_S4) { bool _S5 = myProceduralAnyHit_0(payload_5); -#line 131 +#line 133 if(_S5) { rayQueryGenerateIntersectionEXT(query_0, tHit_1); @@ -156,136 +164,132 @@ void main() if(C_0.shouldStopAtFirstHit_0 != 0U) { -#line 136 +#line 138 rayQueryTerminateEXT(query_0); -#line 135 +#line 137 } else { -#line 135 +#line 137 } -#line 135 +#line 137 committedProceduralAttrs_1 = _S6; -#line 135 +#line 137 } else { -#line 135 +#line 137 committedProceduralAttrs_1 = committedProceduralAttrs_0; -#line 135 +#line 137 } -#line 135 +#line 137 } else { -#line 135 +#line 137 committedProceduralAttrs_1 = committedProceduralAttrs_0; -#line 135 +#line 137 } -#line 135 +#line 137 break; } case 0U: { -#line 144 +#line 146 bool _S7 = myTriangleAnyHit_0(payload_5); -#line 144 +#line 146 if(_S7) { rayQueryConfirmIntersectionEXT(query_0); if(C_0.shouldStopAtFirstHit_0 != 0U) { -#line 148 +#line 150 rayQueryTerminateEXT(query_0); -#line 147 +#line 149 } else { -#line 147 +#line 149 } -#line 144 +#line 146 } else { -#line 144 +#line 146 } -#line 144 +#line 146 committedProceduralAttrs_1 = committedProceduralAttrs_0; -#line 144 +#line 146 break; } default: { -#line 144 +#line 146 committedProceduralAttrs_1 = committedProceduralAttrs_0; -#line 144 +#line 146 break; } } -#line 119 +#line 121 committedProceduralAttrs_0 = committedProceduralAttrs_1; -#line 119 +#line 121 } -#line 158 - uint _S8 = (rayQueryGetIntersectionTypeEXT((query_0), true)); - -#line 158 - switch(_S8) +#line 160 + switch((rayQueryGetIntersectionTypeEXT((query_0), true))) { case 1U: { -#line 161 +#line 163 myTriangleClosestHit_0(payload_5); break; } case 2U: { -#line 165 +#line 167 myProceduralClosestHit_0(payload_5, committedProceduralAttrs_0); break; } case 0U: { -#line 169 +#line 171 myMiss_0(payload_5); break; } default: { -#line 170 +#line 172 break; } } - -#line 172 + ((resultBuffer_0)._data[(index_0)]) = payload_5.value_1; return; } diff --git a/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl b/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl index a85415065..e96cbb8f4 100644 --- a/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl +++ b/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl @@ -4,6 +4,8 @@ #endif #pragma warning(disable: 3557) + +#line 91 "tests/pipeline/ray-tracing/trace-ray-inline.slang" struct SLANG_ParameterGroup_C_0 { float3 origin_0; @@ -15,150 +17,272 @@ struct SLANG_ParameterGroup_C_0 uint shouldStopAtFirstHit_0; }; + +#line 91 cbuffer C_0 : register(b0) { SLANG_ParameterGroup_C_0 C_0; } + +#line 12 RaytracingAccelerationStructure myAccelerationStructure_0 : register(t0); + +#line 86 +RWStructuredBuffer<int > resultBuffer_0 : register(u0); + + +#line 59 struct MyProceduralHitAttrs_0 { int value_0; }; + +#line 81 bool myProceduralIntersection_0(inout float tHit_0, inout MyProceduralHitAttrs_0 hitAttrs_0) { return true; } + +#line 26 struct MyRayPayload_0 { int value_1; }; + +#line 69 bool myProceduralAnyHit_0(inout MyRayPayload_0 payload_0) { return true; } + +#line 51 bool myTriangleAnyHit_0(inout MyRayPayload_0 payload_1) { return true; } + +#line 40 void myTriangleClosestHit_0(inout MyRayPayload_0 payload_2) { payload_2.value_1 = int(1); return; } + +#line 65 void myProceduralClosestHit_0(inout MyRayPayload_0 payload_3, MyProceduralHitAttrs_0 attrs_0) { payload_3.value_1 = attrs_0.value_0; return; } + +#line 33 void myMiss_0(inout MyRayPayload_0 payload_4) { payload_4.value_1 = int(0); return; } + +#line 105 [shader("compute")][numthreads(1, 1, 1)] void main(uint3 tid_0 : SV_DISPATCHTHREADID) { - RayQuery<int(512) > query_0; +#line 107 + uint index_0 = tid_0.x; + +#line 112 MyRayPayload_0 payload_5; + +#line 112 payload_5.value_1 = int(-1); RayDesc ray_0 = { C_0.origin_0, C_0.tMin_0, C_0.direction_0, C_0.tMax_0 }; + +#line 109 + RayQuery<int(512) > query_0; + +#line 114 query_0.TraceRayInline(myAccelerationStructure_0, C_0.rayFlags_0, C_0.instanceMask_0, ray_0); + +#line 114 MyProceduralHitAttrs_0 committedProceduralAttrs_0; + +#line 114 for(;;) { + +#line 123 bool _S1 = query_0.Proceed(); + +#line 123 if(!_S1) { + +#line 123 break; } - uint _S2 = query_0.CandidateType(); + +#line 123 MyProceduralHitAttrs_0 committedProceduralAttrs_1; - switch(_S2) + + switch(query_0.CandidateType()) { case 1U: { MyProceduralHitAttrs_0 candidateProceduralAttrs_0; + +#line 129 candidateProceduralAttrs_0.value_0 = int(0); float tHit_1 = 0.0; - bool _S3 = myProceduralIntersection_0(tHit_1, candidateProceduralAttrs_0); - if(_S3) + bool _S2 = myProceduralIntersection_0(tHit_1, candidateProceduralAttrs_0); + +#line 131 + if(_S2) { - bool _S4 = myProceduralAnyHit_0(payload_5); - if(_S4) + bool _S3 = myProceduralAnyHit_0(payload_5); + +#line 133 + if(_S3) { query_0.CommitProceduralPrimitiveHit(tHit_1); - MyProceduralHitAttrs_0 _S5 = candidateProceduralAttrs_0; + MyProceduralHitAttrs_0 _S4 = candidateProceduralAttrs_0; if(C_0.shouldStopAtFirstHit_0 != 0U) { + +#line 138 query_0.Abort(); + +#line 137 + } + else + { + +#line 137 } - committedProceduralAttrs_1 = _S5; + +#line 137 + committedProceduralAttrs_1 = _S4; + +#line 137 } else { + +#line 137 committedProceduralAttrs_1 = committedProceduralAttrs_0; + +#line 137 } + +#line 137 } else { + +#line 137 committedProceduralAttrs_1 = committedProceduralAttrs_0; + +#line 137 } + +#line 137 break; } case 0U: { - bool _S6 = myTriangleAnyHit_0(payload_5); - if(_S6) + +#line 146 + bool _S5 = myTriangleAnyHit_0(payload_5); + +#line 146 + if(_S5) { query_0.CommitNonOpaqueTriangleHit(); if(C_0.shouldStopAtFirstHit_0 != 0U) { + +#line 150 query_0.Abort(); + +#line 149 + } + else + { + +#line 149 } + +#line 146 + } + else + { + +#line 146 } + +#line 146 committedProceduralAttrs_1 = committedProceduralAttrs_0; + +#line 146 break; } default: { + +#line 146 committedProceduralAttrs_1 = committedProceduralAttrs_0; + +#line 146 break; } } + +#line 121 committedProceduralAttrs_0 = committedProceduralAttrs_1; + +#line 121 } - uint _S7 = query_0.CommittedStatus(); - switch(_S7) + +#line 160 + switch(query_0.CommittedStatus()) { case 1U: { + +#line 163 myTriangleClosestHit_0(payload_5); break; } case 2U: { + +#line 167 myProceduralClosestHit_0(payload_5, committedProceduralAttrs_0); break; } case 0U: { + +#line 171 myMiss_0(payload_5); break; } default: { + +#line 172 break; } } + resultBuffer_0[index_0] = payload_5.value_1; return; } + |
