diff options
| author | Yong He <yonghe@outlook.com> | 2024-09-20 20:21:18 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-20 20:21:18 -0700 |
| commit | 53684ed919ff2f5f3656aed2e95a111207452392 (patch) | |
| tree | cf6926f21797d99534f22da121beceb085b6035a | |
| parent | c42b5e24b5b9d6b03352d809e0a49485d361154f (diff) | |
Fix handling of pointer logic in wgsl backend. (#5129)
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.cpp | 179 | ||||
| -rw-r--r-- | source/slang/slang-emit-wgsl.h | 42 | ||||
| -rw-r--r-- | source/slang/slang-ir-wgsl-legalize.cpp | 82 | ||||
| -rw-r--r-- | tests/language-feature/atomic-t/atomic-0.slang | 14 | ||||
| -rw-r--r-- | tests/wgsl/inout.slang | 32 |
7 files changed, 183 insertions, 189 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index c60397b85..28eb7da04 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1646,16 +1646,11 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) return true; } -bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */) -{ - return doesTargetSupportPtrTypes(); -} - void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec) { EmitOpInfo newOuterPrec = outerPrec; - if (isPointerSyntaxRequiredImpl(inst)) + if (doesTargetSupportPtrTypes()) { switch (inst->getOp()) { @@ -1754,7 +1749,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec) { - if (isPointerSyntaxRequiredImpl(inst)) + if (doesTargetSupportPtrTypes()) { auto prec = getInfo(EmitOp::Prefix); auto newOuterPrec = outerPrec; @@ -2003,6 +1998,11 @@ void CLikeSourceEmitter::emitIntrinsicCallExprImpl( } } +void CLikeSourceEmitter::emitCallArg(IRInst* inst) +{ + emitOperand(inst, getInfo(EmitOp::General)); +} + void CLikeSourceEmitter::_emitCallArgList(IRCall* inst, int startingOperandIndex) { bool isFirstArg = true; @@ -2023,7 +2023,7 @@ void CLikeSourceEmitter::_emitCallArgList(IRCall* inst, int startingOperandIndex m_writer->emit(", "); else isFirstArg = false; - emitOperand(inst->getOperand(aa), getInfo(EmitOp::General)); + emitCallArg(inst->getOperand(aa)); } m_writer->emit(")"); } @@ -2296,7 +2296,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO IRFieldAddress* ii = (IRFieldAddress*) inst; - if (isPointerSyntaxRequiredImpl(inst)) + if (doesTargetSupportPtrTypes()) { auto prec = getInfo(EmitOp::Prefix); needClose = maybeEmitParens(outerPrec, prec); @@ -4206,7 +4206,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitRateQualifiersAndAddressSpace(varDecl); emitVarKeyword(varType, varDecl); - emitGlobalParamType(varType, getName(varDecl)); + emitType(varType, getName(varDecl)); emitSemantics(varDecl); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index ccc25de57..3cccad9e6 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -257,7 +257,6 @@ public: void emitType(IRType* type); void emitType(IRType* type, Name* name, SourceLoc const& nameLoc); void emitType(IRType* type, NameLoc const& nameAndLoc); - virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);} bool hasExplicitConstantBufferOffset(IRInst* cbufferType); bool isSingleElementConstantBuffer(IRInst* cbufferType); bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType); @@ -430,7 +429,6 @@ public: void emitGlobalInst(IRInst* inst); virtual void emitGlobalInstImpl(IRInst* inst); - virtual bool isPointerSyntaxRequiredImpl(IRInst* inst); void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition); @@ -567,6 +565,7 @@ public: // Emit the argument list (including paranthesis) in a `CallInst` void _emitCallArgList(IRCall* call, int startingOperandIndex = 1); + virtual void emitCallArg(IRInst* arg); String _generateUniqueName(const UnownedStringSlice& slice); diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 833e1c8fe..051cdb820 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -23,12 +23,13 @@ // 'transpose' calls, or else perform more complicated transformations that // end up duplicating expressions many times. -namespace Slang { +namespace Slang +{ void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl( IRBasicType *const switchConditionType, - const SwitchRegion::Case *const currentCase, const bool isDefault - ) + const SwitchRegion::Case *const currentCase, + const bool isDefault) { // WGSL has special syntax for blocks sharing case labels: // "case 2, 3, 4: ...;" instead of the C-like syntax @@ -80,8 +81,8 @@ void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl( } void WGSLSourceEmitter::emitParameterGroupImpl( - IRGlobalParam* varDecl, IRUniformParameterGroupType* type -) + IRGlobalParam* varDecl, + IRUniformParameterGroupType* type) { auto varLayout = getVarLayout(varDecl); SLANG_RELEASE_ASSERT(varLayout); @@ -140,8 +141,8 @@ void WGSLSourceEmitter::emitParameterGroupImpl( } void WGSLSourceEmitter::emitEntryPointAttributesImpl( - IRFunc* irFunc, IREntryPointDecoration* entryPointDecor - ) + IRFunc* irFunc, + IREntryPointDecoration* entryPointDecor) { auto stage = entryPointDecor->getProfile().getStage(); @@ -238,9 +239,7 @@ static bool isPowerOf2(const uint32_t n) return (n != 0U) && ((n - 1U) & n) == 0U; } -void WGSLSourceEmitter::emitStructFieldAttributes( - IRStructType * structType, IRStructField * field - ) +void WGSLSourceEmitter::emitStructFieldAttributes(IRStructType * structType, IRStructField * field) { // Tint emits errors unless we explicitly spell out the layout in some cases, so emit // offset and align attribtues for all fields. @@ -273,26 +272,6 @@ void WGSLSourceEmitter::emitStructFieldAttributes( m_writer->emit(")"); } -bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst) -{ - if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) - return false; - - // Don't emit "->" to access fields in resource structs - if (inst->getOp() == kIROp_FieldAddress) - return false; - - // Don't emit "*" to access fields in resource structs - if (inst->getOp() == kIROp_GlobalParam) - return false; - - // Emit 'globalVar' instead of "*&globalVar" - if (inst->getOp() == kIROp_GlobalVar) - return false; - - return true; -} - void WGSLSourceEmitter::emit(const AddressSpace addressSpace) { switch (addressSpace) @@ -325,32 +304,14 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type) { case kIROp_HLSLRWStructuredBufferType: - { - auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); - m_writer->emit("ptr<"); - emit(AddressSpace::StorageBuffer); - m_writer->emit(", "); - m_writer->emit("array"); - m_writer->emit("<"); - emitType(structuredBufferType->getElementType()); - m_writer->emit(">"); - m_writer->emit(", read_write"); - m_writer->emit(">"); - } - break; - case kIROp_HLSLStructuredBufferType: + case kIROp_HLSLRasterizerOrderedStructuredBufferType: { auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); - m_writer->emit("ptr<"); - emit(AddressSpace::StorageBuffer); - m_writer->emit(", "); m_writer->emit("array"); m_writer->emit("<"); emitType(structuredBufferType->getElementType()); m_writer->emit(">"); - m_writer->emit(", read"); - m_writer->emit(">"); } break; @@ -582,7 +543,8 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl) { m_writer->emit("<workgroup>"); } - else if (type->getOp() == kIROp_HLSLRWStructuredBufferType) + else if (type->getOp() == kIROp_HLSLRWStructuredBufferType || + type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType) { m_writer->emit("<"); m_writer->emit("storage, read_write"); @@ -692,9 +654,26 @@ void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator) } } +void WGSLSourceEmitter::emitOperandImpl(IRInst* operand, EmitOpInfo const& outerPrec) +{ + if (operand->getOp() == kIROp_Param && as<IRPtrTypeBase>(operand->getDataType())) + { + // If we are emitting a reference to a pointer typed operand, then + // we should dereference it now since we want to treat all the remaining + // part of wgsl as pointer-free target. + m_writer->emit("(*"); + m_writer->emit(getName(operand)); + m_writer->emit(")"); + } + else + { + CLikeSourceEmitter::emitOperandImpl(operand, outerPrec); + } +} + void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl( - IRType* type, DeclaratorInfo* declarator - ) + IRType* type, + DeclaratorInfo* declarator) { if (declarator) { @@ -999,13 +978,29 @@ bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) } } +void WGSLSourceEmitter::emitCallArg(IRInst* inst) +{ + if (as<IRPtrTypeBase>(inst->getDataType())) + { + // If we are calling a function with a pointer-typed argument, we need to + // explicitly prefix the argument with `&` to pass a pointer. + // + m_writer->emit("&("); + emitOperand(inst, getInfo(EmitOp::General)); + m_writer->emit(")"); + } + else + { + emitOperand(inst, getInfo(EmitOp::General)); + } +} + bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { EmitOpInfo outerPrec = inOuterPrec; switch (inst->getOp()) { - case kIROp_MakeVectorFromScalar: { // In WGSL this is done by calling the vec* overloads listed in [1] @@ -1079,25 +1074,13 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; - case kIROp_RWStructuredBufferGetElementPtr: - { - m_writer->emit("(*"); - emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix))); - m_writer->emit(")["); - emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); - m_writer->emit("]"); - return true; - } - break; - case kIROp_StructuredBufferLoad: case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferGetElementPtr: { - // Structured buffers are just arrays in WGSL - auto base = inst->getOperand(0); - emitOperand(base, outerPrec); + emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix))); m_writer->emit("["); - emitOperand(inst->getOperand(1), EmitOpInfo()); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); m_writer->emit("]"); return true; } @@ -1134,15 +1117,12 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } } break; - } return false; } -void WGSLSourceEmitter::emitVectorTypeNameImpl( - IRType* elementType, IRIntegerValue elementCount - ) +void WGSLSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) { if (elementCount > 1) @@ -1159,61 +1139,6 @@ void WGSLSourceEmitter::emitVectorTypeNameImpl( } } -void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec) -{ - // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM> - // everywhere, except for the global parameter declaration. - // Thus, when these globals are used in expressions, we need an ampersand. - - if (inst->getOp() == kIROp_GlobalParam) - { - switch (inst->getDataType()->getOp()) - { - case kIROp_HLSLStructuredBufferType: - case kIROp_HLSLRWStructuredBufferType: - - m_writer->emit("(&"); - CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); - m_writer->emit(")"); - return; - } - } - - CLikeSourceEmitter::emitOperandImpl(inst, outerPrec); -} - -void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name) -{ - // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM> - // everywhere, except for the global parameter declaration. - - switch (type->getOp()) - { - - case kIROp_HLSLStructuredBufferType: - case kIROp_HLSLRWStructuredBufferType: - { - StringSliceLoc nameAndLoc(name.getUnownedSlice()); - NameDeclaratorInfo nameDeclarator(&nameAndLoc); - emitDeclarator(&nameDeclarator); - m_writer->emit(" : "); - auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type); - m_writer->emit("array"); - m_writer->emit("<"); - emitType(structuredBufferType->getElementType()); - m_writer->emit(">"); - } - break; - - default: - - emitType(type, name); - break; - - } - -} - void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */) { if (m_f16ExtensionEnabled) diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index b3a4efb55..0b4b04b12 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -13,53 +13,36 @@ public: : CLikeSourceEmitter(desc) {} - virtual void emitParameterGroupImpl( - IRGlobalParam* varDecl, IRUniformParameterGroupType* type - ) SLANG_OVERRIDE; - virtual void emitEntryPointAttributesImpl( - IRFunc* irFunc, IREntryPointDecoration* entryPointDecor - ) SLANG_OVERRIDE; + virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE; + virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE; - virtual void emitVectorTypeNameImpl( - IRType* elementType, IRIntegerValue elementCount - ) SLANG_OVERRIDE; + virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitFuncHeaderImpl(IRFunc* func) SLANG_OVERRIDE; virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; - virtual bool tryEmitInstExprImpl( - IRInst* inst, const EmitOpInfo& inOuterPrec - ) SLANG_OVERRIDE; + virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitSwitchCaseSelectorsImpl( IRBasicType *const switchCondition, const SwitchRegion::Case *const currentCase, - const bool isDefault - ) SLANG_OVERRIDE; - virtual void emitSimpleTypeAndDeclaratorImpl( - IRType* type, DeclaratorInfo* declarator - ) SLANG_OVERRIDE; + const bool isDefault) SLANG_OVERRIDE; + virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl) SLANG_OVERRIDE; virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE; + virtual void emitOperandImpl(IRInst* operand, EmitOpInfo const& outerPrec) SLANG_OVERRIDE; virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE; virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE; virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE; - virtual bool isPointerSyntaxRequiredImpl(IRInst* inst) SLANG_OVERRIDE; virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE; virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE; - virtual void emitStructFieldAttributes( - IRStructType * structType, IRStructField * field - ) SLANG_OVERRIDE; - virtual void emitGlobalParamType(IRType* type, const String& name) SLANG_OVERRIDE; - virtual void emitOperandImpl( - IRInst* inst, const EmitOpInfo& outerPrec - ) SLANG_OVERRIDE; + virtual void emitStructFieldAttributes(IRStructType * structType, IRStructField * field) SLANG_OVERRIDE; + virtual void emitCallArg(IRInst* inst) SLANG_OVERRIDE; virtual void emitIntrinsicCallExprImpl( IRCall* inst, UnownedStringSlice intrinsicDefinition, IRInst* intrinsicInst, - EmitOpInfo const& inOuterPrec - ) SLANG_OVERRIDE; + EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; void emit(const AddressSpace addressSpace); @@ -69,10 +52,9 @@ private: void emitMatrixType( IRType *const elementType, const IRIntegerValue& rowCountWGSL, - const IRIntegerValue& colCountWGSL - ); + const IRIntegerValue& colCountWGSL); - bool m_f16ExtensionEnabled {false}; + bool m_f16ExtensionEnabled = false; }; diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp index e05eba78c..622ceeff5 100644 --- a/source/slang/slang-ir-wgsl-legalize.cpp +++ b/source/slang/slang-ir-wgsl-legalize.cpp @@ -40,17 +40,16 @@ namespace Slang std::optional<SystemValLegalizationWorkItem> makeSystemValWorkItem(IRInst* var); void legalizeSystemValue( - EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem - ); + EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem); List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint( - EntryPointInfo entryPoint - ); + EntryPointInfo entryPoint); void legalizeSystemValueParameters(EntryPointInfo entryPoint); void legalizeEntryPointForWGSL(EntryPointInfo entryPoint); IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType); WGSLSystemValueInfo getSystemValueInfo( - String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar - ); + String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar); + void legalizeCall(IRCall* call); + void processInst(IRInst* inst); }; IRInst* LegalizeWGSLEntryPointContext::tryConvertValue( @@ -321,6 +320,74 @@ namespace Slang legalizeSystemValueParameters(entryPoint); } + void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call) + { + // WGSL does not allow forming a pointer to a sub part of a composite value. + // For example, if we have + // ``` + // struct S { float x; float y; }; + // void foo(inout float v) { v = 1.0f; } + // void main() { S s; foo(s.x); } + // ``` + // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`. + // And trying to form `&s.x` in WGSL is illegal. + // To work around this, we will create a local variable to hold the sub part of + // the composite value. + // And then pass the local variable to the function. + // After the call, we will write back the local variable to the sub part of the + // composite value. + // + IRBuilder builder(call); + builder.setInsertBefore(call); + struct WritebackPair { IRInst* dest; IRInst* value; }; + ShortList<WritebackPair> pendingWritebacks; + + for (UInt i = 0; i < call->getArgCount(); i++) + { + auto arg = call->getArg(i); + auto ptrType = as<IRPtrTypeBase>(arg->getDataType()); + if (!ptrType) + continue; + switch (arg->getOp()) + { + case kIROp_Var: + case kIROp_Param: + continue; + default: + break; + } + + // Create a local variable to hold the input argument. + auto var = builder.emitVar( + ptrType->getValueType(), + AddressSpace::Function); + + // Store the input argument into the local variable. + builder.emitStore(var, builder.emitLoad(arg)); + builder.replaceOperand(call->getArgs() + i, var); + pendingWritebacks.add({ arg, var }); + } + + // Perform writebacks after the call. + builder.setInsertAfter(call); + for (auto& pair : pendingWritebacks) + { + builder.emitStore(pair.dest, builder.emitLoad(pair.value)); + } + } + + void LegalizeWGSLEntryPointContext::processInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Call: + legalizeCall(static_cast<IRCall*>(inst)); + break; + default: + for (auto child : inst->getModifiableChildren()) + processInst(child); + } + } void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink) { List<EntryPointInfo> entryPoints; @@ -342,6 +409,9 @@ namespace Slang LegalizeWGSLEntryPointContext context(sink, module); for (auto entryPoint : entryPoints) context.legalizeEntryPointForWGSL(entryPoint); + + // Go through every instruction in the module and legalize them as needed. + context.processInst(module->getModuleInst()); } } diff --git a/tests/language-feature/atomic-t/atomic-0.slang b/tests/language-feature/atomic-t/atomic-0.slang index 6f2ce6418..591de490c 100644 --- a/tests/language-feature/atomic-t/atomic-0.slang +++ b/tests/language-feature/atomic-t/atomic-0.slang @@ -12,46 +12,32 @@ void computeMain() bool result = true; if (outputBuffer[0].add(1) != 0) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].sub(1) != 1) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].max(2) != 0) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].min(1) != 2) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].or(3) != 1) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].and(2) != 3) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].xor(3) != 2) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].exchange(4) != 1) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].compareExchange(4, 5) != 4) {} //result = false; // for some reason this fails on Metal Github CI, so disabling. - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].load() != 5) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].increment() != 5) result = false; - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].decrement() != 6) result = false; - AllMemoryBarrierWithGroupSync(); // CHECK: 6 outputBuffer[0].store(6); - AllMemoryBarrierWithGroupSync(); if (outputBuffer[0].load() != 6) result = false; - AllMemoryBarrierWithGroupSync(); // CHECK: 1 if (result) outputBuffer[1].store(1); diff --git a/tests/wgsl/inout.slang b/tests/wgsl/inout.slang new file mode 100644 index 000000000..57f55ae3f --- /dev/null +++ b/tests/wgsl/inout.slang @@ -0,0 +1,32 @@ +//TEST:SIMPLE(filecheck=CHECK): -target wgsl + +RWStructuredBuffer<float> outputBuffer; + +// CHECK: fn inner{{.*}}( x{{.*}} : ptr<function, f32>) +// CHECK: (*x{{.*}}) = (*x{{.*}}) + 1.0 +void inner(inout float x) +{ + x = x + 1; +} + +// CHECK: fn test{{.*}}( x{{.*}} : ptr<function, f32>) +void test(inout float x) +{ + inner(x); +} + +struct MyType +{ + float myField[3]; +} + +[numthreads(1,1,1)] +void computeMain(int id : SV_DispatchThreadID) +{ + MyType v; + v.myField[id] = 0.0f; + // CHECK: test{{.*}}(&({{.*}})); + test(v.myField[id]); + v.myField[1] = 2.0; + outputBuffer[0] = v.myField[id]; +} |
