summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-09-20 15:11:23 -0700
committerGitHub <noreply@github.com>2024-09-20 15:11:23 -0700
commit490834924cc390cb812713c225b9a8227c66cf1f (patch)
tree5644e2a18cb085692d5fe9625f42582db07447be /source
parentb4c851fb1419f869bddaa08487f58376bc0a7144 (diff)
Initial `Atomic<T>` type implementation. (#5125)
* Initial Atomic<T> type implementation. * Update design doc. * Fix. * Add test. * Fixes and add tests. * Fix WGSL. * Fix glsl. * Fix metal. * experiemnt with github metal. * experiment github metal 2 * github metal experiment 3 * experiment with github metal 4. * experiment with metal 5. * experiment 7. * metal experiment 8. * Fix metal tests. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang126
-rw-r--r--source/slang/slang-ast-type.cpp6
-rw-r--r--source/slang/slang-ast-type.h7
-rw-r--r--source/slang/slang-check-expr.cpp6
-rw-r--r--source/slang/slang-emit-c-like.cpp56
-rw-r--r--source/slang/slang-emit-c-like.h4
-rw-r--r--source/slang/slang-emit-cpp.cpp4
-rw-r--r--source/slang/slang-emit-cuda.cpp123
-rw-r--r--source/slang/slang-emit-glsl.cpp183
-rw-r--r--source/slang/slang-emit-hlsl.cpp171
-rw-r--r--source/slang/slang-emit-hlsl.h1
-rw-r--r--source/slang/slang-emit-metal.cpp191
-rw-r--r--source/slang/slang-emit-metal.h1
-rw-r--r--source/slang/slang-emit-spirv-ops.h146
-rw-r--r--source/slang/slang-emit-spirv.cpp189
-rw-r--r--source/slang/slang-emit-wgsl.cpp187
-rw-r--r--source/slang/slang-emit-wgsl.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h22
-rw-r--r--source/slang/slang-ir-insts.h23
-rw-r--r--source/slang/slang-ir-layout.cpp8
-rw-r--r--source/slang/slang-ir-lower-append-consume-structured-buffer.cpp6
-rw-r--r--source/slang/slang-ir-use-uninitialized-values.cpp1
-rw-r--r--source/slang/slang-ir.cpp17
-rw-r--r--source/slang/slang-ir.h16
-rw-r--r--source/slang/slang-lower-to-ir.cpp16
-rw-r--r--source/slang/slang-type-layout.cpp4
26 files changed, 1415 insertions, 102 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 476279ab8..e4e24dddb 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2721,6 +2721,132 @@ __Addr<T> __getLegalizedSPIRVGlobalParamAddr(T val);
__intrinsic_op($(kIROp_RequireComputeDerivative))
void __requireComputeDerivative();
+// Atomic<T>
+
+enum MemoryOrder
+{
+ Relaxed = $(kIRMemoryOrder_Relaxed),
+ Acquire = $(kIRMemoryOrder_Acquire),
+ Release = $(kIRMemoryOrder_Release),
+ AcquireRelease = $(kIRMemoryOrder_AcquireRelease),
+ SeqCst = $(kIRMemoryOrder_SeqCst),
+}
+
+[sealed] interface IAtomicable {}
+[sealed] interface IArithmeticAtomicable : IAtomicable, IArithmetic {}
+[sealed] interface IBitAtomicable : IArithmeticAtomicable, IInteger {}
+
+extension int : IBitAtomicable {}
+extension uint : IBitAtomicable {}
+extension int64_t : IBitAtomicable {}
+extension uint64_t : IBitAtomicable {}
+extension double : IArithmeticAtomicable {}
+extension float : IArithmeticAtomicable {}
+extension half : IArithmeticAtomicable {}
+
+__magic_type(AtomicType)
+__intrinsic_type($(kIROp_AtomicType))
+[require(cuda_glsl_hlsl_metal_spirv_wgsl)]
+struct Atomic<T : IAtomicable>
+{
+ __intrinsic_op($(kIROp_AtomicLoad))
+ [__ref] T load(MemoryOrder order = MemoryOrder.Relaxed);
+
+ __intrinsic_op($(kIROp_AtomicStore))
+ [__ref] void store(T newValue, MemoryOrder order = MemoryOrder.Relaxed);
+
+ __intrinsic_op($(kIROp_AtomicExchange))
+ [__ref] T exchange(T newValue, MemoryOrder order = MemoryOrder.Relaxed); // returns old value
+
+ __intrinsic_op($(kIROp_AtomicCompareExchange))
+ [__ref] T compareExchange(
+ T compareValue,
+ T newValue,
+ MemoryOrder successOrder = MemoryOrder.Relaxed,
+ MemoryOrder failOrder = MemoryOrder.Relaxed);
+}
+
+extension<T : IArithmeticAtomicable> Atomic<T>
+{
+ __intrinsic_op($(kIROp_AtomicAdd))
+ [__ref] T add(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+ __intrinsic_op($(kIROp_AtomicSub))
+ [__ref] T sub(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+ __intrinsic_op($(kIROp_AtomicMax))
+ [__ref] T max(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+ __intrinsic_op($(kIROp_AtomicMin))
+ [__ref] T min(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+}
+
+extension<T : IBitAtomicable> Atomic<T>
+{
+ __intrinsic_op($(kIROp_AtomicAnd))
+ [__ref] T and(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+ __intrinsic_op($(kIROp_AtomicOr))
+ [__ref] T or(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+ __intrinsic_op($(kIROp_AtomicXor))
+ [__ref] T xor(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value
+ __intrinsic_op($(kIROp_AtomicInc))
+ [__ref] T increment(MemoryOrder order = MemoryOrder.Relaxed);
+ __intrinsic_op($(kIROp_AtomicDec))
+ [__ref] T decrement(MemoryOrder order = MemoryOrder.Relaxed);
+}
+
+__generic<T : IArithmeticAtomicable>
+[ForceInline]
+T operator +=(__ref Atomic<T> v, T value)
+{
+ return v.add(value) + value;
+}
+__generic<T : IArithmeticAtomicable>
+[ForceInline]
+T operator -=(__ref Atomic<T> v, T value)
+{
+ return v.sub(value) - value;
+}
+__generic<T : IBitAtomicable>
+[ForceInline]
+T operator &=(__ref Atomic<T> v, T value)
+{
+ return v.and(value) & value;
+}
+__generic<T : IBitAtomicable>
+[ForceInline]
+T operator |=(__ref Atomic<T> v, T value)
+{
+ return v.or(value) | value;
+}
+__generic<T : IBitAtomicable>
+[ForceInline]
+T operator ^=(__ref Atomic<T> v, T value)
+{
+ return v.xor(value) ^ value;
+}
+
+__generic<T : IBitAtomicable>
+[ForceInline]
+__prefix T operator ++(__ref Atomic<T> v)
+{
+ return v.increment() + T(1);
+}
+__generic<T : IBitAtomicable>
+[ForceInline]
+__postfix T operator ++(__ref Atomic<T> v)
+{
+ return v.increment();
+}
+__generic<T : IBitAtomicable>
+[ForceInline]
+__prefix T operator --(__ref Atomic<T> v)
+{
+ return v.decrement() - T(1);
+}
+__generic<T : IBitAtomicable>
+[ForceInline]
+__postfix T operator --(__ref Atomic<T> v)
+{
+ return v.decrement();
+}
// Binding Attributes
__attributeTarget(DeclBase)
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 1c9f68a48..616c2a67a 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -328,6 +328,12 @@ bool ArrayExpressionType::isUnsized()
return false;
}
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AtomicType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+Type* AtomicType::getElementType()
+{
+ return as<Type>(_getGenericTypeArg(this, 0));
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void TypeType::_toTextOverride(StringBuilder& out)
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 46ea3ea55..1239283d7 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -436,6 +436,13 @@ class ArrayExpressionType : public DeclRefType
IntVal* getElementCount();
};
+class AtomicType : public DeclRefType
+{
+ SLANG_AST_CLASS(AtomicType)
+
+ Type* getElementType();
+};
+
// The "type" of an expression that resolves to a type.
// For example, in the expression `float(2)` the sub-expression,
// `float` would have the type `TypeType(float)`.
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 5233008fd..1b9725a8c 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2274,10 +2274,14 @@ namespace Slang
expr->left = maybeOpenRef(expr->left);
auto type = expr->left->type;
+ if (auto atomicType = as<AtomicType>(type))
+ {
+ type = atomicType->getElementType();
+ }
auto right = maybeOpenRef(expr->right);
expr->right = coerce(CoercionSite::Assignment, type, right);
- if (!type.isLeftValue)
+ if (!expr->left->type.isLeftValue)
{
if (as<ErrorType>(type))
{
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index caf3613a7..c60397b85 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -500,32 +500,6 @@ void CLikeSourceEmitter::defaultEmitInstStmt(IRInst* inst)
{
switch (inst->getOp())
{
- case kIROp_AtomicCounterIncrement:
- {
- auto oldValName = getName(inst);
- m_writer->emit("int ");
- m_writer->emit(oldValName);
- m_writer->emit(";\n");
- m_writer->emit("InterlockedAdd(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(", 1, ");
- m_writer->emit(oldValName);
- m_writer->emit(");\n");
- }
- break;
- case kIROp_AtomicCounterDecrement:
- {
- auto oldValName = getName(inst);
- m_writer->emit("int ");
- m_writer->emit(oldValName);
- m_writer->emit(";\n");
- m_writer->emit("InterlockedAdd(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(", -1, ");
- m_writer->emit(oldValName);
- m_writer->emit(");\n");
- }
- break;
case kIROp_StructuredBufferGetDimensions:
{
auto count = _generateUniqueName(UnownedStringSlice("_elementCount"));
@@ -1862,8 +1836,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst)
emitRateQualifiers(inst);
- bool isConstant(as<IRModuleInst>(inst->getParent()));
- if(isConstant)
+ if (as<IRModuleInst>(inst->getParent()))
{
// "Ordinary" instructions at module scope are constants
@@ -1888,7 +1861,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst)
}
- emitVarKeyword(type, isConstant);
+ emitVarKeyword(type, inst);
emitType(type, getName(inst));
m_writer->emit(" = ");
@@ -2920,8 +2893,19 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst)
// Insts that needs to be emitted as code blocks.
case kIROp_CudaKernelLaunch:
- case kIROp_AtomicCounterIncrement:
- case kIROp_AtomicCounterDecrement:
+ case kIROp_AtomicLoad:
+ case kIROp_AtomicStore:
+ case kIROp_AtomicInc:
+ case kIROp_AtomicDec:
+ case kIROp_AtomicAdd:
+ case kIROp_AtomicSub:
+ case kIROp_AtomicAnd:
+ case kIROp_AtomicOr:
+ case kIROp_AtomicXor:
+ case kIROp_AtomicMin:
+ case kIROp_AtomicMax:
+ case kIROp_AtomicExchange:
+ case kIROp_AtomicCompareExchange:
case kIROp_StructuredBufferGetDimensions:
case kIROp_MetalAtomicCast:
emitInstStmt(inst);
@@ -3143,7 +3127,7 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store)
void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type)
{
- emitVarKeyword(type, /* isConstant */ false);
+ emitVarKeyword(type, inst);
emitType(type, getName(inst));
@@ -3975,7 +3959,7 @@ void CLikeSourceEmitter::emitParameterGroup(IRGlobalParam* varDecl, IRUniformPar
emitParameterGroupImpl(varDecl, type);
}
-void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, bool /* isConstant */) {}
+void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, IRInst* /* varDecl */) {}
void CLikeSourceEmitter::emitVar(IRVar* varDecl)
{
@@ -4015,7 +3999,7 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl)
#endif
emitRateQualifiersAndAddressSpace(varDecl);
- emitVarKeyword(varType, /* isConstant */ false);
+ emitVarKeyword(varType, varDecl);
emitType(varType, getName(varDecl));
@@ -4147,7 +4131,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl)
emitVarModifiers(layout, varDecl, varType);
emitRateQualifiersAndAddressSpace(varDecl);
- emitVarKeyword(varType, /* isConstant */ true);
+ emitVarKeyword(varType, varDecl);
emitType(varType, getName(varDecl));
// TODO: These shouldn't be needed for ordinary
@@ -4221,7 +4205,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl)
emitDecorationLayoutSemantics(varDecl, "register");
emitRateQualifiersAndAddressSpace(varDecl);
- emitVarKeyword(varType, /* isConstant */ false);
+ emitVarKeyword(varType, varDecl);
emitGlobalParamType(varType, getName(varDecl));
emitSemantics(varDecl);
diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h
index be769f31f..ccc25de57 100644
--- a/source/slang/slang-emit-c-like.h
+++ b/source/slang/slang-emit-c-like.h
@@ -498,8 +498,8 @@ public:
virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator);
void emitSimpleTypeAndDeclarator(IRType* type, DeclaratorInfo* declarator) {emitSimpleTypeAndDeclaratorImpl(type, declarator);};
- virtual void emitVarKeywordImpl(IRType * type, bool isConstant);
- void emitVarKeyword(IRType * type, bool isConstant) {emitVarKeywordImpl(type, isConstant);}
+ virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl);
+ void emitVarKeyword(IRType * type, IRInst* varDecl) {emitVarKeywordImpl(type, varDecl);}
virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); };
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index bcb9ed9da..19dc05dcf 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -337,6 +337,10 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
out << intLit->getValue();
return SLANG_OK;
}
+ case kIROp_AtomicType:
+ {
+ return calcTypeName((IRType*)type->getOperand(0), target, out);
+ }
default:
{
if (isNominalOp(type->getOp()))
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 05485855c..81bcafeb3 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -474,6 +474,129 @@ bool CUDASourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
m_writer->emit(");\n");
return true;
}
+ case kIROp_AtomicLoad:
+ {
+ emitInstResultDecl(inst);
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ return true;
+ }
+ case kIROp_AtomicStore:
+ {
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(" = ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ return true;
+ }
+ case kIROp_AtomicExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicExch(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicCompareExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicCAS(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicAdd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicSub:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", -(");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("));\n");
+ return true;
+ }
+ case kIROp_AtomicAnd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAnd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicOr:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicOr(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicXor:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicXor(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMin:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicMin(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMax:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicMax(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicInc:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", 1);\n");
+ return true;
+ }
+ case kIROp_AtomicDec:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", -1);\n");
+ return true;
+ }
default:
return false;
}
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 56113409d..116bf67cd 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -2156,44 +2156,152 @@ bool GLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
{
switch (inst->getOp())
{
- case kIROp_AtomicCounterIncrement:
- {
- auto oldValName = getName(inst);
- m_writer->emit("int ");
- m_writer->emit(oldValName);
- m_writer->emit(" = ");
- m_writer->emit("atomicAdd(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(", 1);\n");
- return true;
- }
- case kIROp_AtomicCounterDecrement:
- {
- auto oldValName = getName(inst);
- m_writer->emit("int ");
- m_writer->emit(oldValName);
- m_writer->emit(" = ");
- m_writer->emit("atomicAdd(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit(", -1);\n");
- return true;
- }
case kIROp_StructuredBufferGetDimensions:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("uvec2(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("._data.length(), ");
+ auto elementType = as<IRHLSLStructuredBufferTypeBase>(inst->getOperand(0)->getDataType())->getElementType();
+ IRIntegerValue stride = 0;
+ if (auto sizeDecor = elementType->findDecoration<IRSizeAndAlignmentDecoration>())
{
- emitInstResultDecl(inst);
- m_writer->emit("uvec2(");
- emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
- m_writer->emit("._data.length(), ");
- auto elementType = as<IRHLSLStructuredBufferTypeBase>(inst->getOperand(0)->getDataType())->getElementType();
- IRIntegerValue stride = 0;
- if (auto sizeDecor = elementType->findDecoration<IRSizeAndAlignmentDecoration>())
- {
- stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment());
- }
- m_writer->emit(stride);
- m_writer->emit(");\n");
- return true;
+ stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment());
}
+ m_writer->emit(stride);
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicLoad:
+ {
+ emitInstResultDecl(inst);
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ return true;
+ }
+ case kIROp_AtomicStore:
+ {
+ emitInstResultDecl(inst);
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(" = ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ return true;
+ }
+ case kIROp_AtomicExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicExchange(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicCompareExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicCompSwap(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicAdd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicSub:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", -(");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("));\n");
+ return true;
+ }
+ case kIROp_AtomicAnd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAnd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicOr:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicOr(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicXor:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicXor(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMin:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicMin(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMax:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicMax(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicInc:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitType(inst->getDataType());
+ m_writer->emit("(1)");
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicDec:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitType(inst->getDataType());
+ m_writer->emit("(-1)");
+ m_writer->emit(");\n");
+ return true;
+ }
default:
return false;
}
@@ -2572,6 +2680,11 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit("DNV");
return;
}
+ case kIROp_AtomicType:
+ {
+ emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType());
+ return;
+ }
default: break;
}
diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp
index f9765a555..b45b4c575 100644
--- a/source/slang/slang-emit-hlsl.cpp
+++ b/source/slang/slang-emit-hlsl.cpp
@@ -496,6 +496,171 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin
}
}
+bool HLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ case kIROp_AtomicLoad:
+ {
+ emitInstResultDecl(inst);
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ return true;
+ }
+ case kIROp_AtomicStore:
+ {
+ emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(" = ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+ return true;
+ }
+ case kIROp_AtomicExchange:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedExchange(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicCompareExchange:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedCompareExchange(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicAdd:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicSub:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", -(");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicAnd:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedAnd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicOr:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedOr(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicXor:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedXor(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMin:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedMin(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMax:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedMax(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicInc:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", 1, ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicDec:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n");
+ m_writer->emit("InterlockedAdd(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", -1, ");
+ m_writer->emit(getName(inst));
+ m_writer->emit(");");
+ return true;
+ }
+ default:
+ return false;
+ }
+}
+
bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch (inst->getOp())
@@ -755,7 +920,6 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
return true;
}
break;
-
default: break;
}
// Not handled
@@ -1030,6 +1194,11 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit("uint4");
return;
}
+ case kIROp_AtomicType:
+ {
+ emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType());
+ return;
+ }
default: break;
}
diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h
index 4d721a1ac..31fa6b290 100644
--- a/source/slang/slang-emit-hlsl.h
+++ b/source/slang/slang-emit-hlsl.h
@@ -49,6 +49,7 @@ protected:
virtual void emitParamTypeModifier(IRType* type) SLANG_OVERRIDE { emitMatrixLayoutModifiersImpl(type); }
virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
+ virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE;
virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp
index a0fe220e5..312d06c08 100644
--- a/source/slang/slang-emit-metal.cpp
+++ b/source/slang/slang-emit-metal.cpp
@@ -246,6 +246,20 @@ void MetalSourceEmitter::ensurePrelude(const char* preludeText)
m_requiredPreludes.add(stringLit);
}
+void MetalSourceEmitter::emitMemoryOrderOperand(IRInst* inst)
+{
+ auto memoryOrder = (IRMemoryOrder)getIntVal(inst);
+ switch (memoryOrder)
+ {
+ case kIRMemoryOrder_Relaxed:
+ m_writer->emit("memory_order_relaxed");
+ break;
+ default:
+ m_writer->emit("memory_order_seq_cst");
+ break;
+ }
+}
+
bool MetalSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
{
switch (inst->getOp())
@@ -271,6 +285,164 @@ bool MetalSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
m_writer->emit("));\n");
return true;
}
+ case kIROp_AtomicLoad:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_load_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(1));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicStore:
+ {
+ m_writer->emit("atomic_store_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_exchange_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicCompareExchange:
+ {
+ emitType(inst->getDataType(), getName(inst));
+ m_writer->emit(";\n{\n");
+ emitType(inst->getDataType(), "_metal_cas_comparand");
+ m_writer->emit(" = ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(";\n");
+
+ m_writer->emit(getName(inst));
+ m_writer->emit(" = atomic_compare_exchange_weak_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", &_metal_cas_comparand, ");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(3));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(4));
+ m_writer->emit(");\n}\n");
+ return true;
+ }
+ case kIROp_AtomicAdd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_add_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicSub:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_sub_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicAnd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_and_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicOr:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_or_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicXor:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_xor_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMin:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_min_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMax:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_max_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitMemoryOrderOperand(inst->getOperand(2));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicInc:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_add_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", 1, ");
+ emitMemoryOrderOperand(inst->getOperand(1));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicDec:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomic_fetch_sub_explicit(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", 1, ");
+ emitMemoryOrderOperand(inst->getOperand(1));
+ m_writer->emit(");\n");
+ return true;
+ }
}
return false;
}
@@ -664,10 +836,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_BoolType:
case kIROp_Int8Type:
case kIROp_IntType:
- case kIROp_Int64Type:
case kIROp_UInt8Type:
case kIROp_UIntType:
- case kIROp_UInt64Type:
case kIROp_FloatType:
case kIROp_DoubleType:
case kIROp_HalfType:
@@ -675,6 +845,12 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit(getDefaultBuiltinTypeName(type->getOp()));
return;
}
+ case kIROp_Int64Type:
+ m_writer->emit("long");
+ return;
+ case kIROp_UInt64Type:
+ m_writer->emit("ulong");
+ return;
case kIROp_Int16Type:
m_writer->emit("short");
return;
@@ -682,10 +858,10 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit("ushort");
return;
case kIROp_IntPtrType:
- m_writer->emit("int64_t");
+ m_writer->emit("long");
return;
case kIROp_UIntPtrType:
- m_writer->emit("uint64_t");
+ m_writer->emit("ulong");
return;
case kIROp_StructType:
m_writer->emit(getName(type));
@@ -781,6 +957,13 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit("mesh_grid_properties ");
return;
}
+ case kIROp_AtomicType:
+ {
+ m_writer->emit("atomic<");
+ emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType());
+ m_writer->emit(">");
+ return;
+ }
default:
break;
}
diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h
index 67aa0d506..8e33eddef 100644
--- a/source/slang/slang-emit-metal.h
+++ b/source/slang/slang-emit-metal.h
@@ -26,6 +26,7 @@ protected:
void ensurePrelude(const char* preludeText);
+ void emitMemoryOrderOperand(IRInst* inst);
virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE;
virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h
index 3d6bf846f..cea246b20 100644
--- a/source/slang/slang-emit-spirv-ops.h
+++ b/source/slang/slang-emit-spirv-ops.h
@@ -2354,6 +2354,152 @@ SpvInst* emitOpExecutionModeIdLocalSizeId(
}
template<typename T1, typename T2, typename T3, typename T4>
+SpvInst* emitOpAtomicLoad(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T1& idResultType,
+ const T2& pointer,
+ const T3& memory,
+ const T4& semantics
+)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ static_assert(isSingular<T4>);
+ return emitInst(
+ parent,
+ inst,
+ SpvOpAtomicLoad,
+ idResultType,
+ kResultID,
+ pointer,
+ memory,
+ semantics
+ );
+}
+
+template<typename T1, typename T2, typename T3, typename T4>
+SpvInst* emitOpAtomicStore(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T1& pointer,
+ const T2& memory,
+ const T3& semantics,
+ const T4& value
+)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ static_assert(isSingular<T4>);
+ return emitInst(
+ parent,
+ inst,
+ SpvOpAtomicStore,
+ pointer,
+ memory,
+ semantics,
+ value
+ );
+}
+
+template<typename T1, typename T2, typename T3, typename T4, typename T5>
+SpvInst* emitOpAtomicExchange(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T1& idResultType,
+ const T2& pointer,
+ const T3& memory,
+ const T4& semantics,
+ const T5& value
+)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ static_assert(isSingular<T4>);
+ static_assert(isSingular<T5>);
+ return emitInst(
+ parent,
+ inst,
+ SpvOpAtomicExchange,
+ idResultType,
+ kResultID,
+ pointer,
+ memory,
+ semantics,
+ value
+ );
+}
+
+template<typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename T7>
+SpvInst* emitOpAtomicCompareExchange(
+ SpvInstParent* parent,
+ IRInst* inst,
+ const T1& idResultType,
+ const T2& pointer,
+ const T3& memory,
+ const T4& semanticsEqual,
+ const T5& semanticsUnequal,
+ const T6& value,
+ const T7& comparator
+)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ static_assert(isSingular<T4>);
+ static_assert(isSingular<T5>);
+ static_assert(isSingular<T6>);
+ static_assert(isSingular<T7>);
+
+ return emitInst(
+ parent,
+ inst,
+ SpvOpAtomicCompareExchange,
+ idResultType,
+ kResultID,
+ pointer,
+ memory,
+ semanticsEqual,
+ semanticsUnequal,
+ value,
+ comparator
+ );
+}
+
+template<typename T1, typename T2, typename T3, typename T4, typename T5>
+SpvInst* emitOpAtomicOp(
+ SpvInstParent* parent,
+ IRInst* inst,
+ SpvOp op,
+ const T1& idResultType,
+ const T2& pointer,
+ const T3& memory,
+ const T4& semantics,
+ const T5& value
+)
+{
+ static_assert(isSingular<T1>);
+ static_assert(isSingular<T2>);
+ static_assert(isSingular<T3>);
+ static_assert(isSingular<T4>);
+ static_assert(isSingular<T5>);
+ return emitInst(
+ parent,
+ inst,
+ op,
+ idResultType,
+ kResultID,
+ pointer,
+ memory,
+ semantics,
+ value
+ );
+}
+
+template<typename T1, typename T2, typename T3, typename T4>
SpvInst* emitOpAtomicIIncrement(
SpvInstParent* parent,
IRInst* inst,
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index a2da4801e..ede573581 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -1660,6 +1660,12 @@ struct SPIRVEmitContext
SpvLiteralInteger::from32(stride));
return arrayType;
}
+ case kIROp_AtomicType:
+ {
+ auto result = ensureInst(as<IRAtomicType>(inst)->getElementType());
+ registerInst(inst, result);
+ return result;
+ }
case kIROp_SubpassInputType:
return ensureSubpassInputType(inst, cast<IRSubpassInputType>(inst));
case kIROp_TextureType:
@@ -2860,6 +2866,115 @@ struct SPIRVEmitContext
return (isSpirv16OrLater() || m_useDemoteToHelperInvocationExtension);
}
+ SpvInst* emitMemorySemanticMask(IRInst* inst)
+ {
+ IRBuilder builder(inst);
+ auto memoryOrder = (IRMemoryOrder)getIntVal(inst);
+ switch (memoryOrder)
+ {
+ case kIRMemoryOrder_Relaxed:
+ return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsMaskNone }, builder.getUIntType());
+ case kIRMemoryOrder_Acquire:
+ return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsAcquireMask }, builder.getUIntType());
+ case kIRMemoryOrder_Release:
+ return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsReleaseMask }, builder.getUIntType());
+ case kIRMemoryOrder_AcquireRelease:
+ return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsAcquireReleaseMask }, builder.getUIntType());
+ case kIRMemoryOrder_SeqCst:
+ return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsSequentiallyConsistentMask }, builder.getUIntType());
+ default:
+ SLANG_UNEXPECTED("unhandled memory order");
+ UNREACHABLE_RETURN(nullptr);
+ }
+ }
+
+ SpvOp getSpvAtomicOp(IRInst* atomicInst, bool& outNegateOperand)
+ {
+ auto typeSelect = [&](SpvOp sop, SpvOp uop, SpvOp fop)
+ {
+ auto scalarType = getVectorElementType(atomicInst->getDataType());
+ if (isIntegralType(scalarType))
+ {
+ auto intInfo = getIntTypeInfo(scalarType);
+ if (intInfo.isSigned)
+ return sop;
+ return uop;
+ }
+ return fop;
+ };
+ outNegateOperand = false;
+ switch (atomicInst->getOp())
+ {
+ case kIROp_AtomicAdd:
+ return typeSelect(SpvOpAtomicIAdd, SpvOpAtomicIAdd, SpvOpAtomicFAddEXT);
+ case kIROp_AtomicSub:
+ if (isFloatingType(getVectorElementType(atomicInst->getDataType())))
+ outNegateOperand = true;
+ return typeSelect(SpvOpAtomicISub, SpvOpAtomicISub, SpvOpAtomicFAddEXT);
+ case kIROp_AtomicMin:
+ return typeSelect(SpvOpAtomicSMin, SpvOpAtomicUMin, SpvOpAtomicFMinEXT);
+ case kIROp_AtomicMax:
+ return typeSelect(SpvOpAtomicSMax, SpvOpAtomicUMax, SpvOpAtomicFMaxEXT);
+ case kIROp_AtomicAnd:
+ return SpvOpAtomicAnd;
+ case kIROp_AtomicOr:
+ return SpvOpAtomicOr;
+ case kIROp_AtomicXor:
+ return SpvOpAtomicXor;
+ default:
+ SLANG_UNEXPECTED("unhandled atomic op");
+ UNREACHABLE_RETURN(SpvOpNop);
+ }
+ }
+
+ void ensureAtomicCapability(IRInst* atomicInst, SpvOp op)
+ {
+ switch (op)
+ {
+ case SpvOpAtomicFAddEXT:
+ {
+ auto typeOp = getVectorElementType(atomicInst->getDataType())->getOp();
+ switch (typeOp)
+ {
+ case kIROp_FloatType:
+ ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_add"));
+ requireSPIRVCapability(SpvCapabilityAtomicFloat32AddEXT);
+ break;
+ case kIROp_DoubleType:
+ ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_add"));
+ requireSPIRVCapability(SpvCapabilityAtomicFloat64AddEXT);
+ break;
+ case kIROp_HalfType:
+ ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float16_add"));
+ requireSPIRVCapability(SpvCapabilityAtomicFloat16AddEXT);
+ break;
+ }
+ }
+ break;
+ case SpvOpAtomicFMinEXT:
+ case SpvOpAtomicFMaxEXT:
+ {
+ auto typeOp = getVectorElementType(atomicInst->getDataType())->getOp();
+ switch (typeOp)
+ {
+ case kIROp_FloatType:
+ ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_min_max"));
+ requireSPIRVCapability(SpvCapabilityAtomicFloat32MinMaxEXT);
+ break;
+ case kIROp_DoubleType:
+ ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_min_max"));
+ requireSPIRVCapability(SpvCapabilityAtomicFloat64MinMaxEXT);
+ break;
+ case kIROp_HalfType:
+ ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_min_max"));
+ requireSPIRVCapability(SpvCapabilityAtomicFloat16MinMaxEXT);
+ break;
+ }
+ }
+ break;
+ }
+ }
+
// The instructions that appear inside the basic blocks of
// functions are what we will call "local" instructions.
//
@@ -3200,22 +3315,82 @@ struct SPIRVEmitContext
case kIROp_ImageSubscript:
result = emitImageSubscript(parent, as<IRImageSubscript>(inst));
break;
- case kIROp_AtomicCounterIncrement:
+ case kIROp_AtomicInc:
{
IRBuilder builder{inst};
- const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType());
- const auto memorySemantics = emitIntConstant(IRIntegerValue{SpvMemorySemanticsMaskNone}, builder.getUIntType());
+ const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType());
+ const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(1));
result = emitOpAtomicIIncrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics);
}
break;
- case kIROp_AtomicCounterDecrement:
+ case kIROp_AtomicDec:
{
- IRBuilder builder{inst};
- const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType());
- const auto memorySemantics = emitIntConstant(IRIntegerValue{SpvMemorySemanticsMaskNone}, builder.getUIntType());
+ IRBuilder builder{ inst };
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType());
+ const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(1));
result = emitOpAtomicIDecrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics);
}
break;
+ case kIROp_AtomicLoad:
+ {
+ IRBuilder builder{ inst };
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType());
+ const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(1));
+ result = emitOpAtomicLoad(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics);
+ }
+ break;
+ case kIROp_AtomicStore:
+ {
+ IRBuilder builder{ inst };
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType());
+ const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(2));
+ result = emitOpAtomicStore(parent, inst, inst->getOperand(0), memoryScope, memorySemantics, inst->getOperand(1));
+ }
+ break;
+ case kIROp_AtomicExchange:
+ {
+ IRBuilder builder{ inst };
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType());
+ const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(2));
+ result = emitOpAtomicExchange(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics, inst->getOperand(1));
+ }
+ break;
+ case kIROp_AtomicCompareExchange:
+ {
+ IRBuilder builder{ inst };
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType());
+ const auto memorySemanticsEqual = emitMemorySemanticMask(inst->getOperand(3));
+ const auto memorySemanticsUnequal = emitMemorySemanticMask(inst->getOperand(4));
+ result = emitOpAtomicCompareExchange(
+ parent, inst, inst->getFullType(), inst->getOperand(0),
+ memoryScope, memorySemanticsEqual, memorySemanticsUnequal,
+ inst->getOperand(2), inst->getOperand(1));
+ }
+ break;
+ case kIROp_AtomicAdd:
+ case kIROp_AtomicSub:
+ case kIROp_AtomicMax:
+ case kIROp_AtomicMin:
+ case kIROp_AtomicAnd:
+ case kIROp_AtomicOr:
+ case kIROp_AtomicXor:
+ {
+ IRBuilder builder{ inst };
+ const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType());
+ const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(2));
+ bool negateOperand = false;
+ auto spvOp = getSpvAtomicOp(inst, negateOperand);
+ auto operand = inst->getOperand(1);
+ if (negateOperand)
+ {
+ builder.setInsertBefore(inst);
+ auto negatedOperand = builder.emitNeg(inst->getDataType(), operand);
+ operand = negatedOperand;
+ }
+ result = emitOpAtomicOp(parent, inst, spvOp, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics, operand);
+ ensureAtomicCapability(inst, spvOp);
+ }
+ break;
case kIROp_ControlBarrier:
{
IRBuilder builder{ inst };
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index 40d8ace91..06fea46b4 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -275,7 +275,6 @@ void WGSLSourceEmitter::emitStructFieldAttributes(
bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst)
{
- // Structured buffers are mapped to 'array' types, which don't need dereferencing
if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr)
return false;
@@ -470,6 +469,14 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
m_writer->emit(">");
return;
}
+
+ case kIROp_AtomicType:
+ {
+ m_writer->emit("atomic<");
+ emitType(cast<IRAtomicType>(type)->getElementType());
+ m_writer->emit(">");
+ return;
+ }
default:
break;
@@ -504,13 +511,32 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
}
-void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, const bool isConstant)
+void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl)
{
- if (isConstant)
- m_writer->emit("const");
- else
+ switch (varDecl->getOp())
+ {
+ case kIROp_GlobalParam:
+ case kIROp_GlobalVar:
+ case kIROp_Var:
m_writer->emit("var");
- if (type->getOp() == kIROp_HLSLRWStructuredBufferType)
+ break;
+ default:
+ if (as<IRModuleInst>(varDecl->getParent()))
+ {
+ m_writer->emit("const");
+ }
+ else
+ {
+ m_writer->emit("var");
+ }
+ break;
+ }
+
+ if (as<IRGroupSharedRate>(varDecl->getRate()))
+ {
+ m_writer->emit("<workgroup>");
+ }
+ else if (type->getOp() == kIROp_HLSLRWStructuredBufferType)
{
m_writer->emit("<");
m_writer->emit("storage, read_write");
@@ -789,6 +815,144 @@ void WGSLSourceEmitter::emitParamTypeImpl(IRType* type, const String& name)
emitType(type, name);
}
+bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst)
+{
+ switch (inst->getOp())
+ {
+ default:
+ return false;
+ case kIROp_AtomicLoad:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicLoad(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("));\n");
+ return true;
+ }
+ case kIROp_AtomicStore:
+ {
+ m_writer->emit("atomicStore(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicExchange(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicCompareExchange:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicCompareExchangeWeak(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
+ m_writer->emit(").old_value;\n");
+ return true;
+ }
+ case kIROp_AtomicAdd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicSub:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicSub(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicAnd:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAnd(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicOr:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicOr(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicXor:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicXor(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMin:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicMin(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicMax:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicMax(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(");\n");
+ return true;
+ }
+ case kIROp_AtomicInc:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicAdd(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitType(inst->getDataType());
+ m_writer->emit("(1));\n");
+ return true;
+ }
+ case kIROp_AtomicDec:
+ {
+ emitInstResultDecl(inst);
+ m_writer->emit("atomicSub(&(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit("), ");
+ emitType(inst->getDataType());
+ m_writer->emit("(1));\n");
+ return true;
+ }
+ }
+}
+
bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
EmitOpInfo outerPrec = inOuterPrec;
@@ -869,6 +1033,17 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
}
break;
+ case kIROp_RWStructuredBufferGetElementPtr:
+ {
+ m_writer->emit("(*");
+ emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix)));
+ m_writer->emit(")[");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit("]");
+ return true;
+ }
+ break;
+
case kIROp_StructuredBufferLoad:
case kIROp_RWStructuredBufferLoad:
{
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index d3cf19d91..b3a4efb55 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -28,6 +28,7 @@ public:
virtual bool tryEmitInstExprImpl(
IRInst* inst, const EmitOpInfo& inOuterPrec
) SLANG_OVERRIDE;
+ virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE;
virtual void emitSwitchCaseSelectorsImpl(
IRBasicType *const switchCondition,
const SwitchRegion::Case *const currentCase,
@@ -36,7 +37,7 @@ public:
virtual void emitSimpleTypeAndDeclaratorImpl(
IRType* type, DeclaratorInfo* declarator
) SLANG_OVERRIDE;
- virtual void emitVarKeywordImpl(IRType * type, const bool isConstant) SLANG_OVERRIDE;
+ virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE;
virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE;
virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 0d689660e..dc4f7dff4 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -69,7 +69,9 @@ INST(Nop, nop, 0, 0)
INST(TensorViewType, TensorView, 1, HOISTABLE)
INST(TorchTensorType, TorchTensor, 0, HOISTABLE)
INST(ArrayListType, ArrayListVector, 1, HOISTABLE)
-
+
+ INST(AtomicType, Atomic, 1, HOISTABLE)
+
/* BindExistentialsTypeBase */
// A `BindExistentials<B, T0,w0, T1,w1, ...>` represents
@@ -400,6 +402,21 @@ INST(Var, var, 0, 0)
INST(Load, load, 1, 0)
INST(Store, store, 2, 0)
+// Atomic Operations
+INST(AtomicLoad, atomicLoad, 1, 0)
+INST(AtomicStore, atomicStore, 2, 0)
+INST(AtomicExchange, atomicExchange, 2, 0)
+INST(AtomicCompareExchange, atomicCompareExchange, 3, 0)
+INST(AtomicAdd, atomicAdd, 2, 0)
+INST(AtomicSub, atomicSub, 2, 0)
+INST(AtomicAnd, atomicAnd, 2, 0)
+INST(AtomicOr, atomicOr, 2, 0)
+INST(AtomicXor, atomicXor, 2, 0)
+INST(AtomicMin, atomicMin, 2, 0)
+INST(AtomicMax, atomicMax, 2, 0)
+INST(AtomicInc, atomicInc, 1, 0)
+INST(AtomicDec, atomicDec, 1, 0)
+
// Produced and removed during backward auto-diff pass as a temporary placeholder representing the
// currently accumulated derivative to pass to some dOut argument in a nested call.
INST(LoadReverseGradient, LoadReverseGradient, 1, 0)
@@ -515,9 +532,6 @@ INST(StructuredBufferGetDimensions, StructuredBufferGetDimensions, 1, 0)
// Resource qualifiers for dynamically varying index
INST(NonUniformResourceIndex, nonUniformResourceIndex, 1, 0)
-INST(AtomicCounterIncrement, AtomicCounterIncrement, 1, 0)
-INST(AtomicCounterDecrement, AtomicCounterDecrement, 1, 0)
-
INST(GetNaturalStride, getNaturalStride, 1, 0)
INST(MeshOutputRef, meshOutputRef, 2, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index f31a56673..24a779aed 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2358,6 +2358,14 @@ struct IRLoad : IRInst
IRInst* getPtr() { return ptr.get(); }
};
+struct IRAtomicLoad : IRInst
+{
+ IRUse ptr;
+ IR_LEAF_ISA(AtomicLoad)
+
+ IRInst* getPtr() { return ptr.get(); }
+};
+
struct IRStore : IRInst
{
IRUse ptr;
@@ -2368,6 +2376,16 @@ struct IRStore : IRInst
IRInst* getVal() { return val.get(); }
};
+struct IRAtomicStore : IRInst
+{
+ IRUse ptr;
+ IRUse val;
+ IR_LEAF_ISA(AtomicStore)
+
+ IRInst* getPtr() { return ptr.get(); }
+ IRInst* getVal() { return val.get(); }
+};
+
struct IRRWStructuredBufferStore : IRInst
{
IR_LEAF_ISA(RWStructuredBufferStore)
@@ -4235,6 +4253,11 @@ public:
IRInst* dstPtr,
IRInst* srcVal);
+ IRInst* emitAtomicStore(
+ IRInst* dstPtr,
+ IRInst* srcVal,
+ IRInst* memoryOrder);
+
IRInst* emitImageLoad(
IRType* type,
ShortList<IRInst*> params);
diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp
index 3a2471930..82287f58e 100644
--- a/source/slang/slang-ir-layout.cpp
+++ b/source/slang/slang-ir-layout.cpp
@@ -211,6 +211,14 @@ case kIROp_##TYPE##Type: \
}
break;
+ case kIROp_AtomicType:
+ {
+ auto atomicType = cast<IRAtomicType>(type);
+ _calcSizeAndAlignment(optionSet, rules, atomicType->getElementType(), outSizeAndAlignment);
+ return SLANG_OK;
+ }
+ break;
+
case kIROp_UnsizedArrayType:
{
auto unsizedArrayType = cast<IRUnsizedArrayType>(type);
diff --git a/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
index da6b8d34d..ac8a8b5e7 100644
--- a/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
+++ b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp
@@ -99,7 +99,8 @@ namespace Slang
auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
- auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterIncrement, 1, &counterBufferPtr);
+ IRInst* atomicIncArgs[] = { counterBufferPtr, builder.getIntValue(builder.getIntType(), kIRMemoryOrder_Relaxed) };
+ auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicInc, 2, atomicIncArgs);
IRInst* getElementPtrArgs[] = { elementBuffer, oldCounter };
auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs);
@@ -122,7 +123,8 @@ namespace Slang
auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey);
IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) };
auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs);
- auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterDecrement, 1, &counterBufferPtr);
+ IRInst* atomicDecArgs[] = { counterBufferPtr, builder.getIntValue(builder.getIntType(), kIRMemoryOrder_Relaxed) };
+ auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicDec, 2, atomicDecArgs);
auto index = builder.emitSub(builder.getIntType(), oldCounter, builder.getIntValue(builder.getIntType(), 1));
// Test if index is greater or equal than 0.
diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp
index 56b13aa09..c09077528 100644
--- a/source/slang/slang-ir-use-uninitialized-values.cpp
+++ b/source/slang/slang-ir-use-uninitialized-values.cpp
@@ -297,6 +297,7 @@ namespace Slang
case kIROp_Store:
case kIROp_SwizzledStore:
case kIROp_SPIRVAsm:
+ case kIROp_AtomicStore:
return Store;
case kIROp_SPIRVAsmOperandInst:
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index b89929f55..c12ad0f68 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5084,6 +5084,23 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitAtomicStore(
+ IRInst* dstPtr,
+ IRInst* srcVal,
+ IRInst* memoryOrder)
+ {
+ auto inst = createInst<IRAtomicStore>(
+ this,
+ kIROp_AtomicStore,
+ nullptr,
+ dstPtr,
+ srcVal,
+ memoryOrder);
+
+ addInst(inst);
+ return inst;
+ }
+
/// @param params An ordered list of imageLoad parameters { image, coord, [optional] seperateArrayCoord, [optional] seperateSampleCoord }
IRInst* IRBuilder::emitImageLoad(IRType* type, ShortList<IRInst*> params)
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 14dde200f..ee12cff8c 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -93,6 +93,15 @@ enum IROpMask : std::underlying_type_t<IROp>
kIROpMask_OpMask = 0x3ff, ///< Mask for just opcode
};
+enum IRMemoryOrder
+{
+ kIRMemoryOrder_Relaxed = 0,
+ kIRMemoryOrder_Acquire = 1,
+ kIRMemoryOrder_Release = 2,
+ kIRMemoryOrder_AcquireRelease = 3,
+ kIRMemoryOrder_SeqCst = 4,
+};
+
inline int32_t operator&(const IROpMask m, const IROp o)
{
#if defined(__cpp_lib_bit_cast)
@@ -1614,6 +1623,13 @@ struct IRArrayType: IRArrayTypeBase
SIMPLE_IR_TYPE(UnsizedArrayType, ArrayTypeBase)
+struct IRAtomicType : IRType
+{
+ IR_LEAF_ISA(AtomicType)
+
+ IRType* getElementType() { return (IRType*)getOperand(0); }
+};
+
SIMPLE_IR_PARENT_TYPE(Rate, Type)
SIMPLE_IR_TYPE(ConstExprRate, Rate)
SIMPLE_IR_TYPE(GroupSharedRate, Rate)
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 813467743..d8dbaa812 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -7130,9 +7130,19 @@ top:
// The `left` value is just a pointer, so we can emit
// a store to it directly.
//
- builder->emitStore(
- left.val,
- getSimpleVal(context, right));
+ if (as<IRAtomicType>(tryGetPointedToType(builder, left.val->getDataType())))
+ {
+ builder->emitAtomicStore(
+ left.val,
+ getSimpleVal(context, right),
+ builder->getIntValue(builder->getIntType(), kIRMemoryOrder_Relaxed));
+ }
+ else
+ {
+ builder->emitStore(
+ left.val,
+ getSimpleVal(context, right));
+ }
}
break;
diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp
index 2447f5787..57635122e 100644
--- a/source/slang/slang-type-layout.cpp
+++ b/source/slang/slang-type-layout.cpp
@@ -4427,6 +4427,10 @@ static TypeLayoutResult _createTypeLayout(
{
return createArrayLikeTypeLayout(context, arrayType, arrayType->getElementType(), arrayType->getElementCount());
}
+ else if (auto atomicType = as<AtomicType>(type))
+ {
+ return _createTypeLayout(context, atomicType->getElementType());
+ }
else if (auto ptrType = as<PtrTypeBase>(type))
{
RefPtr<PointerTypeLayout> ptrLayout = new PointerTypeLayout();