diff options
| author | Yong He <yonghe@outlook.com> | 2024-09-20 15:11:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-20 15:11:23 -0700 |
| commit | 490834924cc390cb812713c225b9a8227c66cf1f (patch) | |
| tree | 5644e2a18cb085692d5fe9625f42582db07447be /source | |
| parent | b4c851fb1419f869bddaa08487f58376bc0a7144 (diff) | |
Initial `Atomic<T>` type implementation. (#5125)
* Initial Atomic<T> type implementation.
* Update design doc.
* Fix.
* Add test.
* Fixes and add tests.
* Fix WGSL.
* Fix glsl.
* Fix metal.
* experiemnt with github metal.
* experiment github metal 2
* github metal experiment 3
* experiment with github metal 4.
* experiment with metal 5.
* experiment 7.
* metal experiment 8.
* Fix metal tests.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
26 files changed, 1415 insertions, 102 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 476279ab8..e4e24dddb 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2721,6 +2721,132 @@ __Addr<T> __getLegalizedSPIRVGlobalParamAddr(T val); __intrinsic_op($(kIROp_RequireComputeDerivative)) void __requireComputeDerivative(); +// Atomic<T> + +enum MemoryOrder +{ + Relaxed = $(kIRMemoryOrder_Relaxed), + Acquire = $(kIRMemoryOrder_Acquire), + Release = $(kIRMemoryOrder_Release), + AcquireRelease = $(kIRMemoryOrder_AcquireRelease), + SeqCst = $(kIRMemoryOrder_SeqCst), +} + +[sealed] interface IAtomicable {} +[sealed] interface IArithmeticAtomicable : IAtomicable, IArithmetic {} +[sealed] interface IBitAtomicable : IArithmeticAtomicable, IInteger {} + +extension int : IBitAtomicable {} +extension uint : IBitAtomicable {} +extension int64_t : IBitAtomicable {} +extension uint64_t : IBitAtomicable {} +extension double : IArithmeticAtomicable {} +extension float : IArithmeticAtomicable {} +extension half : IArithmeticAtomicable {} + +__magic_type(AtomicType) +__intrinsic_type($(kIROp_AtomicType)) +[require(cuda_glsl_hlsl_metal_spirv_wgsl)] +struct Atomic<T : IAtomicable> +{ + __intrinsic_op($(kIROp_AtomicLoad)) + [__ref] T load(MemoryOrder order = MemoryOrder.Relaxed); + + __intrinsic_op($(kIROp_AtomicStore)) + [__ref] void store(T newValue, MemoryOrder order = MemoryOrder.Relaxed); + + __intrinsic_op($(kIROp_AtomicExchange)) + [__ref] T exchange(T newValue, MemoryOrder order = MemoryOrder.Relaxed); // returns old value + + __intrinsic_op($(kIROp_AtomicCompareExchange)) + [__ref] T compareExchange( + T compareValue, + T newValue, + MemoryOrder successOrder = MemoryOrder.Relaxed, + MemoryOrder failOrder = MemoryOrder.Relaxed); +} + +extension<T : IArithmeticAtomicable> Atomic<T> +{ + __intrinsic_op($(kIROp_AtomicAdd)) + [__ref] T add(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value + __intrinsic_op($(kIROp_AtomicSub)) + [__ref] T sub(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value + __intrinsic_op($(kIROp_AtomicMax)) + [__ref] T max(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value + __intrinsic_op($(kIROp_AtomicMin)) + [__ref] T min(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value +} + +extension<T : IBitAtomicable> Atomic<T> +{ + __intrinsic_op($(kIROp_AtomicAnd)) + [__ref] T and(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value + __intrinsic_op($(kIROp_AtomicOr)) + [__ref] T or(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value + __intrinsic_op($(kIROp_AtomicXor)) + [__ref] T xor(T value, MemoryOrder order = MemoryOrder.Relaxed); // returns original value + __intrinsic_op($(kIROp_AtomicInc)) + [__ref] T increment(MemoryOrder order = MemoryOrder.Relaxed); + __intrinsic_op($(kIROp_AtomicDec)) + [__ref] T decrement(MemoryOrder order = MemoryOrder.Relaxed); +} + +__generic<T : IArithmeticAtomicable> +[ForceInline] +T operator +=(__ref Atomic<T> v, T value) +{ + return v.add(value) + value; +} +__generic<T : IArithmeticAtomicable> +[ForceInline] +T operator -=(__ref Atomic<T> v, T value) +{ + return v.sub(value) - value; +} +__generic<T : IBitAtomicable> +[ForceInline] +T operator &=(__ref Atomic<T> v, T value) +{ + return v.and(value) & value; +} +__generic<T : IBitAtomicable> +[ForceInline] +T operator |=(__ref Atomic<T> v, T value) +{ + return v.or(value) | value; +} +__generic<T : IBitAtomicable> +[ForceInline] +T operator ^=(__ref Atomic<T> v, T value) +{ + return v.xor(value) ^ value; +} + +__generic<T : IBitAtomicable> +[ForceInline] +__prefix T operator ++(__ref Atomic<T> v) +{ + return v.increment() + T(1); +} +__generic<T : IBitAtomicable> +[ForceInline] +__postfix T operator ++(__ref Atomic<T> v) +{ + return v.increment(); +} +__generic<T : IBitAtomicable> +[ForceInline] +__prefix T operator --(__ref Atomic<T> v) +{ + return v.decrement() - T(1); +} +__generic<T : IBitAtomicable> +[ForceInline] +__postfix T operator --(__ref Atomic<T> v) +{ + return v.decrement(); +} // Binding Attributes __attributeTarget(DeclBase) diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 1c9f68a48..616c2a67a 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -328,6 +328,12 @@ bool ArrayExpressionType::isUnsized() return false; } +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! AtomicType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +Type* AtomicType::getElementType() +{ + return as<Type>(_getGenericTypeArg(this, 0)); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! TypeType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void TypeType::_toTextOverride(StringBuilder& out) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 46ea3ea55..1239283d7 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -436,6 +436,13 @@ class ArrayExpressionType : public DeclRefType IntVal* getElementCount(); }; +class AtomicType : public DeclRefType +{ + SLANG_AST_CLASS(AtomicType) + + Type* getElementType(); +}; + // The "type" of an expression that resolves to a type. // For example, in the expression `float(2)` the sub-expression, // `float` would have the type `TypeType(float)`. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 5233008fd..1b9725a8c 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2274,10 +2274,14 @@ namespace Slang expr->left = maybeOpenRef(expr->left); auto type = expr->left->type; + if (auto atomicType = as<AtomicType>(type)) + { + type = atomicType->getElementType(); + } auto right = maybeOpenRef(expr->right); expr->right = coerce(CoercionSite::Assignment, type, right); - if (!type.isLeftValue) + if (!expr->left->type.isLeftValue) { if (as<ErrorType>(type)) { diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index caf3613a7..c60397b85 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -500,32 +500,6 @@ void CLikeSourceEmitter::defaultEmitInstStmt(IRInst* inst) { switch (inst->getOp()) { - case kIROp_AtomicCounterIncrement: - { - auto oldValName = getName(inst); - m_writer->emit("int "); - m_writer->emit(oldValName); - m_writer->emit(";\n"); - m_writer->emit("InterlockedAdd("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", 1, "); - m_writer->emit(oldValName); - m_writer->emit(");\n"); - } - break; - case kIROp_AtomicCounterDecrement: - { - auto oldValName = getName(inst); - m_writer->emit("int "); - m_writer->emit(oldValName); - m_writer->emit(";\n"); - m_writer->emit("InterlockedAdd("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", -1, "); - m_writer->emit(oldValName); - m_writer->emit(");\n"); - } - break; case kIROp_StructuredBufferGetDimensions: { auto count = _generateUniqueName(UnownedStringSlice("_elementCount")); @@ -1862,8 +1836,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) emitRateQualifiers(inst); - bool isConstant(as<IRModuleInst>(inst->getParent())); - if(isConstant) + if (as<IRModuleInst>(inst->getParent())) { // "Ordinary" instructions at module scope are constants @@ -1888,7 +1861,7 @@ void CLikeSourceEmitter::emitInstResultDecl(IRInst* inst) } - emitVarKeyword(type, isConstant); + emitVarKeyword(type, inst); emitType(type, getName(inst)); m_writer->emit(" = "); @@ -2920,8 +2893,19 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) // Insts that needs to be emitted as code blocks. case kIROp_CudaKernelLaunch: - case kIROp_AtomicCounterIncrement: - case kIROp_AtomicCounterDecrement: + case kIROp_AtomicLoad: + case kIROp_AtomicStore: + case kIROp_AtomicInc: + case kIROp_AtomicDec: + case kIROp_AtomicAdd: + case kIROp_AtomicSub: + case kIROp_AtomicAnd: + case kIROp_AtomicOr: + case kIROp_AtomicXor: + case kIROp_AtomicMin: + case kIROp_AtomicMax: + case kIROp_AtomicExchange: + case kIROp_AtomicCompareExchange: case kIROp_StructuredBufferGetDimensions: case kIROp_MetalAtomicCast: emitInstStmt(inst); @@ -3143,7 +3127,7 @@ void CLikeSourceEmitter::_emitStoreImpl(IRStore* store) void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type) { - emitVarKeyword(type, /* isConstant */ false); + emitVarKeyword(type, inst); emitType(type, getName(inst)); @@ -3975,7 +3959,7 @@ void CLikeSourceEmitter::emitParameterGroup(IRGlobalParam* varDecl, IRUniformPar emitParameterGroupImpl(varDecl, type); } -void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, bool /* isConstant */) {} +void CLikeSourceEmitter::emitVarKeywordImpl(IRType * /* type */, IRInst* /* varDecl */) {} void CLikeSourceEmitter::emitVar(IRVar* varDecl) { @@ -4015,7 +3999,7 @@ void CLikeSourceEmitter::emitVar(IRVar* varDecl) #endif emitRateQualifiersAndAddressSpace(varDecl); - emitVarKeyword(varType, /* isConstant */ false); + emitVarKeyword(varType, varDecl); emitType(varType, getName(varDecl)); @@ -4147,7 +4131,7 @@ void CLikeSourceEmitter::emitGlobalVar(IRGlobalVar* varDecl) emitVarModifiers(layout, varDecl, varType); emitRateQualifiersAndAddressSpace(varDecl); - emitVarKeyword(varType, /* isConstant */ true); + emitVarKeyword(varType, varDecl); emitType(varType, getName(varDecl)); // TODO: These shouldn't be needed for ordinary @@ -4221,7 +4205,7 @@ void CLikeSourceEmitter::emitGlobalParam(IRGlobalParam* varDecl) emitDecorationLayoutSemantics(varDecl, "register"); emitRateQualifiersAndAddressSpace(varDecl); - emitVarKeyword(varType, /* isConstant */ false); + emitVarKeyword(varType, varDecl); emitGlobalParamType(varType, getName(varDecl)); emitSemantics(varDecl); diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index be769f31f..ccc25de57 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -498,8 +498,8 @@ public: virtual void emitSimpleTypeAndDeclaratorImpl(IRType* type, DeclaratorInfo* declarator); void emitSimpleTypeAndDeclarator(IRType* type, DeclaratorInfo* declarator) {emitSimpleTypeAndDeclaratorImpl(type, declarator);}; - virtual void emitVarKeywordImpl(IRType * type, bool isConstant); - void emitVarKeyword(IRType * type, bool isConstant) {emitVarKeywordImpl(type, isConstant);} + virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl); + void emitVarKeyword(IRType * type, IRInst* varDecl) {emitVarKeywordImpl(type, varDecl);} virtual void beforeComputeEmitActions(IRModule* module) { SLANG_UNUSED(module); }; diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index bcb9ed9da..19dc05dcf 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -337,6 +337,10 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S out << intLit->getValue(); return SLANG_OK; } + case kIROp_AtomicType: + { + return calcTypeName((IRType*)type->getOperand(0), target, out); + } default: { if (isNominalOp(type->getOp())) diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 05485855c..81bcafeb3 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -474,6 +474,129 @@ bool CUDASourceEmitter::tryEmitInstStmtImpl(IRInst* inst) m_writer->emit(");\n"); return true; } + case kIROp_AtomicLoad: + { + emitInstResultDecl(inst); + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_AtomicStore: + { + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(" = "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_AtomicExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomicExch("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicCompareExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomicCAS("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicAdd: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicSub: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("));\n"); + return true; + } + case kIROp_AtomicAnd: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAnd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicOr: + { + emitInstResultDecl(inst); + m_writer->emit("atomicOr("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicXor: + { + emitInstResultDecl(inst); + m_writer->emit("atomicXor("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMin: + { + emitInstResultDecl(inst); + m_writer->emit("atomicMin("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMax: + { + emitInstResultDecl(inst); + m_writer->emit("atomicMax("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicInc: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", 1);\n"); + return true; + } + case kIROp_AtomicDec: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -1);\n"); + return true; + } default: return false; } diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 56113409d..116bf67cd 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2156,44 +2156,152 @@ bool GLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) { switch (inst->getOp()) { - case kIROp_AtomicCounterIncrement: - { - auto oldValName = getName(inst); - m_writer->emit("int "); - m_writer->emit(oldValName); - m_writer->emit(" = "); - m_writer->emit("atomicAdd("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", 1);\n"); - return true; - } - case kIROp_AtomicCounterDecrement: - { - auto oldValName = getName(inst); - m_writer->emit("int "); - m_writer->emit(oldValName); - m_writer->emit(" = "); - m_writer->emit("atomicAdd("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit(", -1);\n"); - return true; - } case kIROp_StructuredBufferGetDimensions: + { + emitInstResultDecl(inst); + m_writer->emit("uvec2("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("._data.length(), "); + auto elementType = as<IRHLSLStructuredBufferTypeBase>(inst->getOperand(0)->getDataType())->getElementType(); + IRIntegerValue stride = 0; + if (auto sizeDecor = elementType->findDecoration<IRSizeAndAlignmentDecoration>()) { - emitInstResultDecl(inst); - m_writer->emit("uvec2("); - emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); - m_writer->emit("._data.length(), "); - auto elementType = as<IRHLSLStructuredBufferTypeBase>(inst->getOperand(0)->getDataType())->getElementType(); - IRIntegerValue stride = 0; - if (auto sizeDecor = elementType->findDecoration<IRSizeAndAlignmentDecoration>()) - { - stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment()); - } - m_writer->emit(stride); - m_writer->emit(");\n"); - return true; + stride = align(sizeDecor->getSize(), (int)sizeDecor->getAlignment()); } + m_writer->emit(stride); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicLoad: + { + emitInstResultDecl(inst); + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_AtomicStore: + { + emitInstResultDecl(inst); + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(" = "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_AtomicExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomicExchange("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicCompareExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomicCompSwap("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicAdd: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicSub: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("));\n"); + return true; + } + case kIROp_AtomicAnd: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAnd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicOr: + { + emitInstResultDecl(inst); + m_writer->emit("atomicOr("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicXor: + { + emitInstResultDecl(inst); + m_writer->emit("atomicXor("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMin: + { + emitInstResultDecl(inst); + m_writer->emit("atomicMin("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMax: + { + emitInstResultDecl(inst); + m_writer->emit("atomicMax("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicInc: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitType(inst->getDataType()); + m_writer->emit("(1)"); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicDec: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitType(inst->getDataType()); + m_writer->emit("(-1)"); + m_writer->emit(");\n"); + return true; + } default: return false; } @@ -2572,6 +2680,11 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit("DNV"); return; } + case kIROp_AtomicType: + { + emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); + return; + } default: break; } diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index f9765a555..b45b4c575 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -496,6 +496,171 @@ void HLSLSourceEmitter::emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPoin } } +bool HLSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + case kIROp_AtomicLoad: + { + emitInstResultDecl(inst); + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_AtomicStore: + { + emitDereferenceOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(" = "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + return true; + } + case kIROp_AtomicExchange: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedExchange("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicCompareExchange: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedCompareExchange("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicAdd: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicSub: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -("); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("), "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicAnd: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAnd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicOr: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedOr("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicXor: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedXor("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMin: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedMin("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMax: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedMax("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicInc: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", 1, "); + m_writer->emit(getName(inst)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicDec: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n"); + m_writer->emit("InterlockedAdd("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", -1, "); + m_writer->emit(getName(inst)); + m_writer->emit(");"); + return true; + } + default: + return false; + } +} + bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch (inst->getOp()) @@ -755,7 +920,6 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return true; } break; - default: break; } // Not handled @@ -1030,6 +1194,11 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit("uint4"); return; } + case kIROp_AtomicType: + { + emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); + return; + } default: break; } diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index 4d721a1ac..31fa6b290 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -49,6 +49,7 @@ protected: virtual void emitParamTypeModifier(IRType* type) SLANG_OVERRIDE { emitMatrixLayoutModifiersImpl(type); } virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; + virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; virtual void emitFuncDecorationImpl(IRDecoration* decoration) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-metal.cpp b/source/slang/slang-emit-metal.cpp index a0fe220e5..312d06c08 100644 --- a/source/slang/slang-emit-metal.cpp +++ b/source/slang/slang-emit-metal.cpp @@ -246,6 +246,20 @@ void MetalSourceEmitter::ensurePrelude(const char* preludeText) m_requiredPreludes.add(stringLit); } +void MetalSourceEmitter::emitMemoryOrderOperand(IRInst* inst) +{ + auto memoryOrder = (IRMemoryOrder)getIntVal(inst); + switch (memoryOrder) + { + case kIRMemoryOrder_Relaxed: + m_writer->emit("memory_order_relaxed"); + break; + default: + m_writer->emit("memory_order_seq_cst"); + break; + } +} + bool MetalSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) { switch (inst->getOp()) @@ -271,6 +285,164 @@ bool MetalSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) m_writer->emit("));\n"); return true; } + case kIROp_AtomicLoad: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_load_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(1)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicStore: + { + m_writer->emit("atomic_store_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_exchange_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicCompareExchange: + { + emitType(inst->getDataType(), getName(inst)); + m_writer->emit(";\n{\n"); + emitType(inst->getDataType(), "_metal_cas_comparand"); + m_writer->emit(" = "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(";\n"); + + m_writer->emit(getName(inst)); + m_writer->emit(" = atomic_compare_exchange_weak_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", &_metal_cas_comparand, "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(3)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(4)); + m_writer->emit(");\n}\n"); + return true; + } + case kIROp_AtomicAdd: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_add_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicSub: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_sub_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicAnd: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_and_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicOr: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_or_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicXor: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_xor_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMin: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_min_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMax: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_max_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitMemoryOrderOperand(inst->getOperand(2)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicInc: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_add_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", 1, "); + emitMemoryOrderOperand(inst->getOperand(1)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicDec: + { + emitInstResultDecl(inst); + m_writer->emit("atomic_fetch_sub_explicit("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", 1, "); + emitMemoryOrderOperand(inst->getOperand(1)); + m_writer->emit(");\n"); + return true; + } } return false; } @@ -664,10 +836,8 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) case kIROp_BoolType: case kIROp_Int8Type: case kIROp_IntType: - case kIROp_Int64Type: case kIROp_UInt8Type: case kIROp_UIntType: - case kIROp_UInt64Type: case kIROp_FloatType: case kIROp_DoubleType: case kIROp_HalfType: @@ -675,6 +845,12 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit(getDefaultBuiltinTypeName(type->getOp())); return; } + case kIROp_Int64Type: + m_writer->emit("long"); + return; + case kIROp_UInt64Type: + m_writer->emit("ulong"); + return; case kIROp_Int16Type: m_writer->emit("short"); return; @@ -682,10 +858,10 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit("ushort"); return; case kIROp_IntPtrType: - m_writer->emit("int64_t"); + m_writer->emit("long"); return; case kIROp_UIntPtrType: - m_writer->emit("uint64_t"); + m_writer->emit("ulong"); return; case kIROp_StructType: m_writer->emit(getName(type)); @@ -781,6 +957,13 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit("mesh_grid_properties "); return; } + case kIROp_AtomicType: + { + m_writer->emit("atomic<"); + emitSimpleTypeImpl(cast<IRAtomicType>(type)->getElementType()); + m_writer->emit(">"); + return; + } default: break; } diff --git a/source/slang/slang-emit-metal.h b/source/slang/slang-emit-metal.h index 67aa0d506..8e33eddef 100644 --- a/source/slang/slang-emit-metal.h +++ b/source/slang/slang-emit-metal.h @@ -26,6 +26,7 @@ protected: void ensurePrelude(const char* preludeText); + void emitMemoryOrderOperand(IRInst* inst); virtual void emitParameterGroupImpl(IRGlobalParam* varDecl, IRUniformParameterGroupType* type) SLANG_OVERRIDE; virtual void emitEntryPointAttributesImpl(IRFunc* irFunc, IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-spirv-ops.h b/source/slang/slang-emit-spirv-ops.h index 3d6bf846f..cea246b20 100644 --- a/source/slang/slang-emit-spirv-ops.h +++ b/source/slang/slang-emit-spirv-ops.h @@ -2354,6 +2354,152 @@ SpvInst* emitOpExecutionModeIdLocalSizeId( } template<typename T1, typename T2, typename T3, typename T4> +SpvInst* emitOpAtomicLoad( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& pointer, + const T3& memory, + const T4& semantics +) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + static_assert(isSingular<T4>); + return emitInst( + parent, + inst, + SpvOpAtomicLoad, + idResultType, + kResultID, + pointer, + memory, + semantics + ); +} + +template<typename T1, typename T2, typename T3, typename T4> +SpvInst* emitOpAtomicStore( + SpvInstParent* parent, + IRInst* inst, + const T1& pointer, + const T2& memory, + const T3& semantics, + const T4& value +) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + static_assert(isSingular<T4>); + return emitInst( + parent, + inst, + SpvOpAtomicStore, + pointer, + memory, + semantics, + value + ); +} + +template<typename T1, typename T2, typename T3, typename T4, typename T5> +SpvInst* emitOpAtomicExchange( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& pointer, + const T3& memory, + const T4& semantics, + const T5& value +) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + static_assert(isSingular<T4>); + static_assert(isSingular<T5>); + return emitInst( + parent, + inst, + SpvOpAtomicExchange, + idResultType, + kResultID, + pointer, + memory, + semantics, + value + ); +} + +template<typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename T7> +SpvInst* emitOpAtomicCompareExchange( + SpvInstParent* parent, + IRInst* inst, + const T1& idResultType, + const T2& pointer, + const T3& memory, + const T4& semanticsEqual, + const T5& semanticsUnequal, + const T6& value, + const T7& comparator +) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + static_assert(isSingular<T4>); + static_assert(isSingular<T5>); + static_assert(isSingular<T6>); + static_assert(isSingular<T7>); + + return emitInst( + parent, + inst, + SpvOpAtomicCompareExchange, + idResultType, + kResultID, + pointer, + memory, + semanticsEqual, + semanticsUnequal, + value, + comparator + ); +} + +template<typename T1, typename T2, typename T3, typename T4, typename T5> +SpvInst* emitOpAtomicOp( + SpvInstParent* parent, + IRInst* inst, + SpvOp op, + const T1& idResultType, + const T2& pointer, + const T3& memory, + const T4& semantics, + const T5& value +) +{ + static_assert(isSingular<T1>); + static_assert(isSingular<T2>); + static_assert(isSingular<T3>); + static_assert(isSingular<T4>); + static_assert(isSingular<T5>); + return emitInst( + parent, + inst, + op, + idResultType, + kResultID, + pointer, + memory, + semantics, + value + ); +} + +template<typename T1, typename T2, typename T3, typename T4> SpvInst* emitOpAtomicIIncrement( SpvInstParent* parent, IRInst* inst, diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index a2da4801e..ede573581 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1660,6 +1660,12 @@ struct SPIRVEmitContext SpvLiteralInteger::from32(stride)); return arrayType; } + case kIROp_AtomicType: + { + auto result = ensureInst(as<IRAtomicType>(inst)->getElementType()); + registerInst(inst, result); + return result; + } case kIROp_SubpassInputType: return ensureSubpassInputType(inst, cast<IRSubpassInputType>(inst)); case kIROp_TextureType: @@ -2860,6 +2866,115 @@ struct SPIRVEmitContext return (isSpirv16OrLater() || m_useDemoteToHelperInvocationExtension); } + SpvInst* emitMemorySemanticMask(IRInst* inst) + { + IRBuilder builder(inst); + auto memoryOrder = (IRMemoryOrder)getIntVal(inst); + switch (memoryOrder) + { + case kIRMemoryOrder_Relaxed: + return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsMaskNone }, builder.getUIntType()); + case kIRMemoryOrder_Acquire: + return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsAcquireMask }, builder.getUIntType()); + case kIRMemoryOrder_Release: + return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsReleaseMask }, builder.getUIntType()); + case kIRMemoryOrder_AcquireRelease: + return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsAcquireReleaseMask }, builder.getUIntType()); + case kIRMemoryOrder_SeqCst: + return emitIntConstant(IRIntegerValue{ SpvMemorySemanticsSequentiallyConsistentMask }, builder.getUIntType()); + default: + SLANG_UNEXPECTED("unhandled memory order"); + UNREACHABLE_RETURN(nullptr); + } + } + + SpvOp getSpvAtomicOp(IRInst* atomicInst, bool& outNegateOperand) + { + auto typeSelect = [&](SpvOp sop, SpvOp uop, SpvOp fop) + { + auto scalarType = getVectorElementType(atomicInst->getDataType()); + if (isIntegralType(scalarType)) + { + auto intInfo = getIntTypeInfo(scalarType); + if (intInfo.isSigned) + return sop; + return uop; + } + return fop; + }; + outNegateOperand = false; + switch (atomicInst->getOp()) + { + case kIROp_AtomicAdd: + return typeSelect(SpvOpAtomicIAdd, SpvOpAtomicIAdd, SpvOpAtomicFAddEXT); + case kIROp_AtomicSub: + if (isFloatingType(getVectorElementType(atomicInst->getDataType()))) + outNegateOperand = true; + return typeSelect(SpvOpAtomicISub, SpvOpAtomicISub, SpvOpAtomicFAddEXT); + case kIROp_AtomicMin: + return typeSelect(SpvOpAtomicSMin, SpvOpAtomicUMin, SpvOpAtomicFMinEXT); + case kIROp_AtomicMax: + return typeSelect(SpvOpAtomicSMax, SpvOpAtomicUMax, SpvOpAtomicFMaxEXT); + case kIROp_AtomicAnd: + return SpvOpAtomicAnd; + case kIROp_AtomicOr: + return SpvOpAtomicOr; + case kIROp_AtomicXor: + return SpvOpAtomicXor; + default: + SLANG_UNEXPECTED("unhandled atomic op"); + UNREACHABLE_RETURN(SpvOpNop); + } + } + + void ensureAtomicCapability(IRInst* atomicInst, SpvOp op) + { + switch (op) + { + case SpvOpAtomicFAddEXT: + { + auto typeOp = getVectorElementType(atomicInst->getDataType())->getOp(); + switch (typeOp) + { + case kIROp_FloatType: + ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_add")); + requireSPIRVCapability(SpvCapabilityAtomicFloat32AddEXT); + break; + case kIROp_DoubleType: + ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_add")); + requireSPIRVCapability(SpvCapabilityAtomicFloat64AddEXT); + break; + case kIROp_HalfType: + ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float16_add")); + requireSPIRVCapability(SpvCapabilityAtomicFloat16AddEXT); + break; + } + } + break; + case SpvOpAtomicFMinEXT: + case SpvOpAtomicFMaxEXT: + { + auto typeOp = getVectorElementType(atomicInst->getDataType())->getOp(); + switch (typeOp) + { + case kIROp_FloatType: + ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_min_max")); + requireSPIRVCapability(SpvCapabilityAtomicFloat32MinMaxEXT); + break; + case kIROp_DoubleType: + ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_min_max")); + requireSPIRVCapability(SpvCapabilityAtomicFloat64MinMaxEXT); + break; + case kIROp_HalfType: + ensureExtensionDeclaration(toSlice("SPV_EXT_shader_atomic_float_min_max")); + requireSPIRVCapability(SpvCapabilityAtomicFloat16MinMaxEXT); + break; + } + } + break; + } + } + // The instructions that appear inside the basic blocks of // functions are what we will call "local" instructions. // @@ -3200,22 +3315,82 @@ struct SPIRVEmitContext case kIROp_ImageSubscript: result = emitImageSubscript(parent, as<IRImageSubscript>(inst)); break; - case kIROp_AtomicCounterIncrement: + case kIROp_AtomicInc: { IRBuilder builder{inst}; - const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType()); - const auto memorySemantics = emitIntConstant(IRIntegerValue{SpvMemorySemanticsMaskNone}, builder.getUIntType()); + const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType()); + const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(1)); result = emitOpAtomicIIncrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); } break; - case kIROp_AtomicCounterDecrement: + case kIROp_AtomicDec: { - IRBuilder builder{inst}; - const auto memoryScope = emitIntConstant(IRIntegerValue{SpvScopeDevice}, builder.getUIntType()); - const auto memorySemantics = emitIntConstant(IRIntegerValue{SpvMemorySemanticsMaskNone}, builder.getUIntType()); + IRBuilder builder{ inst }; + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType()); + const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(1)); result = emitOpAtomicIDecrement(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); } break; + case kIROp_AtomicLoad: + { + IRBuilder builder{ inst }; + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType()); + const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(1)); + result = emitOpAtomicLoad(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics); + } + break; + case kIROp_AtomicStore: + { + IRBuilder builder{ inst }; + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType()); + const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(2)); + result = emitOpAtomicStore(parent, inst, inst->getOperand(0), memoryScope, memorySemantics, inst->getOperand(1)); + } + break; + case kIROp_AtomicExchange: + { + IRBuilder builder{ inst }; + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType()); + const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(2)); + result = emitOpAtomicExchange(parent, inst, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics, inst->getOperand(1)); + } + break; + case kIROp_AtomicCompareExchange: + { + IRBuilder builder{ inst }; + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType()); + const auto memorySemanticsEqual = emitMemorySemanticMask(inst->getOperand(3)); + const auto memorySemanticsUnequal = emitMemorySemanticMask(inst->getOperand(4)); + result = emitOpAtomicCompareExchange( + parent, inst, inst->getFullType(), inst->getOperand(0), + memoryScope, memorySemanticsEqual, memorySemanticsUnequal, + inst->getOperand(2), inst->getOperand(1)); + } + break; + case kIROp_AtomicAdd: + case kIROp_AtomicSub: + case kIROp_AtomicMax: + case kIROp_AtomicMin: + case kIROp_AtomicAnd: + case kIROp_AtomicOr: + case kIROp_AtomicXor: + { + IRBuilder builder{ inst }; + const auto memoryScope = emitIntConstant(IRIntegerValue{ SpvScopeDevice }, builder.getUIntType()); + const auto memorySemantics = emitMemorySemanticMask(inst->getOperand(2)); + bool negateOperand = false; + auto spvOp = getSpvAtomicOp(inst, negateOperand); + auto operand = inst->getOperand(1); + if (negateOperand) + { + builder.setInsertBefore(inst); + auto negatedOperand = builder.emitNeg(inst->getDataType(), operand); + operand = negatedOperand; + } + result = emitOpAtomicOp(parent, inst, spvOp, inst->getFullType(), inst->getOperand(0), memoryScope, memorySemantics, operand); + ensureAtomicCapability(inst, spvOp); + } + break; case kIROp_ControlBarrier: { IRBuilder builder{ inst }; diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index 40d8ace91..06fea46b4 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -275,7 +275,6 @@ void WGSLSourceEmitter::emitStructFieldAttributes( bool WGSLSourceEmitter::isPointerSyntaxRequiredImpl(IRInst* inst) { - // Structured buffers are mapped to 'array' types, which don't need dereferencing if (inst->getOp() == kIROp_RWStructuredBufferGetElementPtr) return false; @@ -470,6 +469,14 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type) m_writer->emit(">"); return; } + + case kIROp_AtomicType: + { + m_writer->emit("atomic<"); + emitType(cast<IRAtomicType>(type)->getElementType()); + m_writer->emit(">"); + return; + } default: break; @@ -504,13 +511,32 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout) } -void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, const bool isConstant) +void WGSLSourceEmitter::emitVarKeywordImpl(IRType * type, IRInst* varDecl) { - if (isConstant) - m_writer->emit("const"); - else + switch (varDecl->getOp()) + { + case kIROp_GlobalParam: + case kIROp_GlobalVar: + case kIROp_Var: m_writer->emit("var"); - if (type->getOp() == kIROp_HLSLRWStructuredBufferType) + break; + default: + if (as<IRModuleInst>(varDecl->getParent())) + { + m_writer->emit("const"); + } + else + { + m_writer->emit("var"); + } + break; + } + + if (as<IRGroupSharedRate>(varDecl->getRate())) + { + m_writer->emit("<workgroup>"); + } + else if (type->getOp() == kIROp_HLSLRWStructuredBufferType) { m_writer->emit("<"); m_writer->emit("storage, read_write"); @@ -789,6 +815,144 @@ void WGSLSourceEmitter::emitParamTypeImpl(IRType* type, const String& name) emitType(type, name); } +bool WGSLSourceEmitter::tryEmitInstStmtImpl(IRInst* inst) +{ + switch (inst->getOp()) + { + default: + return false; + case kIROp_AtomicLoad: + { + emitInstResultDecl(inst); + m_writer->emit("atomicLoad(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("));\n"); + return true; + } + case kIROp_AtomicStore: + { + m_writer->emit("atomicStore(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomicExchange(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicCompareExchange: + { + emitInstResultDecl(inst); + m_writer->emit("atomicCompareExchangeWeak(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); + m_writer->emit(").old_value;\n"); + return true; + } + case kIROp_AtomicAdd: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicSub: + { + emitInstResultDecl(inst); + m_writer->emit("atomicSub(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicAnd: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAnd(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicOr: + { + emitInstResultDecl(inst); + m_writer->emit("atomicOr(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicXor: + { + emitInstResultDecl(inst); + m_writer->emit("atomicXor(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMin: + { + emitInstResultDecl(inst); + m_writer->emit("atomicMin(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicMax: + { + emitInstResultDecl(inst); + m_writer->emit("atomicMax(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(");\n"); + return true; + } + case kIROp_AtomicInc: + { + emitInstResultDecl(inst); + m_writer->emit("atomicAdd(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitType(inst->getDataType()); + m_writer->emit("(1));\n"); + return true; + } + case kIROp_AtomicDec: + { + emitInstResultDecl(inst); + m_writer->emit("atomicSub(&("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit("), "); + emitType(inst->getDataType()); + m_writer->emit("(1));\n"); + return true; + } + } +} + bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { EmitOpInfo outerPrec = inOuterPrec; @@ -869,6 +1033,17 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu } break; + case kIROp_RWStructuredBufferGetElementPtr: + { + m_writer->emit("(*"); + emitOperand(inst->getOperand(0), leftSide(outerPrec, getInfo(EmitOp::Postfix))); + m_writer->emit(")["); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit("]"); + return true; + } + break; + case kIROp_StructuredBufferLoad: case kIROp_RWStructuredBufferLoad: { diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index d3cf19d91..b3a4efb55 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -28,6 +28,7 @@ public: virtual bool tryEmitInstExprImpl( IRInst* inst, const EmitOpInfo& inOuterPrec ) SLANG_OVERRIDE; + virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitSwitchCaseSelectorsImpl( IRBasicType *const switchCondition, const SwitchRegion::Case *const currentCase, @@ -36,7 +37,7 @@ public: virtual void emitSimpleTypeAndDeclaratorImpl( IRType* type, DeclaratorInfo* declarator ) SLANG_OVERRIDE; - virtual void emitVarKeywordImpl(IRType * type, const bool isConstant) SLANG_OVERRIDE; + virtual void emitVarKeywordImpl(IRType * type, IRInst* varDecl) SLANG_OVERRIDE; virtual void emitDeclaratorImpl(DeclaratorInfo* declarator) SLANG_OVERRIDE; virtual void emitStructDeclarationSeparatorImpl() SLANG_OVERRIDE; virtual void emitLayoutQualifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 0d689660e..dc4f7dff4 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -69,7 +69,9 @@ INST(Nop, nop, 0, 0) INST(TensorViewType, TensorView, 1, HOISTABLE) INST(TorchTensorType, TorchTensor, 0, HOISTABLE) INST(ArrayListType, ArrayListVector, 1, HOISTABLE) - + + INST(AtomicType, Atomic, 1, HOISTABLE) + /* BindExistentialsTypeBase */ // A `BindExistentials<B, T0,w0, T1,w1, ...>` represents @@ -400,6 +402,21 @@ INST(Var, var, 0, 0) INST(Load, load, 1, 0) INST(Store, store, 2, 0) +// Atomic Operations +INST(AtomicLoad, atomicLoad, 1, 0) +INST(AtomicStore, atomicStore, 2, 0) +INST(AtomicExchange, atomicExchange, 2, 0) +INST(AtomicCompareExchange, atomicCompareExchange, 3, 0) +INST(AtomicAdd, atomicAdd, 2, 0) +INST(AtomicSub, atomicSub, 2, 0) +INST(AtomicAnd, atomicAnd, 2, 0) +INST(AtomicOr, atomicOr, 2, 0) +INST(AtomicXor, atomicXor, 2, 0) +INST(AtomicMin, atomicMin, 2, 0) +INST(AtomicMax, atomicMax, 2, 0) +INST(AtomicInc, atomicInc, 1, 0) +INST(AtomicDec, atomicDec, 1, 0) + // Produced and removed during backward auto-diff pass as a temporary placeholder representing the // currently accumulated derivative to pass to some dOut argument in a nested call. INST(LoadReverseGradient, LoadReverseGradient, 1, 0) @@ -515,9 +532,6 @@ INST(StructuredBufferGetDimensions, StructuredBufferGetDimensions, 1, 0) // Resource qualifiers for dynamically varying index INST(NonUniformResourceIndex, nonUniformResourceIndex, 1, 0) -INST(AtomicCounterIncrement, AtomicCounterIncrement, 1, 0) -INST(AtomicCounterDecrement, AtomicCounterDecrement, 1, 0) - INST(GetNaturalStride, getNaturalStride, 1, 0) INST(MeshOutputRef, meshOutputRef, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f31a56673..24a779aed 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2358,6 +2358,14 @@ struct IRLoad : IRInst IRInst* getPtr() { return ptr.get(); } }; +struct IRAtomicLoad : IRInst +{ + IRUse ptr; + IR_LEAF_ISA(AtomicLoad) + + IRInst* getPtr() { return ptr.get(); } +}; + struct IRStore : IRInst { IRUse ptr; @@ -2368,6 +2376,16 @@ struct IRStore : IRInst IRInst* getVal() { return val.get(); } }; +struct IRAtomicStore : IRInst +{ + IRUse ptr; + IRUse val; + IR_LEAF_ISA(AtomicStore) + + IRInst* getPtr() { return ptr.get(); } + IRInst* getVal() { return val.get(); } +}; + struct IRRWStructuredBufferStore : IRInst { IR_LEAF_ISA(RWStructuredBufferStore) @@ -4235,6 +4253,11 @@ public: IRInst* dstPtr, IRInst* srcVal); + IRInst* emitAtomicStore( + IRInst* dstPtr, + IRInst* srcVal, + IRInst* memoryOrder); + IRInst* emitImageLoad( IRType* type, ShortList<IRInst*> params); diff --git a/source/slang/slang-ir-layout.cpp b/source/slang/slang-ir-layout.cpp index 3a2471930..82287f58e 100644 --- a/source/slang/slang-ir-layout.cpp +++ b/source/slang/slang-ir-layout.cpp @@ -211,6 +211,14 @@ case kIROp_##TYPE##Type: \ } break; + case kIROp_AtomicType: + { + auto atomicType = cast<IRAtomicType>(type); + _calcSizeAndAlignment(optionSet, rules, atomicType->getElementType(), outSizeAndAlignment); + return SLANG_OK; + } + break; + case kIROp_UnsizedArrayType: { auto unsizedArrayType = cast<IRUnsizedArrayType>(type); diff --git a/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp index da6b8d34d..ac8a8b5e7 100644 --- a/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp +++ b/source/slang/slang-ir-lower-append-consume-structured-buffer.cpp @@ -99,7 +99,8 @@ namespace Slang auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey); IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) }; auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs); - auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterIncrement, 1, &counterBufferPtr); + IRInst* atomicIncArgs[] = { counterBufferPtr, builder.getIntValue(builder.getIntType(), kIRMemoryOrder_Relaxed) }; + auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicInc, 2, atomicIncArgs); IRInst* getElementPtrArgs[] = { elementBuffer, oldCounter }; auto elementBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(elementType), kIROp_RWStructuredBufferGetElementPtr, 2, getElementPtrArgs); @@ -122,7 +123,8 @@ namespace Slang auto counterBuffer = builder.emitFieldExtract(counterBufferType, bufferParam, counterBufferKey); IRInst* getCounterPtrArgs[] = { counterBuffer, builder.getIntValue(builder.getIntType(), 0) }; auto counterBufferPtr = builder.emitIntrinsicInst(builder.getPtrType(builder.getIntType()), kIROp_RWStructuredBufferGetElementPtr, 2, getCounterPtrArgs); - auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicCounterDecrement, 1, &counterBufferPtr); + IRInst* atomicDecArgs[] = { counterBufferPtr, builder.getIntValue(builder.getIntType(), kIRMemoryOrder_Relaxed) }; + auto oldCounter = builder.emitIntrinsicInst(builder.getIntType(), kIROp_AtomicDec, 2, atomicDecArgs); auto index = builder.emitSub(builder.getIntType(), oldCounter, builder.getIntValue(builder.getIntType(), 1)); // Test if index is greater or equal than 0. diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp index 56b13aa09..c09077528 100644 --- a/source/slang/slang-ir-use-uninitialized-values.cpp +++ b/source/slang/slang-ir-use-uninitialized-values.cpp @@ -297,6 +297,7 @@ namespace Slang case kIROp_Store: case kIROp_SwizzledStore: case kIROp_SPIRVAsm: + case kIROp_AtomicStore: return Store; case kIROp_SPIRVAsmOperandInst: diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b89929f55..c12ad0f68 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5084,6 +5084,23 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitAtomicStore( + IRInst* dstPtr, + IRInst* srcVal, + IRInst* memoryOrder) + { + auto inst = createInst<IRAtomicStore>( + this, + kIROp_AtomicStore, + nullptr, + dstPtr, + srcVal, + memoryOrder); + + addInst(inst); + return inst; + } + /// @param params An ordered list of imageLoad parameters { image, coord, [optional] seperateArrayCoord, [optional] seperateSampleCoord } IRInst* IRBuilder::emitImageLoad(IRType* type, ShortList<IRInst*> params) { diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 14dde200f..ee12cff8c 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -93,6 +93,15 @@ enum IROpMask : std::underlying_type_t<IROp> kIROpMask_OpMask = 0x3ff, ///< Mask for just opcode }; +enum IRMemoryOrder +{ + kIRMemoryOrder_Relaxed = 0, + kIRMemoryOrder_Acquire = 1, + kIRMemoryOrder_Release = 2, + kIRMemoryOrder_AcquireRelease = 3, + kIRMemoryOrder_SeqCst = 4, +}; + inline int32_t operator&(const IROpMask m, const IROp o) { #if defined(__cpp_lib_bit_cast) @@ -1614,6 +1623,13 @@ struct IRArrayType: IRArrayTypeBase SIMPLE_IR_TYPE(UnsizedArrayType, ArrayTypeBase) +struct IRAtomicType : IRType +{ + IR_LEAF_ISA(AtomicType) + + IRType* getElementType() { return (IRType*)getOperand(0); } +}; + SIMPLE_IR_PARENT_TYPE(Rate, Type) SIMPLE_IR_TYPE(ConstExprRate, Rate) SIMPLE_IR_TYPE(GroupSharedRate, Rate) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 813467743..d8dbaa812 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7130,9 +7130,19 @@ top: // The `left` value is just a pointer, so we can emit // a store to it directly. // - builder->emitStore( - left.val, - getSimpleVal(context, right)); + if (as<IRAtomicType>(tryGetPointedToType(builder, left.val->getDataType()))) + { + builder->emitAtomicStore( + left.val, + getSimpleVal(context, right), + builder->getIntValue(builder->getIntType(), kIRMemoryOrder_Relaxed)); + } + else + { + builder->emitStore( + left.val, + getSimpleVal(context, right)); + } } break; diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 2447f5787..57635122e 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -4427,6 +4427,10 @@ static TypeLayoutResult _createTypeLayout( { return createArrayLikeTypeLayout(context, arrayType, arrayType->getElementType(), arrayType->getElementCount()); } + else if (auto atomicType = as<AtomicType>(type)) + { + return _createTypeLayout(context, atomicType->getElementType()); + } else if (auto ptrType = as<PtrTypeBase>(type)) { RefPtr<PointerTypeLayout> ptrLayout = new PointerTypeLayout(); |
