diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-14 16:23:19 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-14 16:23:19 -0700 |
| commit | 661d6198bbb9857d3fdc6df477e0742ed0b0765c (patch) | |
| tree | 974a57cfa2e43624e91502e9e652a0cc78105b3a /source | |
| parent | 0403e0556b470f6b316153caea2dc6f5c314da5b (diff) | |
Support per field matrix layout (#3101)
* Support per field matrix layout
* Fix warnings.
* Fix.
* Fix tests.
* Fix spiv gen.
* Fix.
* More test fixes.
* Fix.
* Run only GPU tests on self-hosted servers.
* Remove -use-glsl-matrix-layout-modifier.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
36 files changed, 1223 insertions, 160 deletions
diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp index bbf154a69..32834232f 100644 --- a/source/slang-glslang/slang-glslang.cpp +++ b/source/slang-glslang/slang-glslang.cpp @@ -509,7 +509,7 @@ static spv_target_env _getUniversalTargetEnv(glslang::EShTargetLanguageVersion i return SPV_ENV_UNIVERSAL_1_2; } -static int glslang_compileGLSLToSPIRV(const glslang_CompileRequest_1_2& request) +static int glslang_compileGLSLToSPIRV(glslang_CompileRequest_1_2 request) { // Check that the encoding matches assert(glslang::EShTargetSpv_1_4 == _makeTargetLanguageVersion(1, 4)); diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 257d5930b..ebf4b5d76 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -894,8 +894,11 @@ struct vector __init(vector<T,N> value); } +const int kRowMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_ROW_MAJOR); +const int kColumnMajorMatrixLayout = $(SLANG_MATRIX_LAYOUT_COLUMN_MAJOR); + /// A matrix with `R` rows and `C` columns, with elements of type `T`. -__generic<T = float, let R : int = 4, let C : int = 4> +__generic<T = float, let R : int = 4, let C : int = 4, let L : int = $(SLANG_MATRIX_LAYOUT_MODE_UNKNOWN)> __magic_type(MatrixExpressionType) struct matrix { @@ -1111,7 +1114,7 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) for( int R = 2; R <= 4; ++R ) for( int C = 2; C <= 4; ++C ) { - sb << "__generic<T> __extension matrix<T, " << R << "," << C << ">\n{\n"; + sb << "__generic<T, let L:int> __extension matrix<T, " << R << "," << C << ", L>\n{\n"; // initialize from R*C scalars sb << "__intrinsic_op(" << int(kIROp_MakeMatrix) << ") __init("; @@ -1137,7 +1140,7 @@ for( int C = 2; C <= 4; ++C ) for( int cc = C; cc <= 4; ++cc ) { if(rr == R && cc == C) continue; - sb << "__intrinsic_op(" << int(kIROp_MatrixReshape) << ") __init(matrix<T," << rr << "," << cc << "> value);\n"; + sb << "__intrinsic_op(" << int(kIROp_MatrixReshape) << ") __init(matrix<T," << rr << "," << cc << ", L> value);\n"; } sb << "}\n"; } @@ -1207,16 +1210,16 @@ extension vector<T, N> : IDifferentiable } } -__generic<T:__BuiltinFloatingPointType, let R: int, let C: int> -extension matrix<T, R, C> : IDifferentiable +__generic<T:__BuiltinFloatingPointType, let R: int, let C: int, let L : int> +extension matrix<T, R, C, L> : IDifferentiable { - typedef matrix<T, R, C> Differential; + typedef matrix<T, R, C, L> Differential; [__unsafeForceInlineEarly] [BackwardDifferentiable] static Differential dzero() { - return matrix<T, R, C>(__slang_noop_cast<T>(T.dzero())); + return matrix<T, R, C, L>(__slang_noop_cast<T>(T.dzero())); } [__unsafeForceInlineEarly] @@ -2425,9 +2428,9 @@ vector<T,N> operator$(op.name)(in out vector<T,N> value) {$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); } $(fixity.qual) -__generic<T : __BuiltinArithmeticType, let R : int, let C : int> +__generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int> [__unsafeForceInlineEarly] -matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C> value) +matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value) {$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); } $(fixity.qual) @@ -2609,9 +2612,9 @@ __generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : __intrinsic_op($(info.op)) matrix<L,N,M> operator$(info.name)(matrix<L,N,M> left, matrix<R,N,M> right); -__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int> +__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int, let Layout : int> [__unsafeForceInlineEarly] -matrix<L, N, M> operator$(info.name)=(in out matrix<L, N, M> left, matrix<R, N, M> right) +matrix<L, N, M> operator$(info.name)=(in out matrix<L, N, M, Layout> left, matrix<R, N, M> right) { left = left $(info.name) right; return left; @@ -2641,9 +2644,9 @@ __generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : __intrinsic_op($(info.op)) matrix<L,N,M> operator$(info.name)(matrix<L,N,M> left, R right); -__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int> +__generic<L: __BuiltinIntegerType, R: __BuiltinIntegerType, let N : int, let M : int, let Layout : int> [__unsafeForceInlineEarly] -matrix<L,N,M> operator$(info.name)=(in out matrix<L,N,M> left, R right) +matrix<L,N,M> operator$(info.name)=(in out matrix<L,N,M, Layout> left, R right) { left = left $(info.name) right; return left; @@ -2696,17 +2699,17 @@ ${{{{ return left; } - __generic<T : $(op.interface), let R : int, let C : int> + __generic<T : $(op.interface), let R : int, let C : int, let Layout : int> [__unsafeForceInlineEarly] - matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C> left, matrix<T,R,C> right) + matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C,Layout> left, matrix<T,R,C> right) { left = left $(op.name) right; return left; } - __generic<T : $(op.interface), let R : int, let C : int> + __generic<T : $(op.interface), let R : int, let C : int, let Layout : int> [__unsafeForceInlineEarly] - matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C> left, T right) + matrix<T,R,C> operator$(op.name)=(in out matrix<T,R,C, Layout> left, T right) { left = left $(op.name) right; return left; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 3e381e55d..ce0e72d34 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -1403,11 +1403,11 @@ void __sincos_impl(vector<T, N> x, out vector<T, N> s, out vector<T, N> c) c = cos(x); } -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L1 : int, let L2 : int> [BackwardDifferentiable] [PrimalSubstituteOf(sincos)] [PreferRecompute] -void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M> s, out matrix<T, N, M> c) +void __sincos_impl(matrix<T, N, M> x, out matrix<T, N, M, L1> s, out matrix<T, N, M, L2> c) { s = sin(x); c = cos(x); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 9de14ad96..015b1c4ad 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -235,19 +235,20 @@ struct StructuredBuffer out uint numStructs, out uint stride); - __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") - [__readNone] + //__target_intrinsic(glsl, "$0._data[$1]") + //__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") + __intrinsic_op($(kIROp_StructuredBufferLoad)) T Load(int location); - [__readNone] + __intrinsic_op($(kIROp_StructuredBufferLoadStatus)) T Load(int location, out uint status); __subscript(uint index) -> T { - __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") + //__target_intrinsic(glsl, "$0._data[$1]") + //__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") [__readNone] + __intrinsic_op($(kIROp_StructuredBufferLoad)) get; }; }; @@ -709,7 +710,6 @@ static const struct { for(auto item : kMutableStructuredBufferCases) { }}}} - __generic<T> __magic_type(HLSL$(item.name)Type) __intrinsic_type($(item.op)) @@ -724,18 +724,22 @@ struct $(item.name) uint IncrementCounter(); - __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") + //__target_intrinsic(glsl, "$0._data[$1]") + //__target_intrinsic(spirv_direct, "%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;") [__NoSideEffect] + __intrinsic_op($(kIROp_RWStructuredBufferLoad)) T Load(int location); + [__NoSideEffect] + __intrinsic_op($(kIROp_RWStructuredBufferLoadStatus)) T Load(int location, out uint status); __subscript(uint index) -> T { - __target_intrinsic(glsl, "$0._data[$1]") - __target_intrinsic(spirv_direct, "*StorageBuffer OpAccessChain resultType resultId _0 const(int, 0) _1") + //__target_intrinsic(glsl, "$0._data[$1]") + //__target_intrinsic(spirv_direct, "*StorageBuffer OpAccessChain resultType resultId _0 const(int, 0) _1") [__NoSideEffect] + __intrinsic_op($(kIROp_RWStructuredBufferGetElementPtr)) ref; } }; @@ -2248,10 +2252,10 @@ vector<T, N> frexp(vector<T, N> x, out vector<T, N> exp) VECTOR_MAP_BINARY(T, N, frexp, x, exp); } -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L : int> __target_intrinsic(hlsl) [__readNone] -matrix<T, N, M> frexp(matrix<T, N, M> x, out matrix<T, N, M> exp) +matrix<T, N, M> frexp(matrix<T, N, M> x, out matrix<T, N, M, L> exp) { MATRIX_MAP_BINARY(T, N, M, frexp, x, exp); } @@ -2924,10 +2928,10 @@ vector<T,N> modf(vector<T,N> x, out vector<T,N> ip) VECTOR_MAP_BINARY(T, N, modf, x, ip); } -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L : int> __target_intrinsic(hlsl) [__readNone] -matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M> ip) +matrix<T,N,M> modf(matrix<T,N,M> x, out matrix<T,N,M,L> ip) { MATRIX_MAP_BINARY(T, N, M, modf, x, ip); } @@ -3664,10 +3668,10 @@ void sincos(vector<T,N> x, out vector<T,N> s, out vector<T,N> c) c = cos(x); } -__generic<T : __BuiltinFloatingPointType, let N : int, let M : int> +__generic<T : __BuiltinFloatingPointType, let N : int, let M : int, let L1: int, let L2 : int> __target_intrinsic(hlsl) [__readNone] -void sincos(matrix<T,N,M> x, out matrix<T,N,M> s, out matrix<T,N,M> c) +void sincos(matrix<T,N,M> x, out matrix<T,N,M,L1> s, out matrix<T,N,M,L2> c) { s = sin(x); c = cos(x); diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index bb4f53433..a76f6e07f 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -383,6 +383,21 @@ VectorExpressionType* ASTBuilder::getVectorType( return as<VectorExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "VectorExpressionType")); } +MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout) +{ + // Canonicalize constant size arguments to int. + if (auto rowCountConstantInt = as<ConstantIntVal>(rowCount)) + { + rowCount = getIntVal(getIntType(), rowCountConstantInt->getValue()); + } + if (auto colCountConstantInt = as<ConstantIntVal>(colCount)) + { + colCount = getIntVal(getIntType(), colCountConstantInt->getValue()); + } + Val* args[] = { elementType, rowCount, colCount, layout }; + return as<MatrixExpressionType>(getSpecializedBuiltinType(makeArrayView(args), "MatrixExpressionType")); +} + DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, Witness* primalIsDifferentialWitness) diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 5c0e74851..2674ef1b2 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -424,6 +424,8 @@ public: VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); + MatrixExpressionType* getMatrixType(Type* elementType, IntVal* rowCount, IntVal* colCount, IntVal* layout); + ConstantBufferType* getConstantBufferType(Type* elementType); ParameterBlockType* getParameterBlockType(Type* elementType); diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 9140be967..b1d3a34a2 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -80,6 +80,9 @@ namespace Slang // No conversion at all kConversionCost_None = 0, + // Convert between matrices of different layout + kConversionCost_MatrixLayout = 5, + // Conversion from a buffer to the type it carries needs to add a minimal // extra cost, just so we can distinguish an overload on `ConstantBuffer<Foo>` // from one on `Foo` diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 7ea5e8ed1..6d4a52cae 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -238,6 +238,11 @@ IntVal* MatrixExpressionType::getColumnCount() return as<IntVal>(_getGenericTypeArg(this, 2)); } +IntVal* MatrixExpressionType::getLayout() +{ + return as<IntVal>(_getGenericTypeArg(this, 3)); +} + Type* MatrixExpressionType::getRowType() { if (!rowType) diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index f3dc975a3..c2a13542d 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -435,7 +435,7 @@ class VectorExpressionType : public ArithmeticExpressionType IntVal* getElementCount(); }; -// A matrix type, e.g., `matrix<T,R,C>` +// A matrix type, e.g., `matrix<T,R,C,L>` class MatrixExpressionType : public ArithmeticExpressionType { SLANG_AST_CLASS(MatrixExpressionType) @@ -443,6 +443,7 @@ class MatrixExpressionType : public ArithmeticExpressionType Type* getElementType(); IntVal* getRowCount(); IntVal* getColumnCount(); + IntVal* getLayout(); Type* getRowType(); diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 09a783fb4..5a9c8df12 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -894,6 +894,28 @@ namespace Slang } return true; } + // matrix type with different layouts are convertible + if (auto fromMatrixType = as<MatrixExpressionType>(fromType)) + { + if (auto toMatrixType = as<MatrixExpressionType>(toType)) + { + if (fromMatrixType->getElementType()->equals(toMatrixType->getElementType()) && + fromMatrixType->getRowCount()->equals(toMatrixType->getRowCount()) && + fromMatrixType->getColumnCount()->equals(toMatrixType->getColumnCount())) + { + if (outCost) + { + *outCost = kConversionCost_MatrixLayout; + } + if (outToExpr) + { + *outToExpr = fromExpr; + } + return true; + } + } + + } // A type is always convertible to any of its supertypes. // diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4e8a00907..6385e5f57 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1069,6 +1069,23 @@ namespace Slang validateArraySizeForVariable(varDecl); } + // If there is a matrix layout modifier, we will modify the matrix type now. + if (auto matrixType = as<MatrixExpressionType>(varDecl->type.type)) + { + if (auto matrixLayoutModifier = varDecl->findModifier<MatrixLayoutModifier>()) + { + auto matrixLayout = as<ColumnMajorLayoutModifier>(matrixLayoutModifier) ? SLANG_MATRIX_LAYOUT_COLUMN_MAJOR : SLANG_MATRIX_LAYOUT_ROW_MAJOR; + auto newMatrixType = getASTBuilder()->getMatrixType( + matrixType->getElementType(), + matrixType->getRowCount(), + matrixType->getColumnCount(), + getASTBuilder()->getIntVal(getASTBuilder()->getIntType(), matrixLayout)); + varDecl->type.type = newMatrixType; + if (varDecl->initExpr) + varDecl->initExpr = coerce(CoercionSite::Initializer, varDecl->type, varDecl->initExpr); + } + } + checkMeshOutputDecl(varDecl); // The NVAPI library allows user code to express extended operations diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index a6231d959..e42e97b89 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1783,7 +1783,6 @@ namespace Slang char const* value); SlangResult setMatrixLayoutMode( SlangMatrixLayoutMode mode); - /// Create an initially-empty linkage Linkage(Session* session, ASTBuilder* astBuilder, Linkage* builtinLinkage); @@ -2672,6 +2671,7 @@ namespace Slang virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) SLANG_OVERRIDE; virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE; virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE; + void setHLSLToVulkanLayoutOptions(int targetIndex, HLSLToVulkanLayoutOptions* vulkanLayoutOptions); EndToEndCompileRequest( diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 0cc3a196b..95c691d8b 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -2146,6 +2146,52 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO } break; + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + { + auto base = inst->getOperand(0); + emitOperand(base, outerPrec); + m_writer->emit(".Load("); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit(")"); + } + break; + + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoadStatus: + { + auto base = inst->getOperand(0); + emitOperand(base, outerPrec); + m_writer->emit(".Load("); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), EmitOpInfo()); + m_writer->emit(")"); + } + break; + + case kIROp_RWStructuredBufferGetElementPtr: + { + auto base = inst->getOperand(0); + emitOperand(base, outerPrec); + m_writer->emit("["); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit("]"); + } + break; + + case kIROp_RWStructuredBufferStore: + { + auto base = inst->getOperand(0); + emitOperand(base, EmitOpInfo()); + m_writer->emit(".Store("); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit(", "); + emitOperand(inst->getOperand(2), EmitOpInfo()); + m_writer->emit(")"); + } + break; + case kIROp_Call: { emitCallExpr((IRCall*)inst, outerPrec); diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 83af4ad47..dee8e7197 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1391,13 +1391,24 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut auto outerPrec = getInfo(EmitOp::General); auto prec = getInfo(EmitOp::Postfix); emitOperand(baseInst, leftSide(outerPrec, prec)); - m_writer->emit("->rows + "); + m_writer->emit("->rows + ("); emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); - m_writer->emit(")"); + m_writer->emit("))"); return true; } return false; } + case kIROp_RWStructuredBufferGetElementPtr: + { + m_writer->emit("(&("); + auto base = inst->getOperand(0); + auto outerPrec = getInfo(EmitOp::General); + emitOperand(base, outerPrec); + m_writer->emit("["); + emitOperand(inst->getOperand(1), EmitOpInfo()); + m_writer->emit("]))"); + return true; + } case kIROp_swizzle: { // For C++ we don't need to emit a swizzle function diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 22a0c323b..0920c236c 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -1959,6 +1959,10 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu return true; } case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_RWStructuredBufferGetElementPtr: { auto outerPrec = inOuterPrec; auto prec = getInfo(EmitOp::Postfix); @@ -1972,7 +1976,7 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu maybeCloseParens(needClose); return true; } - case kIROp_StructuredBufferStore: + case kIROp_RWStructuredBufferStore: { auto outerPrec = inOuterPrec; diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index c16ee6a90..bafbb79f1 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -1619,6 +1619,15 @@ struct SPIRVEmitContext return emitLoad(parent, as<IRLoad>(inst)); case kIROp_Store: return emitStore(parent, as<IRStore>(inst)); + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + return emitStructuredBufferLoad(parent, inst); + case kIROp_RWStructuredBufferStore: + return emitStructuredBufferStore(parent, inst); + case kIROp_RWStructuredBufferGetElementPtr: + return emitStructuredBufferGetElementPtr(parent, inst); case kIROp_swizzle: return emitSwizzle(parent, as<IRSwizzle>(inst)); case kIROp_IntCast: @@ -2645,6 +2654,30 @@ struct SPIRVEmitContext return emitInst(parent, inst, SpvOpStore, inst->getPtr(), inst->getVal()); } + SpvInst* emitStructuredBufferLoad(SpvInstParent* parent, IRInst* inst) + { + //"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpLoad resultType resultId %addr;" + IRBuilder builder(inst); + auto addr = emitInst(parent, inst, SpvOpAccessChain, inst->getOperand(0)->getDataType(), kResultID, inst->getOperand(0), emitIntConstant(0, builder.getIntType()), inst->getOperand(1)); + return emitInst(parent, inst, SpvOpLoad, inst->getFullType(), kResultID, addr); + } + + SpvInst* emitStructuredBufferStore(SpvInstParent* parent, IRInst* inst) + { + //"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1; OpStore %addr _2;" + IRBuilder builder(inst); + auto addr = emitInst(parent, inst, SpvOpAccessChain, inst->getOperand(0)->getDataType(), kResultID, inst->getOperand(0), emitIntConstant(0, builder.getIntType()), inst->getOperand(1)); + return emitInst(parent, inst, SpvOpStore, addr, inst->getOperand(2)); + } + + SpvInst* emitStructuredBufferGetElementPtr(SpvInstParent* parent, IRInst* inst) + { + //"%addr = OpAccessChain resultType*StorageBuffer resultId _0 const(int, 0) _1;" + IRBuilder builder(inst); + auto addr = emitInst(parent, inst, SpvOpAccessChain, inst->getDataType(), kResultID, inst->getOperand(0), emitIntConstant(0, builder.getIntType()), inst->getOperand(1)); + return addr; + } + SpvInst* emitSwizzle(SpvInstParent* parent, IRSwizzle* inst) { if (inst->getElementCount() == 1) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 343c18916..4b95ca9c8 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -49,6 +49,7 @@ #include "slang-ir-specialize-arrays.h" #include "slang-ir-specialize-buffer-load-arg.h" #include "slang-ir-specialize-resources.h" +#include "slang-ir-specialize-matrix-layout.h" #include "slang-ir-ssa.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-strip-cached-dict.h" @@ -59,6 +60,7 @@ #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" #include "slang-ir-legalize-uniform-buffer-load.h" +#include "slang-ir-lower-buffer-element-type.h" #include "slang-ir-string-hash.h" #include "slang-ir-simplify-for-emit.h" #include "slang-ir-pytorch-cpp-binding.h" @@ -353,6 +355,9 @@ Result linkAndOptimizeIR( simplifyIR(irModule, sink); + // Fill in default matrix layout into matrix types that left layout unspecified. + specializeMatrixLayout(codeGenContext->getTargetReq(), irModule); + // It's important that this takes place before defunctionalization as we // want to be able to easily discover the cooperate and fallback funcitons // being passed to saturated_cooperation @@ -588,10 +593,6 @@ Result linkAndOptimizeIR( } eliminateDeadCode(irModule); - // Rewrite functions that return arrays to return them via `out` parameter, - // since our target languages doesn't allow returning arrays. - legalizeArrayReturnType(irModule); - #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER RESOURCE SPECIALIZATION"); #endif @@ -857,6 +858,14 @@ Result linkAndOptimizeIR( // arrays that the emitters can deal with. legalizeMeshOutputTypes(irModule); + // We need to lower any types used in a buffer resource (e.g. ContantBuffer or StructuredBuffer) into + // a simple storage type that has target independent layout. + lowerBufferElementTypeToStorageType(targetRequest, irModule); + + // Rewrite functions that return arrays to return them via `out` parameter, + // since our target languages doesn't allow returning arrays. + legalizeArrayReturnType(irModule); + if (isKhronosTarget(targetRequest) || target == CodeGenTarget::HLSL) { legalizeUniformBufferLoad(irModule); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index f1099547c..efaaec906 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1858,7 +1858,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ByteAddressBufferLoad: case kIROp_ByteAddressBufferStore: case kIROp_StructuredBufferLoad: - case kIROp_StructuredBufferStore: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_RWStructuredBufferStore: + case kIROp_RWStructuredBufferGetElementPtr: case kIROp_Reinterpret: case kIROp_IsType: case kIROp_ImageSubscript: diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp index 14f985c3b..a6d0acee7 100644 --- a/source/slang/slang-ir-byte-address-legalize.cpp +++ b/source/slang/slang-ir-byte-address-legalize.cpp @@ -212,6 +212,15 @@ struct ByteAddressBufferLegalizationContext return getNaturalOffset(target, field, outOffset); } + SlangResult getSizeAndAlignment(TargetRequest* target, IRType* type, IRSizeAndAlignment* outSizeAlignment) + { + if (target->getHLSLToVulkanLayoutOptions() && target->getHLSLToVulkanLayoutOptions()->shouldUseGLLayout()) + { + return getStd430SizeAndAlignment(target, type, outSizeAlignment); + } + return getNaturalSizeAndAlignment(target, type, outSizeAlignment); + } + // The core workhorse routine for the load case is `emitLegalLoad`, // which tries to emit load operations that read a value of the // given `type` from the given `buffer` at the required `baseOffset` @@ -312,23 +321,42 @@ struct ByteAddressBufferLegalizationContext // small detail that we need to construct the row type // that we expect to load for each "element." // - // TODO: The logic here assumes row-major layout, because - // the row-vs-column-major information has been dropped - // by this point in the IR. - // - // In order to allow both row- and column-major matrices - // to be loaded from byte-address buffers, we would need - // to make row-vs-column-major-ness be part of the IR - // type system so that IR layout can take it into account. - // - // For now we have to live with the "natural" layout of - // matrices always being row-major. - // - auto rowCountInst = as<IRIntLit>(matType->getRowCount()); - if( rowCountInst ) + if (getIntVal(matType->getLayout()) != SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - auto rowType = m_builder.getVectorType(matType->getElementType(), matType->getColumnCount()); - return emitLegalSequenceLoad(type, buffer, baseOffset, immediateOffset, kIROp_MakeMatrix, rowType, rowCountInst->getValue()); + auto rowCountInst = as<IRIntLit>(matType->getRowCount()); + if( rowCountInst ) + { + auto rowType = m_builder.getVectorType(matType->getElementType(), matType->getColumnCount()); + return emitLegalSequenceLoad(type, buffer, baseOffset, immediateOffset, kIROp_MakeMatrix, rowType, rowCountInst->getValue()); + } + } + else + { + List<IRInst*> elements; + auto colCount = (Index)getIntVal(matType->getColumnCount()); + auto rowCount = (Index)getIntVal(matType->getRowCount()); + auto colVectorType = m_builder.getVectorType(matType->getElementType(), rowCount); + IRSizeAndAlignment colVectorSizeAlignment; + getSizeAndAlignment(m_target, colVectorType, &colVectorSizeAlignment); + for (Index c = 0; c < colCount; c++) + { + auto colVector = emitLegalLoad(colVectorType, buffer, baseOffset, immediateOffset); + for (Index r = 0; r < rowCount; r++) + { + elements.add(m_builder.emitElementExtract(colVector, (IRIntegerValue)r)); + } + immediateOffset += colVectorSizeAlignment.getStride(); + } + List<IRInst*> args; + for (Index r = 0; r < rowCount; r++) + { + for (Index c = 0; c < colCount; c++) + { + auto index = c * rowCount + r; + args.add(elements[index]); + } + } + return m_builder.emitMakeMatrix(matType, (UInt)args.getCount(), args.getBuffer()); } } else if( auto vecType = as<IRVectorType>(type) ) @@ -832,14 +860,40 @@ struct ByteAddressBufferLegalizationContext } else if( auto matType = as<IRMatrixType>(type) ) { - // Matrix storesget the same caveat as the load case: - // we are only supporting row-major layout for now. - // - auto rowCountInst = as<IRIntLit>(matType->getRowCount()); - if( rowCountInst ) + auto layout = getIntVal(matType->getLayout()); + if (layout != SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { - auto rowType = m_builder.getVectorType(matType->getElementType(), matType->getColumnCount()); - return emitLegalSequenceStore(buffer, baseOffset, immediateOffset, value, rowType, rowCountInst->getValue()); + auto rowCountInst = as<IRIntLit>(matType->getRowCount()); + if( rowCountInst ) + { + auto rowType = m_builder.getVectorType(matType->getElementType(), matType->getColumnCount()); + return emitLegalSequenceStore(buffer, baseOffset, immediateOffset, value, rowType, rowCountInst->getValue()); + } + } + else + { + auto colCount = (Index)getIntVal(matType->getColumnCount()); + auto rowCount = (Index)getIntVal(matType->getRowCount()); + List<IRInst*> srcRows; + for (Index r = 0; r < rowCount; r++) + srcRows.add(m_builder.emitElementExtract(value, (IRIntegerValue)r)); + for (Index c = 0; c < colCount; c++) + { + List<IRInst*> colVectorArgs; + for (Index r = 0; r < rowCount; r++) + { + auto rowVector = srcRows[r]; + auto element = m_builder.emitElementExtract(rowVector, (IRIntegerValue)c); + colVectorArgs.add(element); + } + auto colVectorType = m_builder.getVectorType(matType->getElementType(), rowCount); + auto colVector = m_builder.emitMakeVector(colVectorType, colVectorArgs); + IRSizeAndAlignment colVectorSizeAlignment; + getSizeAndAlignment(m_target, colVectorType, &colVectorSizeAlignment); + emitLegalStore(colVectorType, buffer, baseOffset, immediateOffset, colVector); + immediateOffset += colVectorSizeAlignment.getStride(); + } + return SLANG_OK; } } else if( auto vecType = as<IRVectorType>(type) ) @@ -903,7 +957,7 @@ struct ByteAddressBufferLegalizationContext auto index = m_builder.emitIntrinsicInst(indexType, kIROp_Div, 2, divArgs); IRInst* args[] = { structuredBuffer, index, value }; - m_builder.emitIntrinsicInst(type, kIROp_StructuredBufferStore, 3, args); + m_builder.emitIntrinsicInst(type, kIROp_RWStructuredBufferStore, 3, args); return SLANG_OK; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 062da75e8..c3c92f9ba 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -52,7 +52,7 @@ INST(Nop, nop, 0, 0) INST(BasicBlockType, BasicBlock, 0, HOISTABLE) INST(VectorType, Vec, 2, HOISTABLE) - INST(MatrixType, Mat, 3, HOISTABLE) + INST(MatrixType, Mat, 4, HOISTABLE) INST(ConjunctionType, Conjunction, 0, HOISTABLE) INST(AttributedType, Attributed, 0, HOISTABLE) @@ -427,6 +427,9 @@ INST(ByteAddressBufferStore, byteAddressBufferStore, 3, 0) // - `dst` is a value of type T // INST(StructuredBufferLoad, structuredBufferLoad, 2, 0) +INST(StructuredBufferLoadStatus, structuredBufferLoadStatus, 3, 0) +INST(RWStructuredBufferLoad, rwstructuredBufferLoad, 2, 0) +INST(RWStructuredBufferLoadStatus, rwstructuredBufferLoadStatus, 3, 0) // Store data to a structured buffer // @@ -437,7 +440,9 @@ INST(StructuredBufferLoad, structuredBufferLoad, 2, 0) // - `offset` is an `int` // - `src` is a value of type T // -INST(StructuredBufferStore, structuredBufferStore, 3, 0) +INST(RWStructuredBufferStore, rwstructuredBufferStore, 3, 0) + +INST(RWStructuredBufferGetElementPtr, rwstructuredBufferGetElementPtr, 2, 0) INST(MeshOutputRef, meshOutputRef, 2, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 2d8393698..dfd662be3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1979,6 +1979,14 @@ struct IRStore : IRInst IRInst* getVal() { return val.get(); } }; +struct IRRWStructuredBufferStore : IRInst +{ + IR_LEAF_ISA(RWStructuredBufferStore) + IRInst* getStructuredBuffer() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } + IRInst* getVal() { return getOperand(2); } +}; + struct IRFieldExtract : IRInst { IRUse base; @@ -2015,6 +2023,13 @@ struct IRGetElementPtr : IRInst IRInst* getIndex() { return getOperand(1); } }; +struct IRRWStructuredBufferGetElementPtr : IRInst +{ + IR_LEAF_ISA(RWStructuredBufferGetElementPtr); + IRInst* getBase() { return getOperand(0); } + IRInst* getIndex() { return getOperand(1); } +}; + struct IRLoadReverseGradient : IRInst { IR_LEAF_ISA(LoadReverseGradient) @@ -2957,7 +2972,8 @@ public: IRMatrixType* getMatrixType( IRType* elementType, IRInst* rowCount, - IRInst* columnCount); + IRInst* columnCount, + IRInst* layout); IRArrayListType* getArrayListType(IRType* elementType); IRTensorViewType* getTensorViewType(IRType* elementType); diff --git a/source/slang/slang-ir-lower-buffer-element-type.cpp b/source/slang/slang-ir-lower-buffer-element-type.cpp new file mode 100644 index 000000000..3ef94d415 --- /dev/null +++ b/source/slang/slang-ir-lower-buffer-element-type.cpp @@ -0,0 +1,658 @@ +#include "slang-ir-lower-buffer-element-type.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-clone.h" + +namespace Slang +{ + struct LoweredElementTypeContext + { + struct LoweredElementTypeInfo + { + IRType* originalType; + IRType* loweredType; + IRType* loweredInnerArrayType = nullptr; // For matrix/array types that are lowered into a struct type, this is the inner array type of the data field. + IRStructKey* loweredInnerStructKey = nullptr; // For matrix/array types that are lowered into a struct type, this is the struct key of the data field. + IRFunc* convertOriginalToLowered = nullptr; + IRFunc* convertLoweredToOriginal = nullptr; + }; + Dictionary<IRType*, LoweredElementTypeInfo> loweredTypeInfo; + Dictionary<IRType*, LoweredElementTypeInfo> mapLoweredTypeToInfo; + + SlangMatrixLayoutMode defaultMatrixLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + + LoweredElementTypeContext(SlangMatrixLayoutMode inDefaultMatrixLayout) + : defaultMatrixLayout(inDefaultMatrixLayout) + {} + + IRFunc* createMatrixUnpackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, matrixType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = (Index)getIntVal(matrixType->getRowCount()); + auto colCount = (Index)getIntVal(matrixType->getColumnCount()); + auto packedParam = builder.emitParam(structType); + auto vectorArray = builder.emitFieldExtract(arrayType, packedParam, dataKey); + List<IRInst*> args; + args.setCount(rowCount * colCount); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto vector = builder.emitElementExtract(vectorArray, c); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = builder.emitElementExtract(vector, r); + args[(Index)(r*colCount + c)] = element; + } + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(vectorArray, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + args[(Index)(r * colCount + c)] = element; + } + } + } + IRInst* result = builder.emitMakeMatrix(matrixType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createMatrixPackFunc( + IRMatrixType* matrixType, + IRStructType* structType, + IRVectorType* vectorType, + IRArrayType* arrayType) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&matrixType, structType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packMatrix")); + builder.setInsertInto(func); + builder.emitBlock(); + auto rowCount = getIntVal(matrixType->getRowCount()); + auto colCount = getIntVal(matrixType->getColumnCount()); + auto originalParam = builder.emitParam(matrixType); + List<IRInst*> elements; + elements.setCount((Index)(rowCount * colCount)); + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto vector = builder.emitElementExtract(originalParam, r); + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = builder.emitElementExtract(vector, c); + elements[(Index)(r * colCount + c)] = element; + } + } + List<IRInst*> vectors; + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue c = 0; c < colCount; c++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue r = 0; r < rowCount; r++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + auto colVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer()); + vectors.add(colVector); + } + } + else + { + for (IRIntegerValue r = 0; r < rowCount; r++) + { + List<IRInst*> vecArgs; + for (IRIntegerValue c = 0; c < colCount; c++) + { + auto element = elements[(Index)(r * colCount + c)]; + vecArgs.add(element); + } + auto rowVector = builder.emitMakeVector(vectorType, (UInt)vecArgs.getCount(), vecArgs.getBuffer()); + vectors.add(rowVector); + } + } + + auto vectorArray = builder.emitMakeArray(arrayType, (UInt)vectors.getCount(), vectors.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &vectorArray); + builder.emitReturn(result); + return func; + } + + IRFunc* createArrayUnpackFunc( + IRArrayType* arrayType, + IRStructType* structType, + IRStructKey* dataKey, + IRArrayType* innerArrayType, + LoweredElementTypeInfo innerTypeInfo) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("unpackStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto packedParam = builder.emitParam(structType); + auto packedArray = builder.emitFieldExtract(innerArrayType, packedParam, dataKey); + auto count = getIntVal(arrayType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto packedElement = builder.emitElementExtract(packedArray, ii); + auto originalElement = builder.emitCallInst(innerTypeInfo.originalType, innerTypeInfo.convertLoweredToOriginal, 1, &packedElement); + args[(Index)ii] = originalElement; + } + auto result = builder.emitMakeArray(arrayType, (UInt)args.getCount(), args.getBuffer()); + builder.emitReturn(result); + return func; + } + + IRFunc* createArrayPackFunc( + IRArrayType* arrayType, + IRStructType* structType, + IRArrayType* innerArrayType, + LoweredElementTypeInfo innerTypeInfo) + { + IRBuilder builder(structType); + builder.setInsertAfter(structType); + auto func = builder.createFunc(); + auto funcType = builder.getFuncType(1, (IRType**)&structType, arrayType); + func->setFullType(funcType); + builder.addNameHintDecoration(func, UnownedStringSlice("packStorage")); + builder.setInsertInto(func); + builder.emitBlock(); + auto originalParam = builder.emitParam(arrayType); + auto count = getIntVal(arrayType->getElementCount()); + List<IRInst*> args; + args.setCount((Index)count); + for (IRIntegerValue ii = 0; ii < count; ++ii) + { + auto originalElement = builder.emitElementExtract(originalParam, ii); + auto packedElement = builder.emitCallInst(innerTypeInfo.loweredType, innerTypeInfo.convertOriginalToLowered, 1, &originalElement); + args[(Index)ii] = packedElement; + } + auto packedArray = builder.emitMakeArray(innerArrayType, (UInt)args.getCount(), args.getBuffer()); + auto result = builder.emitMakeStruct(structType, 1, &packedArray); + builder.emitReturn(result); + return func; + } + + LoweredElementTypeInfo getLoweredTypeInfoImpl(IRType* type) + { + IRBuilder builder(type); + builder.setInsertAfter(type); + + LoweredElementTypeInfo info; + info.originalType = type; + + if (auto matrixType = as<IRMatrixType>(type)) + { + if (getIntVal(matrixType->getLayout()) == defaultMatrixLayout) + { + info.loweredType = type; + return info; + } + + auto loweredType = builder.createStructType(); + StringBuilder nameSB; + bool isColMajor = getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR; + nameSB << "_MatrixStorage_"; + getTypeNameHint(nameSB, matrixType->getElementType()); + nameSB << getIntVal(matrixType->getRowCount()) << "x" << getIntVal(matrixType->getColumnCount()); + if (isColMajor) + nameSB << "_ColMajor"; + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto vectorType = builder.getVectorType(matrixType->getElementType(), + isColMajor?matrixType->getRowCount():matrixType->getColumnCount()); + auto arrayType = builder.getArrayType(vectorType, isColMajor?matrixType->getColumnCount():matrixType->getRowCount()); + builder.createStructField(loweredType, structKey, arrayType); + + info.loweredType = loweredType; + info.loweredInnerArrayType = arrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createMatrixUnpackFunc(matrixType, loweredType, structKey, arrayType); + info.convertOriginalToLowered = createMatrixPackFunc(matrixType, loweredType, vectorType, arrayType); + return info; + } + else if (auto arrayType = as<IRArrayType>(type)) + { + auto loweredInnerTypeInfo = getLoweredTypeInfo(arrayType->getElementType()); + + if (loweredInnerTypeInfo.loweredType != loweredInnerTypeInfo.originalType) + { + auto loweredType = builder.createStructType(); + info.loweredType = loweredType; + StringBuilder nameSB; + nameSB << "_ArrayStorage_"; + getTypeNameHint(nameSB, arrayType->getElementType()); + nameSB << getIntVal(arrayType->getElementCount()); + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + auto structKey = builder.createStructKey(); + builder.addNameHintDecoration(structKey, UnownedStringSlice("data")); + auto innerArrayType = builder.getArrayType(loweredInnerTypeInfo.loweredType, arrayType->getElementCount()); + builder.createStructField(loweredType, structKey, innerArrayType); + info.loweredInnerArrayType = innerArrayType; + info.loweredInnerStructKey = structKey; + info.convertLoweredToOriginal = createArrayUnpackFunc(arrayType, loweredType, structKey, innerArrayType, loweredInnerTypeInfo); + info.convertOriginalToLowered = createArrayPackFunc(arrayType, loweredType, innerArrayType, loweredInnerTypeInfo); + } + else + { + info.loweredType = type; + } + return info; + } + else if (auto structType = as<IRStructType>(type)) + { + bool hasNonTrivialField = false; + List<LoweredElementTypeInfo> fieldLoweredTypeInfo; + for (auto field : structType->getFields()) + { + auto loweredFieldTypeInfo = getLoweredTypeInfo(field->getFieldType()); + fieldLoweredTypeInfo.add(loweredFieldTypeInfo); + if (loweredFieldTypeInfo.loweredType != loweredFieldTypeInfo.originalType) + hasNonTrivialField = true; + } + + if (!hasNonTrivialField) + { + info.loweredType = type; + return info; + } + + auto loweredType = builder.createStructType(); + StringBuilder nameSB; + getTypeNameHint(nameSB, type); + nameSB << "_Storage"; + builder.addNameHintDecoration(loweredType, nameSB.produceString().getUnownedSlice()); + info.loweredType = loweredType; + + // Create fields. + { + Index fieldId = 0; + for (auto field : structType->getFields()) + { + auto loweredFieldTypeInfo = fieldLoweredTypeInfo[fieldId]; + builder.createStructField(loweredType, field->getKey(), loweredFieldTypeInfo.loweredType); + fieldId++; + } + } + + // Create unpack func. + { + builder.setInsertAfter(loweredType); + info.convertLoweredToOriginal = builder.createFunc(); + builder.setInsertInto(info.convertLoweredToOriginal); + builder.addNameHintDecoration(info.convertLoweredToOriginal, UnownedStringSlice("unpackStorage")); + info.convertLoweredToOriginal->setFullType(builder.getFuncType(1, (IRType**)&loweredType, type)); + builder.emitBlock(); + auto loweredParam = builder.emitParam(loweredType); + List<IRInst*> args; + Index fieldId = 0; + for (auto field : structType->getFields()) + { + auto storageField = builder.emitFieldExtract(fieldLoweredTypeInfo[fieldId].loweredType, loweredParam, field->getKey()); + auto unpackedField = fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal + ? builder.emitCallInst(field->getFieldType(), fieldLoweredTypeInfo[fieldId].convertLoweredToOriginal, 1, &storageField) + : storageField; + args.add(unpackedField); + fieldId++; + } + auto result = builder.emitMakeStruct(type, args); + builder.emitReturn(result); + } + + // Create pack func. + { + builder.setInsertAfter(info.convertLoweredToOriginal); + info.convertOriginalToLowered = builder.createFunc(); + builder.setInsertInto(info.convertOriginalToLowered); + builder.addNameHintDecoration(info.convertOriginalToLowered, UnownedStringSlice("packStorage")); + info.convertOriginalToLowered->setFullType(builder.getFuncType(1, (IRType**)&type, loweredType)); + builder.emitBlock(); + auto param = builder.emitParam(type); + List<IRInst*> args; + Index fieldId = 0; + for (auto field : structType->getFields()) + { + auto fieldVal = builder.emitFieldExtract(type, param, field->getKey()); + auto packedField = fieldLoweredTypeInfo[fieldId].convertOriginalToLowered + ? builder.emitCallInst(fieldLoweredTypeInfo[fieldId].loweredType, fieldLoweredTypeInfo[fieldId].convertOriginalToLowered, 1, &fieldVal) + : fieldVal; + args.add(packedField); + fieldId++; + } + auto result = builder.emitMakeStruct(loweredType, args); + builder.emitReturn(result); + } + + return info; + } + + info.loweredType = type; + return info; + } + + LoweredElementTypeInfo getLoweredTypeInfo(IRType* type) + { + LoweredElementTypeInfo info; + if (loweredTypeInfo.tryGetValue(type, info)) + return info; + info = getLoweredTypeInfoImpl(type); + loweredTypeInfo[type] = info; + mapLoweredTypeToInfo[info.loweredType] = info; + return info; + } + + IRType* getLoweredPtrLikeType(IRType* originalPtrLikeType, IRType* newElementType) + { + if (as<IRPointerLikeType>(originalPtrLikeType) || as<IRPtrTypeBase>(originalPtrLikeType) || as<IRHLSLStructuredBufferTypeBase>(originalPtrLikeType)) + { + IRBuilder builder(newElementType); + builder.setInsertAfter(newElementType); + return builder.getType(originalPtrLikeType->getOp(), newElementType); + } + SLANG_UNREACHABLE("unhandled ptr like or buffer type"); + } + + IRInst* getStoreVal(IRInst* storeInst) + { + if (auto store = as<IRStore>(storeInst)) + return store->getVal(); + else if (auto sbStore = as<IRRWStructuredBufferStore>(storeInst)) + return sbStore->getVal(); + return nullptr; + } + + void processModule(IRModule* module) + { + IRBuilder builder(module); + struct BufferTypeInfo + { + IRType* bufferType; + IRType* elementType; + }; + List<BufferTypeInfo> bufferTypeInsts; + for (auto globalInst : module->getGlobalInsts()) + { + IRType* elementType = nullptr; + if (auto structBuffer = as<IRHLSLStructuredBufferTypeBase>(globalInst)) + elementType = structBuffer->getElementType(); + else if (auto constBuffer = as<IRUniformParameterGroupType>(globalInst)) + elementType = constBuffer->getElementType(); + if (as<IRTextureBufferType>(globalInst)) + continue; + if (!as<IRStructType>(elementType) && !as<IRMatrixType>(elementType) && !as<IRArrayType>(elementType)) + continue; + bufferTypeInsts.add(BufferTypeInfo{ (IRType*)globalInst, elementType }); + } + + // Maintain a pending work list of all matrix addresses, and try to lower them out of existance + // after everything else has been lowered. + List<IRInst*> matrixAddrInsts; + + for (auto bufferTypeInfo : bufferTypeInsts) + { + auto bufferType = bufferTypeInfo.bufferType; + auto elementType = bufferTypeInfo.elementType; + auto loweredBufferElementTypeInfo = getLoweredTypeInfo(elementType); + + // If the lowered type is the same as original type, no change is required. + if (!loweredBufferElementTypeInfo.convertLoweredToOriginal) + continue; + + builder.setInsertBefore(bufferType); + + auto loweredBufferType = builder.getType( + bufferType->getOp(), + loweredBufferElementTypeInfo.loweredType); + + // We treat a value of a buffer type as a pointer, and use a work list to translate + // all loads and stores through the pointer values that needs lowering. + + List<IRInst*> ptrValsWorkList; + traverseUses(bufferType, [&](IRUse* use) + { + auto user = use->getUser(); + if (use != &user->typeUse) + return; + ptrValsWorkList.add(use->getUser()); + }); + + // Translate the values to use new lowered buffer type instead. + for (Index i = 0; i < ptrValsWorkList.getCount(); i++) + { + auto ptrVal = ptrValsWorkList[i]; + auto oldPtrType = ptrVal->getFullType(); + auto originalElementType = oldPtrType->getOperand(0); + auto loweredElementTypeInfo = getLoweredTypeInfo((IRType*)originalElementType); + if (!loweredElementTypeInfo.convertLoweredToOriginal) + continue; + + ptrVal->setFullType(getLoweredPtrLikeType(ptrVal->getFullType(), loweredElementTypeInfo.loweredType)); + + traverseUses(ptrVal, [&](IRUse* use) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_Load: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferLoadStatus: + { + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto newLoad = cloneInst(&cloneEnv, &builder, user); + newLoad->setFullType(loweredElementTypeInfo.loweredType); + auto unpackedVal = builder.emitCallInst(elementType, loweredElementTypeInfo.convertLoweredToOriginal, 1, &newLoad); + user->replaceUsesWith(unpackedVal); + user->removeAndDeallocate(); + break; + } + case kIROp_Store: + case kIROp_RWStructuredBufferStore: + { + // Use must be the dest operand of the store inst. + if (use != user->getOperands() + 0) + break; + IRCloneEnv cloneEnv = {}; + builder.setInsertBefore(user); + auto originalVal = getStoreVal(user); + auto packedVal = builder.emitCallInst(loweredElementTypeInfo.loweredType, loweredElementTypeInfo.convertOriginalToLowered, 1, &originalVal); + if (auto store = as<IRStore>(user)) + store->val.set(packedVal); + else if (auto sbStore = as<IRRWStructuredBufferStore>(user)) + sbStore->setOperand(2, packedVal); + else + SLANG_UNREACHABLE("unhandled store type"); + break; + } + case kIROp_GetElementPtr: + case kIROp_FieldAddress: + { + // If original type is an array, the lowered type will be a struct. + // In that case, all existing address insts should be appended with a field extract. + if (as<IRArrayType>(originalElementType)) + { + builder.setInsertBefore(user); + List<IRInst*> args; + for (UInt i = 0; i < user->getOperandCount(); i++) + args.add(user->getOperand(i)); + auto newArrayPtrVal = builder.emitFieldAddress( + builder.getPtrType(loweredElementTypeInfo.loweredInnerArrayType), + ptrVal, + loweredElementTypeInfo.loweredInnerStructKey); + builder.replaceOperand(use, newArrayPtrVal); + ptrValsWorkList.add(user); + } + else if (as<IRMatrixType>(originalElementType)) + { + // We are tring to get a pointer to a lowered matrix element. + // We process this insts at a later phase. + SLANG_ASSERT(user->getOp() == kIROp_GetElementPtr); + matrixAddrInsts.add(user); + } + else + { + // If we getting a derived address from the pointer, we need to recursively + // lower the new address. We do so by pushing the address inst into the + // work list. + ptrValsWorkList.add(user); + } + } + break; + case kIROp_RWStructuredBufferGetElementPtr: + ptrValsWorkList.add(user); + break; + default: + SLANG_UNREACHABLE("unhandled inst of a buffer/pointer value that needs storage lowering."); + break; + } + }); + } + + // Replace all remaining uses of bufferType to loweredBufferType, these uses are non-operational and should be + // directly replaceable, such as uses in `IRFuncType`. + bufferType->replaceUsesWith(loweredBufferType); + bufferType->removeAndDeallocate(); + } + + lowerMatrixAddresses(module, matrixAddrInsts); + } + + // Lower all getElementPtr insts of a lowered matrix out of existance. + void lowerMatrixAddresses(IRModule* module, List<IRInst*>& matrixAddrInsts) + { + IRBuilder builder(module); + for (auto majorAddr : matrixAddrInsts) + { + auto majorGEP = as<IRGetElementPtr>(majorAddr); + SLANG_ASSERT(majorGEP); + auto loweredMatrixType = cast<IRPtrTypeBase>(majorGEP->getBase()->getFullType())->getValueType(); + auto matrixTypeInfo = mapLoweredTypeToInfo.tryGetValue(loweredMatrixType); + SLANG_ASSERT(matrixTypeInfo); + auto matrixType = as<IRMatrixType>(matrixTypeInfo->originalType); + auto rowCount = getIntVal(matrixType->getRowCount()); + traverseUses(majorAddr, [&](IRUse* use) + { + auto user = use->getUser(); + builder.setInsertBefore(user); + switch (user->getOp()) + { + case kIROp_Load: + { + IRInst* resultInst = nullptr; + auto dataPtr = builder.emitFieldAddress( + builder.getPtrType(matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + List<IRInst*> args; + for (IRIntegerValue i = 0; i < rowCount; i++) + { + auto vector = builder.emitLoad(builder.emitElementAddress(dataPtr, i)); + auto element = builder.emitElementExtract(vector, majorGEP->getIndex()); + args.add(element); + } + resultInst = builder.emitMakeVector(builder.getVectorType(matrixType->getElementType(), (IRIntegerValue)args.getCount()), args); + } + else + { + auto element = builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + resultInst = builder.emitLoad(element); + } + user->replaceUsesWith(resultInst); + user->removeAndDeallocate(); + } + break; + case kIROp_Store: + { + auto storeInst = cast<IRStore>(user); + if (storeInst->getOperand(0) != majorAddr) + break; + auto dataPtr = builder.emitFieldAddress( + builder.getPtrType(matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + for (IRIntegerValue i = 0; i < rowCount; i++) + { + auto vectorAddr = builder.emitElementAddress(dataPtr, i); + auto elementAddr = builder.emitElementAddress(vectorAddr, majorGEP->getIndex()); + builder.emitStore(elementAddr, builder.emitElementExtract(storeInst->getVal(), i)); + } + } + else + { + auto rowAddr = builder.emitElementAddress(dataPtr, majorGEP->getIndex()); + builder.emitStore(rowAddr, storeInst->getVal()); + user->removeAndDeallocate(); + } + break; + } + case kIROp_GetElementPtr: + { + auto gep2 = cast<IRGetElementPtr>(user); + auto rowIndex = majorGEP->getIndex(); + auto colIndex = gep2->getIndex(); + if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) + { + Swap(rowIndex, colIndex); + } + auto dataPtr = builder.emitFieldAddress( + builder.getPtrType(matrixTypeInfo->loweredInnerArrayType), + majorGEP->getBase(), + matrixTypeInfo->loweredInnerStructKey); + auto vectorAddr = builder.emitElementAddress(dataPtr, rowIndex); + auto elementAddr = builder.emitElementAddress(vectorAddr, colIndex); + gep2->replaceUsesWith(elementAddr); + gep2->removeAndDeallocate(); + break; + } + default: + SLANG_UNREACHABLE("unhandled inst of a matrix address inst that needs storage lowering."); + break; + } + }); + } + } + }; + + void lowerBufferElementTypeToStorageType(TargetRequest* target, IRModule* module) + { + SlangMatrixLayoutMode defaultMatrixMode = (SlangMatrixLayoutMode)target->getDefaultMatrixLayoutMode(); + if (defaultMatrixMode == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + defaultMatrixMode = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + LoweredElementTypeContext context(defaultMatrixMode); + context.processModule(module); + } +} diff --git a/source/slang/slang-ir-lower-buffer-element-type.h b/source/slang/slang-ir-lower-buffer-element-type.h new file mode 100644 index 000000000..bbee71df4 --- /dev/null +++ b/source/slang/slang-ir-lower-buffer-element-type.h @@ -0,0 +1,20 @@ +#ifndef SLANG_IR_LOWER_BUFFER_ELEMENT_TYPE_H +#define SLANG_IR_LOWER_BUFFER_ELEMENT_TYPE_H + +namespace Slang +{ + struct IRModule; + class TargetRequest; + + // For each struct type S used as element type of a ConstantBuffer, ParameterBlock or [RW]StructuredBuffer, + // we create a lowered type L, where matrix types are lowered to arrays of vectors based on major-ness, + // and loads from the buffer are converted to L_to_S(load(buffer)), and stores to the buffer are + // converted to store(buffer, S_to_L(val)). + // This pass needs to take place after type legalization, and before array return type lowering + // because it may create functions that returns array typed values. + // + void lowerBufferElementTypeToStorageType(TargetRequest* target, IRModule* module); + +} + +#endif diff --git a/source/slang/slang-ir-specialize-matrix-layout.cpp b/source/slang/slang-ir-specialize-matrix-layout.cpp new file mode 100644 index 000000000..5ce61f4cf --- /dev/null +++ b/source/slang/slang-ir-specialize-matrix-layout.cpp @@ -0,0 +1,46 @@ +#include "slang-ir-specialize-matrix-layout.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +namespace Slang +{ + + void visitParent(List<IRMatrixType*>& typeWorkList, IRInst* parent) + { + for (auto child : parent->getChildren()) + { + if (auto matrixType = as<IRMatrixType>(child)) + { + if (auto constLayout = as<IRIntLit>(matrixType->getLayout())) + { + if (constLayout->getValue() == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + { + typeWorkList.add(matrixType); + } + } + } + visitParent(typeWorkList, child); + } + } + + void specializeMatrixLayout(TargetRequest* target, IRModule* module) + { + List<IRMatrixType*> typeWorkList; + visitParent(typeWorkList, module->getModuleInst()); + + IRIntegerValue defaultLayout = target->getDefaultMatrixLayoutMode(); + if (defaultLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + defaultLayout = SLANG_MATRIX_LAYOUT_ROW_MAJOR; + + IRBuilder builder(module); + for (auto matrixType : typeWorkList) + { + builder.setInsertBefore(matrixType); + auto replacementMatrixType = builder.getMatrixType(matrixType->getElementType(), matrixType->getRowCount(), matrixType->getColumnCount(), + builder.getIntValue(builder.getIntType(), defaultLayout)); + matrixType->replaceUsesWith(replacementMatrixType); + } + } + +} diff --git a/source/slang/slang-ir-specialize-matrix-layout.h b/source/slang/slang-ir-specialize-matrix-layout.h new file mode 100644 index 000000000..3074f72cf --- /dev/null +++ b/source/slang/slang-ir-specialize-matrix-layout.h @@ -0,0 +1,16 @@ +#ifndef SLANG_IR_SPECIALIZE_MATRIX_LAYOUT_H +#define SLANG_IR_SPECIALIZE_MATRIX_LAYOUT_H + +namespace Slang +{ + struct IRModule; + class TargetRequest; + + // Repalce all matrix types whose layout is not specified with the default layout + // of the target request. + // + void specializeMatrixLayout(TargetRequest* target, IRModule* module); + +} + +#endif diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index daea4582a..cc9bf4164 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -182,30 +182,41 @@ struct SPIRVLegalizationContext : public SourceEmitterBase } } - void processGetElementPtr(IRGetElementPtr* inst) + void processGetElementPtrImpl(IRInst* gepInst, IRInst* base, IRInst* index) { - if (auto ptrType = as<IRPtrTypeBase>(inst->getBase()->getDataType())) + if (auto ptrType = as<IRPtrTypeBase>(base->getDataType())) { if (!ptrType->hasAddressSpace()) return; - auto oldResultType = as<IRPtrTypeBase>(inst->getDataType()); + auto oldResultType = as<IRPtrTypeBase>(gepInst->getDataType()); if (oldResultType->getAddressSpace() != ptrType->getAddressSpace()) { IRBuilder builder(m_sharedContext->m_irModule); - builder.setInsertBefore(inst); + builder.setInsertBefore(gepInst); auto newPtrType = builder.getPtrType( oldResultType->getOp(), oldResultType->getValueType(), ptrType->getAddressSpace()); + IRInst* args[2] = { base, index }; auto newInst = - builder.emitElementAddress(newPtrType, inst->getBase(), inst->getIndex()); - inst->replaceUsesWith(newInst); - inst->removeAndDeallocate(); + builder.emitIntrinsicInst(newPtrType, gepInst->getOp(), 2, args); + gepInst->replaceUsesWith(newInst); + gepInst->removeAndDeallocate(); addUsersToWorkList(newInst); } } } + void processGetElementPtr(IRGetElementPtr* gepInst) + { + processGetElementPtrImpl(gepInst, gepInst->getBase(), gepInst->getIndex()); + } + + void processRWStructuredBufferGetElementPtr(IRRWStructuredBufferGetElementPtr* gepInst) + { + processGetElementPtrImpl(gepInst, gepInst->getBase(), gepInst->getIndex()); + } + void processFieldAddress(IRFieldAddress* inst) { if (auto ptrType = as<IRPtrTypeBase>(inst->getBase()->getDataType())) @@ -286,6 +297,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_FieldAddress: processFieldAddress(as<IRFieldAddress>(inst)); break; + case kIROp_RWStructuredBufferGetElementPtr: + processRWStructuredBufferGetElementPtr(as<IRRWStructuredBufferGetElementPtr>(inst)); + break; case kIROp_HLSLStructuredBufferType: case kIROp_HLSLRWStructuredBufferType: processStructuredBufferType(as<IRHLSLStructuredBufferTypeBase>(inst)); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 41ac19ee9..f96cc174c 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -407,6 +407,8 @@ bool isPtrLikeOrHandleType(IRInst* type) return true; if (as<IRPseudoPtrType>(type)) return true; + if (as<IRHLSLStructuredBufferTypeBase>(type)) + return true; switch (type->getOp()) { case kIROp_ComPtrType: @@ -871,10 +873,15 @@ bool isGlobalOrUnknownMutableAddress(IRGlobalValueWithCode* parentFunc, IRInst* if (root) { + // If this is a global readonly resource, it is not a mutable address. if (as<IRParameterGroupType>(root->getDataType())) { return false; } + if (as<IRHLSLStructuredBufferType>(root->getDataType())) + { + return false; + } } switch (root->getOp()) @@ -991,6 +998,47 @@ void resetScratchDataBit(IRInst* inst, int bitIndex) } } +UnownedStringSlice getBasicTypeNameHint(IRType* basicType) +{ + switch (basicType->getOp()) + { + case kIROp_IntType: + return UnownedStringSlice::fromLiteral("int"); + case kIROp_Int8Type: + return UnownedStringSlice::fromLiteral("int8"); + case kIROp_Int16Type: + return UnownedStringSlice::fromLiteral("int16"); + case kIROp_Int64Type: + return UnownedStringSlice::fromLiteral("int64"); + case kIROp_IntPtrType: + return UnownedStringSlice::fromLiteral("intptr"); + case kIROp_UIntType: + return UnownedStringSlice::fromLiteral("uint"); + case kIROp_UInt8Type: + return UnownedStringSlice::fromLiteral("uint8"); + case kIROp_UInt16Type: + return UnownedStringSlice::fromLiteral("uint16"); + case kIROp_UInt64Type: + return UnownedStringSlice::fromLiteral("uint64"); + case kIROp_UIntPtrType: + return UnownedStringSlice::fromLiteral("uintptr"); + case kIROp_FloatType: + return UnownedStringSlice::fromLiteral("float"); + case kIROp_HalfType: + return UnownedStringSlice::fromLiteral("half"); + case kIROp_DoubleType: + return UnownedStringSlice::fromLiteral("double"); + case kIROp_BoolType: + return UnownedStringSlice::fromLiteral("bool"); + case kIROp_VoidType: + return UnownedStringSlice::fromLiteral("void"); + case kIROp_CharType: + return UnownedStringSlice::fromLiteral("char"); + default: + return UnownedStringSlice(); + } +} + struct GenericChildrenMigrationContextImpl { IRCloneEnv cloneEnv; diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index fc44b8d30..a0336e1c2 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -223,6 +223,7 @@ bool isOne(IRInst* inst); void initializeScratchData(IRInst* inst); void resetScratchDataBit(IRInst* inst, int bitIndex); + } #endif diff --git a/source/slang/slang-ir-wrap-structured-buffers.cpp b/source/slang/slang-ir-wrap-structured-buffers.cpp index 0e7248320..6b7043416 100644 --- a/source/slang/slang-ir-wrap-structured-buffers.cpp +++ b/source/slang/slang-ir-wrap-structured-buffers.cpp @@ -161,50 +161,21 @@ struct WrapStructuredBuffersContext // are calls that could potentially be intrinsic // operations on `*StructuredBuffer`. // - auto call = as<IRCall>(valueUse->getUser()); - if(!call) - return; - if(call->getArgCount() == 0) - return; - if(call->getArg(0) != valueOfStructuredBufferType) - return; - - // At this point we have a candidate `call` instruction, - // but we need to determine whether it is a call to - // one of the `*StructuredBuffer` intrinsics that we want - // to rewrite, or if it is another user-defined function - // that we should leave along (even if that user-defined - // function happens to return a matrix). - // - // For now we will do this in a somewhat ad-hoc fashion. - // We know that the `Load` and `operator[]` operations - // on `*StructuredBuffer` are generic, and unlike user-defined - // generic functions they will not have been specialized - // before we get here. - // - // We will thus use the fact that the callee of the call - // is a `specialize` instruction to let us know that it - // is an intrinsic, and thus should be one of the functions - // we care about. - // - // TODO: Figure out if there is a more robust way to make - // this check. It is possible that structured buffer - // access should be modeled with explicit IR opcodes - // rather than just as builtin functions. - // - auto callee = call->getCallee(); - if(!as<IRSpecialize>(callee)) + auto user = valueUse->getUser(); + switch (user->getOp()) + { + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferLoadStatus: + case kIROp_RWStructuredBufferStore: + case kIROp_RWStructuredBufferLoadStatus: + case kIROp_RWStructuredBufferGetElementPtr: + break; + default: return; + } - // At this point it seems likely we have one of the calls - // we want to rewrite, but there are still intrinsics - // like `GetDimensions` that we want to leave alone. - // - // For now we will look at the return type of the call, - // where we care about two cases. - // - builder->setInsertBefore(call->getNextInst()); - auto oldResultType = call->getDataType(); + builder->setInsertAfter(user); + auto oldResultType = user->getDataType(); // First we care about the case for `Load`, which // will return the element type, which would be @@ -217,7 +188,7 @@ struct WrapStructuredBuffersContext // go ahead and modify its type to be correct. // auto newResultType = wrapperStruct; - builder->setDataType(call, newResultType); + builder->setDataType(user, newResultType); // Next, we need to make sure to extract the // field from the wrapper struct, so that @@ -232,12 +203,12 @@ struct WrapStructuredBuffersContext // // float4x4 newVal = call.wrapped; // - auto newVal = builder->emitFieldExtract(oldResultType, call, wrappedFieldKey); + auto newVal = builder->emitFieldExtract(oldResultType, user, wrappedFieldKey); // Any code that used the value of `call` should // now use `newVal` instead... // - call->replaceUsesWith(newVal); + user->replaceUsesWith(newVal); // // ... except for one important gotcha, which is // that `newVal` itself used `call`, and replacing @@ -251,7 +222,7 @@ struct WrapStructuredBuffersContext // of `replaceUsesWith` that can handle cases like // this. // - newVal->setOperand(0, call); + newVal->setOperand(0, user); } // // The second interesting case is the `ref` accessor @@ -276,11 +247,11 @@ struct WrapStructuredBuffersContext // there if you want the comments. auto newResultType = builder->getPtrType(oldPtrType->getOp(), wrapperStruct); - builder->setDataType(call, newResultType); + builder->setDataType(user, newResultType); - auto newVal = builder->emitFieldAddress(oldResultType, call, wrappedFieldKey); - call->replaceUsesWith(newVal); - newVal->setOperand(0, call); + auto newVal = builder->emitFieldAddress(oldResultType, user, wrappedFieldKey); + user->replaceUsesWith(newVal); + newVal->setOperand(0, user); } } }); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 0a79cec57..49c91dc22 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2808,9 +2808,10 @@ namespace Slang IRMatrixType* IRBuilder::getMatrixType( IRType* elementType, IRInst* rowCount, - IRInst* columnCount) + IRInst* columnCount, + IRInst* layout) { - IRInst* operands[] = { elementType, rowCount, columnCount }; + IRInst* operands[] = { elementType, rowCount, columnCount, layout }; return (IRMatrixType*)getType( kIROp_MatrixType, sizeof(operands) / sizeof(operands[0]), @@ -7304,6 +7305,9 @@ namespace Slang case kIROp_MakeDifferentialPair: case kIROp_MakeTuple: case kIROp_GetTupleElement: + case kIROp_StructuredBufferLoad: + case kIROp_RWStructuredBufferLoad: + case kIROp_RWStructuredBufferGetElementPtr: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_LoadReverseGradient: case kIROp_ReverseGradientDiffPairRef: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index cdce11bbb..c5778004b 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1553,6 +1553,7 @@ struct IRMatrixType : IRType IRType* getElementType() { return (IRType*)getOperand(0); } IRInst* getRowCount() { return getOperand(1); } IRInst* getColumnCount() { return getOperand(2); } + IRInst* getLayout() { return getOperand(3); } IR_LEAF_ISA(MatrixType) }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index b1a38febd..d67d57001 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1817,11 +1817,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto elementType = lowerType(context, type->getElementType()); auto rowCount = lowerSimpleVal(context, type->getRowCount()); auto columnCount = lowerSimpleVal(context, type->getColumnCount()); - + auto layout = lowerSimpleVal(context, type->getLayout()); return getBuilder()->getMatrixType( elementType, rowCount, - columnCount); + columnCount, + layout); } IRType* visitArrayExpressionType(ArrayExpressionType* type) diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index dbd6b97a4..85602b744 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -372,11 +372,11 @@ void initCommandOptions(CommandOptions& options) const Option generalOpts[] = { - { OptionKind::MacroDefine, "-D?...", "-D<name>[=<value>], -D <name>[=<value>]", - "Insert a preprocessor macro.\n" + { OptionKind::MacroDefine, "-D?...", "-D<name>[=<value>], -D <name>[=<value>]", + "Insert a preprocessor macro.\n" "The space between - D and <name> is optional. If no <value> is specified, Slang will define the macro with an empty value." }, { OptionKind::DepFile, "-depfile", "-depfile <path>", "Save the source file dependency list in a file." }, - { OptionKind::EntryPointName, "-entry", "-entry <name>", + { OptionKind::EntryPointName, "-entry", "-entry <name>", "Specify the name of an entry-point function.\n" "When compiling from a single file, this defaults to main if you specify a stage using -stage.\n" "Multiple -entry options may be used in a single invocation. " @@ -386,7 +386,7 @@ void initCommandOptions(CommandOptions& options) { OptionKind::EmitIr, "-emit-ir", nullptr, "Emit IR typically as a '.slang-module' when outputting to a container." }, { OptionKind::Help, "-h,-help,--help", "-h or -h <help-category>", "Print this message, or help in specified category." }, { OptionKind::HelpStyle, "-help-style", "-help-style <help-style>", "Help formatting style" }, - { OptionKind::Include, "-I?...", "-I<path>, -I <path>", + { OptionKind::Include, "-I?...", "-I<path>, -I <path>", "Add a path to be used in resolving '#include' " "and 'import' operations."}, { OptionKind::Language, "-lang", "-lang <language>", "Set the language for the following input files."}, diff --git a/source/slang/slang-spirv-val.cpp b/source/slang/slang-spirv-val.cpp index 990ccd909..54bf5348b 100644 --- a/source/slang/slang-spirv-val.cpp +++ b/source/slang/slang-spirv-val.cpp @@ -3,6 +3,29 @@ namespace Slang { +SlangResult debugDisassembleSPIRV(const List<uint8_t>& spirv, String& outDis) +{ + CommandLine commandLine; + commandLine.m_executableLocation.setName("spirv-dis"); + RefPtr<Process> p; + const auto createResult = Process::create(commandLine, 0, p); + // If we failed to even start the process, then validation isn't available + SLANG_RETURN_ON_FAIL(createResult); + const auto in = p->getStream(StdStreamType::In); + const auto out = p->getStream(StdStreamType::Out); + // Write the assembly + SLANG_RETURN_ON_FAIL(in->write(spirv.getBuffer(), spirv.getCount())); + in->close(); + // Wait for it to finish + if (!p->waitForTermination(1000)) + return SLANG_FAIL; + + List<Byte> outData; + SLANG_RETURN_ON_FAIL(StreamUtil::readAll(out, 0, outData)); + outDis = String((const char*)outData.getBuffer()); + return SLANG_OK; +} + SlangResult debugValidateSPIRV(const List<uint8_t>& spirv) { // Set up our process @@ -25,6 +48,7 @@ SlangResult debugValidateSPIRV(const List<uint8_t>& spirv) if(!p->waitForTermination(1000)) return SLANG_FAIL; + // TODO: allow inheriting stderr in Process List<Byte> outData; SLANG_RETURN_ON_FAIL(StreamUtil::readAll(out, 0, outData)); @@ -32,8 +56,14 @@ SlangResult debugValidateSPIRV(const List<uint8_t>& spirv) outData.clear(); SLANG_RETURN_ON_FAIL(StreamUtil::readAll(err, 0, outData)); fwrite(outData.getBuffer(), outData.getCount(), 1, stderr); - const auto ret = p->getReturnValue(); + if (SLANG_FAILED(ret)) + { + String spirvDis; + debugDisassembleSPIRV(spirv, spirvDis); + fwrite(spirvDis.getBuffer(), spirvDis.getLength(), 1, stderr); + } + return ret == 0 ? SLANG_OK : SLANG_FAIL; } diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index d075b12e2..bf3df4adc 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -2866,23 +2866,6 @@ static TypeLayoutResult _createTypeLayout( if (declForModifiers) { - // TODO: The approach implemented here has a row/column-major - // layout model recursively affect any sub-fields (so that - // the layout of a nested struct depends on the context where - // it is nested). This is consistent with the GLSL behavior - // for these modifiers, but it is *not* how HLSL is supposed - // to work. - // - // In the trivial case where `row_major` and `column_major` - // are only applied to leaf fields/variables of matrix type - // the difference should be immaterial. - - if (declForModifiers->hasModifier<RowMajorLayoutModifier>()) - subContext.matrixLayoutMode = kMatrixLayoutMode_RowMajor; - - if (declForModifiers->hasModifier<ColumnMajorLayoutModifier>()) - subContext.matrixLayoutMode = kMatrixLayoutMode_ColumnMajor; - // TODO: really need to look for other modifiers that affect // layout, such as GLSL `std140`. } @@ -3866,7 +3849,12 @@ static TypeLayoutResult _createTypeLayout( // size_t layoutMajorCount = rowCount; size_t layoutMinorCount = colCount; - if (context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) + auto matrixLayout = getIntVal(matType->getLayout()); + if (matrixLayout == SLANG_MATRIX_LAYOUT_MODE_UNKNOWN) + { + matrixLayout = context.matrixLayoutMode; + } + if (matrixLayout == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { size_t tmp = layoutMajorCount; layoutMajorCount = layoutMinorCount; @@ -3891,7 +3879,7 @@ static TypeLayoutResult _createTypeLayout( size_t rowStride = 0; size_t colStride = 0; - if(context.matrixLayoutMode == kMatrixLayoutMode_ColumnMajor) + if (matrixLayout == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR) { colStride = majorStride; rowStride = minorStride; @@ -3918,7 +3906,7 @@ static TypeLayoutResult _createTypeLayout( typeLayout->elementTypeLayout = rowTypeLayout; typeLayout->uniformStride = rowStride; - typeLayout->mode = context.matrixLayoutMode; + typeLayout->mode = (MatrixLayoutMode)matrixLayout; typeLayout->addResourceUsage(info.kind, info.size); diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index 912a8f2a7..619b76b85 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -792,4 +792,16 @@ <Type Name="Slang::TypeEqualityWitness"> <DisplayString>{sub,na} == {sup,na}</DisplayString> </Type> + + <Type Name="Slang::ConstantIntVal"> + <DisplayString>ConstantIntVal ({m_operands.m_buffer[1].values.intOperand} : {*(Type*)m_operands.m_buffer[0].values.nodeOperand})</DisplayString> + </Type> + + <Type Name="Slang::GenericParamIntVal"> + <DisplayString>GenericParamIntVal ({*(DeclRefBase*)m_operands.m_buffer[1].values.nodeOperand})</DisplayString> + </Type> + + <Type Name="Slang::BasicExpressionType"> + <DisplayString>BasicExpressionType ({*(DeclRefBase*)m_operands.m_buffer[0].values.nodeOperand})</DisplayString> + </Type> </AutoVisualizer> |
