summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-emit-c-like.cpp20
-rw-r--r--source/slang/slang-emit-c-like.h3
-rw-r--r--source/slang/slang-emit-wgsl.cpp179
-rw-r--r--source/slang/slang-emit-wgsl.h42
-rw-r--r--source/slang/slang-ir-wgsl-legalize.cpp82
-rw-r--r--tests/language-feature/atomic-t/atomic-0.slang14
-rw-r--r--tests/wgsl/inout.slang32
7 files changed, 183 insertions, 189 deletions
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index c60397b85..28eb7da04 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -1646,16 +1646,11 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
return true;
}
-bool CLikeSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* /* inst */)
-{
- return doesTargetSupportPtrTypes();
-}
-
void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const& outerPrec)
{
EmitOpInfo newOuterPrec = outerPrec;
- if (isPointerSyntaxRequiredImpl(inst))
+ if (doesTargetSupportPtrTypes())
{
switch (inst->getOp())
{
@@ -1754,7 +1749,7 @@ void CLikeSourceEmitter::emitDereferenceOperand(IRInst* inst, EmitOpInfo const&
void CLikeSourceEmitter::emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec)
{
- if (isPointerSyntaxRequiredImpl(inst))
+ if (doesTargetSupportPtrTypes())
{
auto prec = getInfo(EmitOp::Prefix);
auto newOuterPrec = outerPrec;
@@ -2003,6 +1998,11 @@ void CLikeSourceEmitter::emitIntrinsicCallExprImpl(
}
}
+void CLikeSourceEmitter::emitCallArg(IRInst* inst)
+{
+ emitOperand(inst, getInfo(EmitOp::General));
+}
+
void CLikeSourceEmitter::_emitCallArgList(IRCall* inst, int startingOperandIndex)
{
bool isFirstArg = true;
@@ -2023,7 +2023,7 @@ void CLikeSourceEmitter::_emitCallArgList(IRCall* inst, int startingOperandIndex
m_writer->emit(", ");
else
isFirstArg = false;
- emitOperand(inst->getOperand(aa), getInfo(EmitOp::General));
+ emitCallArg(inst->getOperand(aa));
}
m_writer->emit(")");
}
@@ -2296,7 +2296,7 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO
IRFieldAddress* ii = (IRFieldAddress*) inst;
- if (isPointerSyntaxRequiredImpl(inst))
+ if (doesTargetSupportPtrTypes())
{
auto prec = getInfo(EmitOp::Prefix);
needClose = maybeEmitParens(outerPrec, prec);
@@ -4206,7 +4206,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)
emitRateQualifiersAndAddressSpace(varDecl);
emitVarKeyword(varType, varDecl);
- emitGlobalParamType(varType, getName(varDecl));
+ emitType(varType, getName(varDecl));
emitSemantics(varDecl);
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index ccc25de57..3cccad9e6 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -257,7 +257,6 @@ public:
void emitType(IRType* type);
void emitType(IRType* type, Name* name, SourceLoc const& nameLoc);
void emitType(IRType* type, NameLoc const& nameAndLoc);
- virtual void emitGlobalParamType(IRType* type, String const& name) {emitType(type, name);}
bool hasExplicitConstantBufferOffset(IRInst* cbufferType);
bool isSingleElementConstantBuffer(IRInst* cbufferType);
bool shouldForceUnpackConstantBufferElements(IRInst* cbufferType);
@@ -430,7 +429,6 @@ public:
void emitGlobalInst(IRInst* inst);
virtual void emitGlobalInstImpl(IRInst* inst);
- virtual bool isPointerSyntaxRequiredImpl(IRInst* inst);
void ensureInstOperand(ComputeEmitActionsContext* ctx, IRInst* inst, EmitAction::Level requiredLevel = EmitAction::Level::Definition);
@@ -567,6 +565,7 @@ public:
// Emit the argument list (including paranthesis) in a `CallInst`
void _emitCallArgList(IRCall* call, int startingOperandIndex = 1);
+ virtual void emitCallArg(IRInst* arg);
String _generateUniqueName(const UnownedStringSlice& slice);
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index 833e1c8fe..051cdb820 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -23,12 +23,13 @@
// 'transpose' calls, or else perform more complicated transformations that
// end up duplicating expressions many times.
-namespace Slang {
+namespace Slang
+{
void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
IRBasicType *const switchConditionType,
- const SwitchRegion::Case *const currentCase, const bool isDefault
- )
+ const SwitchRegion::Case *const currentCase,
+ const bool isDefault)
{
// WGSL has special syntax for blocks sharing case labels:
// "case 2, 3, 4: ...;" instead of the C-like syntax
@@ -80,8 +81,8 @@ void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
}
void WGSLSourceEmitter::emitParameterGroupImpl(
- IRGlobalParam* varDecl, IRUniformParameterGroupType* type
-)
+ IRGlobalParam* varDecl,
+ IRUniformParameterGroupType* type)
{
auto varLayout = getVarLayout(varDecl);
SLANG_RELEASE_ASSERT(varLayout);
@@ -140,8 +141,8 @@ void WGSLSourceEmitter::emitParameterGroupImpl(
}
void WGSLSourceEmitter::emitEntryPointAttributesImpl(
- IRFunc* irFunc, IREntryPointDecoration* entryPointDecor
- )
+ IRFunc* irFunc,
+ IREntryPointDecoration* entryPointDecor)
{
auto stage = entryPointDecor->getProfile().getStage();
@@ -238,9 +239,7 @@ static bool isPowerOf2(const uint32_t n)
return (n != 0U) && ((n - 1U) & n) == 0U;
}
-void WGSLSourceEmitter::emitStructFieldAttributes(
- IRStructType * structType, IRStructField * field
- )
+void WGSLSourceEmitter::emitStructFieldAttributes(IRStructType * structType, IRStructField * field)
{
// Tint emits errors unless we explicitly spell out the layout in some cases, so emit
// offset and align attribtues for all fields.
@@ -273,26 +272,6 @@ void WGSLSourceEmitter::emitStructFieldAttributes(
m_writer->emit(")");
}
-bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst)
-{
- if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr)
- return false;
-
- // Don't emit "->" to access fields in resource structs
- if (inst->getOp() == kIROp_FieldAddress)
- return false;
-
- // Don't emit "*" to access fields in resource structs
- if (inst->getOp() == kIROp_GlobalParam)
- return false;
-
- // Emit 'globalVar' instead of "*&globalVar"
- if (inst->getOp() == kIROp_GlobalVar)
- return false;
-
- return true;
-}
-
void WGSLSourceEmitter::emit(const AddressSpace addressSpace)
{
switch (addressSpace)
@@ -325,32 +304,14 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
{
case kIROp_HLSLRWStructuredBufferType:
- {
- auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
- m_writer->emit("ptr<");
- emit(AddressSpace::StorageBuffer);
- m_writer->emit(", ");
- m_writer->emit("array");
- m_writer->emit("<");
- emitType(structuredBufferType->getElementType());
- m_writer->emit(">");
- m_writer->emit(", read_write");
- m_writer->emit(">");
- }
- break;
-
case kIROp_HLSLStructuredBufferType:
+ case kIROp_HLSLRasterizerOrderedStructuredBufferType:
{
auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
- m_writer->emit("ptr<");
- emit(AddressSpace::StorageBuffer);
- m_writer->emit(", ");
m_writer->emit("array");
m_writer->emit("<");
emitType(structuredBufferType->getElementType());
m_writer->emit(">");
- m_writer->emit(", read");
- m_writer->emit(">");
}
break;
@@ -582,7 +543,8 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl)
{
m_writer->emit("<workgroup>");
}
- else if (type->getOp() == kIROp_HLSLRWStructuredBufferType)
+ else if (type->getOp() == kIROp_HLSLRWStructuredBufferType ||
+ type->getOp() == kIROp_HLSLRasterizerOrderedStructuredBufferType)
{
m_writer->emit("<");
m_writer->emit("storage, read_write");
@@ -692,9 +654,26 @@ void WGSLSourceEmitter::emitDeclaratorImpl(DeclaratorInfo* declarator)
}
}
+void WGSLSourceEmitter::emitOperandImpl(IRInst* operand, EmitOpInfo const& outerPrec)
+{
+ if (operand->getOp() == kIROp_Param && as<IRPtrTypeBase>(operand->getDataType()))
+ {
+ // If we are emitting a reference to a pointer typed operand, then
+ // we should dereference it now since we want to treat all the remaining
+ // part of wgsl as pointer-free target.
+ m_writer->emit("(*");
+ m_writer->emit(getName(operand));
+ m_writer->emit(")");
+ }
+ else
+ {
+ CLikeSourceEmitter::emitOperandImpl(operand, outerPrec);
+ }
+}
+
void WGSLSourceEmitter::emitSimpleTypeAndDeclaratorImpl(
- IRType* type, DeclaratorInfo* declarator
- )
+ IRType* type,
+ DeclaratorInfo* declarator)
{
if (declarator)
{
@@ -999,13 +978,29 @@ bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
}
}
+void WGSLSourceEmitter::emitCallArg(IRInst* inst)
+{
+ if (as<IRPtrTypeBase>(inst->getDataType()))
+ {
+ // If we are calling a function with a pointer-typed argument, we need to
+ // explicitly prefix the argument with `&` to pass a pointer.
+ //
+ m_writer->emit("&(");
+ emitOperand(inst, getInfo(EmitOp::General));
+ m_writer->emit(")");
+ }
+ else
+ {
+ emitOperand(inst, getInfo(EmitOp::General));
+ }
+}
+
bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;
switch (inst->getOp())
{
-
case kIROp_MakeVectorFromScalar:
{
// In WGSL this is done by calling the vec* overloads listed in [1]
@@ -1079,25 +1074,13 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
break;
- case kIROp_RWStructuredBufferGetElementPtr:
- {
- m_writer->emit("(*");
- emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix)));
- m_writer->emit(")[");
- emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
- m_writer->emit("]");
- return true;
- }
- break;
-
case kIROp_StructuredBufferLoad:
case kIROp_RWStructuredBufferLoad:
+ case kIROp_RWStructuredBufferGetElementPtr:
{
- // Structured buffers are just arrays in WGSL
- auto base = inst->getOperand(0);
- emitOperand(base, outerPrec);
+ emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix)));
m_writer->emit("[");
- emitOperand(inst->getOperand(1), EmitOpInfo());
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
m_writer->emit("]");
return true;
}
@@ -1134,15 +1117,12 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
}
break;
-
}
return false;
}
-void WGSLSourceEmitter::emitVectorTypeNameImpl(
- IRType* elementType, IRIntegerValue elementCount
- )
+void WGSLSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
{
if (elementCount > 1)
@@ -1159,61 +1139,6 @@ void WGSLSourceEmitter::emitVectorTypeNameImpl(
}
}
-void WGSLSourceEmitter::emitOperandImpl(IRInst* inst, const EmitOpInfo& outerPrec)
-{
- // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM>
- // everywhere, except for the global parameter declaration.
- // Thus, when these globals are used in expressions, we need an ampersand.
-
- if (inst->getOp() == kIROp_GlobalParam)
- {
- switch (inst->getDataType()->getOp())
- {
- case kIROp_HLSLStructuredBufferType:
- case kIROp_HLSLRWStructuredBufferType:
-
- m_writer->emit("(&");
- CLikeSourceEmitter::emitOperandImpl(inst, outerPrec);
- m_writer->emit(")");
- return;
- }
- }
-
- CLikeSourceEmitter::emitOperandImpl(inst, outerPrec);
-}
-
-void WGSLSourceEmitter::emitGlobalParamType(IRType* type, const String& name)
-{
- // In WGSL, the structured buffer types are converted to ptr<AS, array<E>, AM>
- // everywhere, except for the global parameter declaration.
-
- switch (type->getOp())
- {
-
- case kIROp_HLSLStructuredBufferType:
- case kIROp_HLSLRWStructuredBufferType:
- {
- StringSliceLoc nameAndLoc(name.getUnownedSlice());
- NameDeclaratorInfo nameDeclarator(&nameAndLoc);
- emitDeclarator(&nameDeclarator);
- m_writer->emit(" : ");
- auto structuredBufferType = as<IRHLSLStructuredBufferTypeBase>(type);
- m_writer->emit("array");
- m_writer->emit("<");
- emitType(structuredBufferType->getElementType());
- m_writer->emit(">");
- }
- break;
-
- default:
-
- emitType(type, name);
- break;
-
- }
-
-}
-
void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */)
{
if (m_f16ExtensionEnabled)
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index b3a4efb55..0b4b04b12 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -13,53 +13,36 @@ public:
: CLikeSourceEmitter(desc)
{}
- virtual void emitParameterGroupImpl(
- IRGlobalParam* varDecl, IRUniformParameterGroupType* type
- ) SLANG_OVERRIDE;
- virtual void emitEntryPointAttributesImpl(
- IRFunc* irFunc, IREntryPointDecoration* entryPointDecor
- ) SLANG_OVERRIDE;
+ virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE;
+ virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
virtual void emitSimpleTypeImpl(IRType* type) SLANG_OVERRIDE;
- virtual void emitVectorTypeNameImpl(
- IRType* elementType, IRIntegerValue elementCount
- ) SLANG_OVERRIDE;
+ virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
virtual void emitFuncHeaderImpl(IRFunc* func) SLANG_OVERRIDE;
virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE;
- virtual bool tryEmitInstExprImpl(
- IRInst* inst, const EmitOpInfo& inOuterPrec
- ) SLANG_OVERRIDE;
+ virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void emitSwitchCaseSelectorsImpl(
IRBasicType *const switchCondition,
const SwitchRegion::Case *const currentCase,
- const bool isDefault
- ) SLANG_OVERRIDE;
- virtual void emitSimpleTypeAndDeclaratorImpl(
- IRType* type, DeclaratorInfo* declarator
- ) SLANG_OVERRIDE;
+ const bool isDefault) SLANG_OVERRIDE;
+ virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;
virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE;
+ virtual void emitOperandImpl(IRInst* operand, EmitOpInfo const& outerPrec) SLANG_OVERRIDE;
virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE;
virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
virtual void emitParamTypeImpl(IRType* type, const String& name) SLANG_OVERRIDE;
- virtual bool isPointerSyntaxRequiredImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void _emitType(IRType* type, DeclaratorInfo* declarator) SLANG_OVERRIDE;
virtual void emitFrontMatterImpl(TargetRequest* targetReq) SLANG_OVERRIDE;
- virtual void emitStructFieldAttributes(
- IRStructType * structType, IRStructField * field
- ) SLANG_OVERRIDE;
- virtual void emitGlobalParamType(IRType* type, const String& name) SLANG_OVERRIDE;
- virtual void emitOperandImpl(
- IRInst* inst, const EmitOpInfo& outerPrec
- ) SLANG_OVERRIDE;
+ virtual void emitStructFieldAttributes(IRStructType * structType, IRStructField * field) SLANG_OVERRIDE;
+ virtual void emitCallArg(IRInst* inst) SLANG_OVERRIDE;
virtual void emitIntrinsicCallExprImpl(
IRCall* inst,
UnownedStringSlice intrinsicDefinition,
IRInst* intrinsicInst,
- EmitOpInfo const& inOuterPrec
- ) SLANG_OVERRIDE;
+ EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
void emit(const AddressSpace addressSpace);
@@ -69,10 +52,9 @@ private:
void emitMatrixType(
IRType *const elementType,
const IRIntegerValue& rowCountWGSL,
- const IRIntegerValue& colCountWGSL
- );
+ const IRIntegerValue& colCountWGSL);
- bool m_f16ExtensionEnabled {false};
+ bool m_f16ExtensionEnabled = false;
};
diff --git a/source/slang/slang-ir-wgsl-legalize.cpp b/source/slang/slang-ir-wgsl-legalize.cpp
index e05eba78c..622ceeff5 100644
--- a/source/slang/slang-ir-wgsl-legalize.cpp
+++ b/source/slang/slang-ir-wgsl-legalize.cpp
@@ -40,17 +40,16 @@ namespace Slang
std::optional<SystemValLegalizationWorkItem> makeSystemValWorkItem(IRInst* var);
void legalizeSystemValue(
- EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem
- );
+ EntryPointInfo entryPoint, SystemValLegalizationWorkItem& workItem);
List<SystemValLegalizationWorkItem> collectSystemValFromEntryPoint(
- EntryPointInfo entryPoint
- );
+ EntryPointInfo entryPoint);
void legalizeSystemValueParameters(EntryPointInfo entryPoint);
void legalizeEntryPointForWGSL(EntryPointInfo entryPoint);
IRInst* tryConvertValue(IRBuilder& builder, IRInst* val, IRType* toType);
WGSLSystemValueInfo getSystemValueInfo(
- String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar
- );
+ String inSemanticName, String* optionalSemanticIndex, IRInst* parentVar);
+ void legalizeCall(IRCall* call);
+ void processInst(IRInst* inst);
};
IRInst* LegalizeWGSLEntryPointContext::tryConvertValue(
@@ -321,6 +320,74 @@ namespace Slang
legalizeSystemValueParameters(entryPoint);
}
+ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
+ {
+ // WGSL does not allow forming a pointer to a sub part of a composite value.
+ // For example, if we have
+ // ```
+ // struct S { float x; float y; };
+ // void foo(inout float v) { v = 1.0f; }
+ // void main() { S s; foo(s.x); }
+ // ```
+ // The call to `foo(s.x)` is illegal in WGSL because `s.x` is a sub part of `s`.
+ // And trying to form `&s.x` in WGSL is illegal.
+ // To work around this, we will create a local variable to hold the sub part of
+ // the composite value.
+ // And then pass the local variable to the function.
+ // After the call, we will write back the local variable to the sub part of the
+ // composite value.
+ //
+ IRBuilder builder(call);
+ builder.setInsertBefore(call);
+ struct WritebackPair { IRInst* dest; IRInst* value; };
+ ShortList<WritebackPair> pendingWritebacks;
+
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ auto ptrType = as<IRPtrTypeBase>(arg->getDataType());
+ if (!ptrType)
+ continue;
+ switch (arg->getOp())
+ {
+ case kIROp_Var:
+ case kIROp_Param:
+ continue;
+ default:
+ break;
+ }
+
+ // Create a local variable to hold the input argument.
+ auto var = builder.emitVar(
+ ptrType->getValueType(),
+ AddressSpace::Function);
+
+ // Store the input argument into the local variable.
+ builder.emitStore(var, builder.emitLoad(arg));
+ builder.replaceOperand(call->getArgs() + i, var);
+ pendingWritebacks.add({ arg, var });
+ }
+
+ // Perform writebacks after the call.
+ builder.setInsertAfter(call);
+ for (auto& pair : pendingWritebacks)
+ {
+ builder.emitStore(pair.dest, builder.emitLoad(pair.value));
+ }
+ }
+
+ void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ legalizeCall(static_cast<IRCall*>(inst));
+ break;
+ default:
+ for (auto child : inst->getModifiableChildren())
+ processInst(child);
+ }
+ }
void legalizeIRForWGSL(IRModule* module, DiagnosticSink* sink)
{
List<EntryPointInfo> entryPoints;
@@ -342,6 +409,9 @@ namespace Slang
LegalizeWGSLEntryPointContext context(sink, module);
for (auto entryPoint : entryPoints)
context.legalizeEntryPointForWGSL(entryPoint);
+
+ // Go through every instruction in the module and legalize them as needed.
+ context.processInst(module->getModuleInst());
}
}
diff --git a/tests/language-feature/atomic-t/atomic-0.slang b/tests/language-feature/atomic-t/atomic-0.slang
index 6f2ce6418..591de490c 100644
--- a/tests/language-feature/atomic-t/atomic-0.slang
+++ b/tests/language-feature/atomic-t/atomic-0.slang
@@ -12,46 +12,32 @@ void computeMain()
bool result = true;
if (outputBuffer[0].add(1) != 0)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].sub(1) != 1)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].max(2) != 0)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].min(1) != 2)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].or(3) != 1)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].and(2) != 3)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].xor(3) != 2)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].exchange(4) != 1)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].compareExchange(4, 5) != 4)
{} //result = false; // for some reason this fails on Metal Github CI, so disabling.
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].load() != 5)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].increment() != 5)
result = false;
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].decrement() != 6)
result = false;
- AllMemoryBarrierWithGroupSync();
// CHECK: 6
outputBuffer[0].store(6);
- AllMemoryBarrierWithGroupSync();
if (outputBuffer[0].load() != 6)
result = false;
- AllMemoryBarrierWithGroupSync();
// CHECK: 1
if (result)
outputBuffer[1].store(1);
diff --git a/tests/wgsl/inout.slang b/tests/wgsl/inout.slang
new file mode 100644
index 000000000..57f55ae3f
--- /dev/null
+++ b/tests/wgsl/inout.slang
@@ -0,0 +1,32 @@
+//TEST:SIMPLE(filecheck=CHECK): -target wgsl
+
+RWStructuredBuffer<float> outputBuffer;
+
+// CHECK: fn inner{{.*}}( x{{.*}} : ptr<function, f32>)
+// CHECK: (*x{{.*}}) = (*x{{.*}}) + 1.0
+void inner(inout float x)
+{
+ x = x + 1;
+}
+
+// CHECK: fn test{{.*}}( x{{.*}} : ptr<function, f32>)
+void test(inout float x)
+{
+ inner(x);
+}
+
+struct MyType
+{
+ float myField[3];
+}
+
+[numthreads(1,1,1)]
+void computeMain(int id : SV_DispatchThreadID)
+{
+ MyType v;
+ v.myField[id] = 0.0f;
+ // CHECK: test{{.*}}(&({{.*}}));
+ test(v.myField[id]);
+ v.myField[1] = 2.0;
+ outputBuffer[0] = v.myField[id];
+}