diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-16 13:55:32 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-16 13:55:32 -0800 |
| commit | 4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch) | |
| tree | ed4af0ded878e4f06e9641ce61d26ffd7c89ccbc /source/slang | |
| parent | eda88e513e8b1e2abc05e9dc8555f237d96472df (diff) | |
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend.
* Update IR documentation.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
41 files changed, 1497 insertions, 3653 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 2a8344e3a..6357d58bd 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -786,6 +786,8 @@ __generic<T = float, let R : int = 4, let C : int = 4> __magic_type(Matrix) struct matrix { + __intrinsic_op($(kIROp_MakeMatrixFromScalar)) + __init(T val); } ${{{{ @@ -1093,9 +1095,6 @@ extension matrix<T, R, C> : IDifferentiable { typedef matrix<T, R, C> Differential; - __intrinsic_op($(kIROp_MakeMatrixFromScalar)) - __init(T val); - [__unsafeForceInlineEarly] static Differential dzero() { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 464811a96..1d2b327d2 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -872,36 +872,31 @@ matrix<T, N, M> acos(matrix<T, N, M> x) // Test if all components are non-zero (HLSL SM 1.0) __generic<T : __BuiltinType> +__target_intrinsic(cpp, "bool($0)") +__target_intrinsic(cuda, "bool($0)") __target_intrinsic(glsl, "bool($0)") bool all(T x); __generic<T : __BuiltinType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "all(bvec$N0($0))") -bool all(vector<T,N> x); -// TODO: implementation of `all()` in the stdlib is -// blocked on fixing implementation of `bool` vector -// `getAt` on the CUDA codegen path. -/* +bool all(vector<T,N> x) { bool result = true; for(int i = 0; i < N; ++i) result = result && all(x[i]); return result; } -*/ __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(hlsl) -bool all(matrix<T,N,M> x); -/* +bool all(matrix<T,N,M> x) { bool result = true; for(int i = 0; i < N; ++i) result = result && all(x[i]); return result; } -*/ // Barrier for writes to all memory spaces (HLSL SM 5.0) __target_intrinsic(glsl, "memoryBarrier(), groupMemoryBarrier(), memoryBarrierImage(), memoryBarrierBuffer()") @@ -916,42 +911,39 @@ void AllMemoryBarrierWithGroupSync(); // Test if any components is non-zero (HLSL SM 1.0) __generic<T : __BuiltinType> +__target_intrinsic(cpp, "bool($0)") +__target_intrinsic(cuda, "bool($0)") __target_intrinsic(glsl, "bool($0)") bool any(T x); __generic<T : __BuiltinType, let N : int> __target_intrinsic(hlsl) __target_intrinsic(glsl, "any(bvec$N0($0))") -bool any(vector<T, N> x); -// TODO: implementation of `any()` in the stdlib is -// blocked on fixing implementation of `bool` vector -// `getAt` on the CUDA codegen path. -/* +bool any(vector<T, N> x) { bool result = false; for(int i = 0; i < N; ++i) result = result || any(x[i]); return result; } -*/ __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(hlsl) -bool any(matrix<T, N, M> x); -/* +bool any(matrix<T, N, M> x) { bool result = false; for(int i = 0; i < N; ++i) result = result || any(x[i]); return result; } -*/ // Reinterpret bits as a double (HLSL SM 5.0) __target_intrinsic(hlsl) __target_intrinsic(glsl, "packDouble2x32(uvec2($0, $1))") +__target_intrinsic(cpp, "$P_asdouble($0, $1)") +__target_intrinsic(cuda, "$P_asdouble($0, $1)") __target_intrinsic(spirv_direct, "%v = OpCompositeConstruct _type(uint2) resultId _0 _1; OpExtInst resultType resultId glsl450 59 %v") __glsl_extension(GL_ARB_gpu_shader5) double asdouble(uint lowbits, uint highbits); @@ -960,11 +952,15 @@ double asdouble(uint lowbits, uint highbits); __target_intrinsic(hlsl) __target_intrinsic(glsl, "intBitsToFloat") +__target_intrinsic(cpp, "$P_asfloat($0)") +__target_intrinsic(cuda, "$P_asfloat($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") float asfloat(int x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "uintBitsToFloat") +__target_intrinsic(cpp, "$P_asfloat($0)") +__target_intrinsic(cuda, "$P_asfloat($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") float asfloat(uint x); @@ -1044,11 +1040,15 @@ matrix<T, N, M> asin(matrix<T, N, M> x) __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToInt") +__target_intrinsic(cpp, "$P_asint($0)") +__target_intrinsic(cuda, "$P_asint($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") int asint(float x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "int($0)") +__target_intrinsic(cpp, "$P_asint($0)") +__target_intrinsic(cuda, "$P_asint($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") int asint(uint x); @@ -1104,6 +1104,8 @@ matrix<int,N,M> asint(matrix<int,N,M> x) __target_intrinsic(hlsl) __target_intrinsic(glsl, "{ uvec2 v = unpackDouble2x32($0); $1 = v.x; $2 = v.y; }") __glsl_extension(GL_ARB_gpu_shader5) +__target_intrinsic(cpp, "$P_asuint($0, $1, $2)") +__target_intrinsic(cuda, "$P_asuint($0, $1, $2)") void asuint(double value, out uint lowbits, out uint highbits); // Reinterpret bits as a uint (HLSL SM 4.0) @@ -1111,11 +1113,15 @@ void asuint(double value, out uint lowbits, out uint highbits); __target_intrinsic(hlsl) __target_intrinsic(glsl, "floatBitsToUint") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") +__target_intrinsic(cpp, "$P_asuint($0)") +__target_intrinsic(cuda, "$P_asuint($0)") uint asuint(float x); __target_intrinsic(hlsl) __target_intrinsic(glsl, "uint($0)") __target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0") +__target_intrinsic(cpp, "$P_asuint($0)") +__target_intrinsic(cuda, "$P_asuint($0)") uint asuint(int x); __generic<let N : int> @@ -1812,7 +1818,7 @@ __target_intrinsic(glsl, "unpackHalf2x16($0).x") __glsl_version(420) __target_intrinsic(hlsl) __cuda_sm_version(6.0) -__target_intrinsic(cuda, "__half2float(__short_as_half($0))") +__target_intrinsic(cuda, "__half2float(__ushort_as_half($0))") float f16tof32(uint value); __generic<let N : int> diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 87b620ed2..ba6b26ec6 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -66,111 +66,6 @@ namespace Slang { static const char s_xyzwNames[] = "xyzw"; -static UnownedStringSlice _getTypePrefix(IROp op) -{ - switch (op) - { - case kIROp_BoolType: return UnownedStringSlice::fromLiteral("Bool"); - case kIROp_IntType: return UnownedStringSlice::fromLiteral("I32"); - case kIROp_UIntType: return UnownedStringSlice::fromLiteral("U32"); - case kIROp_FloatType: return UnownedStringSlice::fromLiteral("F32"); - case kIROp_Int64Type: return UnownedStringSlice::fromLiteral("I64"); - case kIROp_UInt64Type: return UnownedStringSlice::fromLiteral("U64"); - case kIROp_DoubleType: return UnownedStringSlice::fromLiteral("F64"); - default: return UnownedStringSlice::fromLiteral("?"); - } -} - - -static IROp _getCType(IROp op) -{ - switch (op) - { - case kIROp_VoidType: - case kIROp_BoolType: - { - return op; - } - case kIROp_Int8Type: - case kIROp_Int16Type: - case kIROp_IntType: - case kIROp_UInt8Type: - case kIROp_UInt16Type: - case kIROp_UIntType: - { - // Promote all these to Int - return kIROp_IntType; - } - case kIROp_IntPtrType: - case kIROp_UIntPtrType: - { - return kIROp_IntPtrType; - } - case kIROp_Int64Type: - case kIROp_UInt64Type: - { - // Promote all these to Int64, we can just vary the call to make these work - return kIROp_Int64Type; - } - case kIROp_DoubleType: - { - return kIROp_DoubleType; - } - case kIROp_HalfType: - case kIROp_FloatType: - { - // Promote both to float - return kIROp_FloatType; - } - default: - { - SLANG_ASSERT(!"Unhandled type"); - return kIROp_undefined; - } - } -} - -static UnownedStringSlice _getCTypeVecPostFix(IROp op) -{ - switch (op) - { - case kIROp_BoolType: return UnownedStringSlice::fromLiteral("B"); - case kIROp_IntType: return UnownedStringSlice::fromLiteral("I"); - case kIROp_UIntType: return UnownedStringSlice::fromLiteral("U"); - case kIROp_FloatType: return UnownedStringSlice::fromLiteral("F"); - case kIROp_Int64Type: return UnownedStringSlice::fromLiteral("I64"); - case kIROp_DoubleType: return UnownedStringSlice::fromLiteral("F64"); - case kIROp_IntPtrType: return UnownedStringSlice::fromLiteral(""); - case kIROp_UIntPtrType: return UnownedStringSlice::fromLiteral(""); - default: return UnownedStringSlice::fromLiteral("?"); - } -} - -static bool _isCppTarget(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - return true; - default: - return false; - } -} - -static bool _isCppOrCudaTarget(CodeGenTarget target) -{ - switch (target) - { - case CodeGenTarget::CPPSource: - case CodeGenTarget::HostCPPSource: - case CodeGenTarget::CUDASource: - return true; - default: - return false; - } -} - /* !!!!!!!!!!!!!!!!!!!!!!!! CPPEmitHandler !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ /* static */ UnownedStringSlice CPPSourceEmitter::getBuiltinTypeName(IROp op) @@ -204,118 +99,8 @@ static bool _isCppOrCudaTarget(CodeGenTarget target) } } -void CPPSourceEmitter::emitTypeDefinition(IRType* inType) +UnownedStringSlice CPPSourceEmitter::_getTypeName(IRType* type) { - if (_isCppTarget(m_target)) - { - // All types are templates in C++ - return; - } - - IRType* type = m_typeSet.getType(inType); - if (!m_typeSet.isOwned(type)) - { - // If defined in a different module, we assume they are emitted already. (Assumed to - // be a nominal type) - return; - } - - SourceWriter* writer = getSourceWriter(); - - switch (type->getOp()) - { - case kIROp_VectorType: - { - auto vecType = static_cast<IRVectorType*>(type); - - const UnownedStringSlice* elemNames = getVectorElementNames(vecType); - - int count = int(getIntVal(vecType->getElementCount())); - - SLANG_ASSERT(count > 0 && count < 4); - - UnownedStringSlice typeName = _getTypeName(type); - UnownedStringSlice elemName = _getTypeName(vecType->getElementType()); - - writer->emit("struct "); - writer->emit(typeName); - writer->emit("\n{\n"); - writer->indent(); - - writer->emit(elemName); - writer->emit(" "); - for (int i = 0; i < count; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - writer->emit(elemNames[i]); - } - writer->emit(";\n"); - - writer->dedent(); - writer->emit("};\n\n"); - break; - } - case kIROp_MatrixType: - { - auto matType = static_cast<IRMatrixType*>(type); - - const auto rowCount = int(getIntVal(matType->getRowCount())); - const auto colCount = int(getIntVal(matType->getColumnCount())); - - IRType* vecType = m_typeSet.addVectorType(matType->getElementType(), colCount); - - UnownedStringSlice typeName = _getTypeName(type); - UnownedStringSlice rowTypeName = _getTypeName(vecType); - - writer->emit("template<>\n"); - writer->emit("struct "); - writer->emit(typeName); - writer->emit("\n{\n"); - writer->indent(); - - writer->emit(rowTypeName); - writer->emit(" rows["); - writer->emit(rowCount); - writer->emit("];\n"); - - writer->dedent(); - writer->emit("};\n\n"); - break; - } - case kIROp_PtrType: - case kIROp_RefType: - { - // We don't need to output a definition for these types - break; - } - case kIROp_ArrayType: - case kIROp_UnsizedArrayType: - case kIROp_HLSLRWStructuredBufferType: - { - // We don't need to output a definition for these with C++ templates - // For C we may need to (or do casting at point of usage) - break; - } - default: - { - if (IRBasicType::isaImpl(type->getOp())) - { - // Don't emit anything for built in types - return; - } - SLANG_ASSERT(!"Unhandled type"); - break; - } - } -} - -UnownedStringSlice CPPSourceEmitter::_getTypeName(IRType* inType) -{ - IRType* type = m_typeSet.getType(inType); - StringSlicePool::Handle handle = StringSlicePool::kNullHandle; if (m_typeNameMap.TryGetValue(type, handle)) { @@ -424,22 +209,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S auto vecCount = int(getIntVal(vecType->getElementCount())); auto elemType = vecType->getElementType(); - if (_isCppOrCudaTarget(target)) - { - out << "Vector<" << _getTypeName(elemType) << ", " << vecCount << ">"; - } - else - { - out << "Vec"; - UnownedStringSlice postFix = _getCTypeVecPostFix(elemType->getOp()); - - out << postFix; - if (postFix.getLength() > 1) - { - out << "_"; - } - out << vecCount; - } + out << "Vector<" << _getTypeName(elemType) << ", " << vecCount << ">"; return SLANG_OK; } case kIROp_MatrixType: @@ -450,22 +220,8 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S const auto rowCount = int(getIntVal(matType->getRowCount())); const auto colCount = int(getIntVal(matType->getColumnCount())); - if (_isCppOrCudaTarget(target)) - { - out << "Matrix<" << _getTypeName(elementType) << ", " << rowCount << ", " << colCount << ">"; - } - else - { - out << "Mat"; - const UnownedStringSlice postFix = _getCTypeVecPostFix(_getCType(elementType->getOp())); - out << postFix; - if (postFix.getLength() > 1) - { - out << "_"; - } - out << rowCount; - out << colCount; - } + out << "Matrix<" << _getTypeName(elementType) << ", " << rowCount << ", " << colCount << ">"; + return SLANG_OK; } case kIROp_WitnessTableType: @@ -625,17 +381,6 @@ void CPPSourceEmitter::useType(IRType* type) _getTypeName(type); } -static IRBasicType* _getElementType(IRType* type) -{ - switch (type->getOp()) - { - case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break; - case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break; - default: break; - } - return dynamicCast<IRBasicType>(type); -} - /* static */CPPSourceEmitter::TypeDimension CPPSourceEmitter::_getTypeDimension(IRType* type, bool vecSwap) { switch (type->getOp()) @@ -735,943 +480,11 @@ void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDim } } -static bool _isOperator(const UnownedStringSlice& funcName) -{ - if (funcName.getLength() > 0) - { - const char c = funcName[0]; - return !((c >= 'a' && c <='z') || (c >= 'A' && c <= 'Z') || c == '_'); - } - return false; -} - -void CPPSourceEmitter::_emitAryDefinition(const HLSLIntrinsic* specOp) -{ - auto info = HLSLIntrinsic::getInfo(specOp->op); - auto funcName = info.funcName; - SLANG_ASSERT(funcName.getLength() > 0); - - const bool isOperator = _isOperator(funcName); - - SourceWriter* writer = getSourceWriter(); - - IRFuncType* funcType = specOp->signatureType; - const int numParams = int(funcType->getParamCount()); - SLANG_ASSERT(numParams <= 3); - - bool areAllScalar = true; - TypeDimension paramDims[3]; - for (int i = 0; i < numParams; ++i) - { - paramDims[i]= _getTypeDimension(funcType->getParamType(i), false); - areAllScalar = areAllScalar && paramDims[i].isScalar(); - } - - // If all are scalar, then we don't need to emit a definition - if (areAllScalar) - { - return; - } - - IRType* retType = specOp->returnType; - - UnownedStringSlice scalarFuncName(funcName); - if (isOperator) - { - StringBuilder builder; - builder << "operator"; - builder << funcName; - _emitSignature(builder.getUnownedSlice(), specOp); - } - else - { - scalarFuncName = _getScalarFuncName(specOp->op, _getElementType(funcType->getParamType(0))); - _emitSignature(funcName, specOp); - } - - writer->emit("\n{\n"); - writer->indent(); - - const bool hasReturnType = retType->getOp() != kIROp_VoidType; - - TypeDimension calcDim; - if (hasReturnType) - { - emitType(retType); - writer->emit(" r;\n"); - - calcDim = _getTypeDimension(retType, false); - } - else - { - calcDim = _getTypeDimension(funcType->getParamType(0), false); - } - - for (int i = 0; i < calcDim.rowCount; ++i) - { - for (int j = 0; j < calcDim.colCount; ++j) - { - if (hasReturnType) - { - _emitAccess(UnownedStringSlice::fromLiteral("r"), calcDim, i, j, writer); - writer->emit(" = "); - } - - if (isOperator) - { - switch (numParams) - { - case 1: - { - writer->emit(funcName); - _emitAccess(UnownedStringSlice::fromLiteral("a"), paramDims[0], i, j, writer); - break; - } - case 2: - { - _emitAccess(UnownedStringSlice::fromLiteral("a"), paramDims[0], i, j, writer); - writer->emit(" "); - writer->emit(funcName); - writer->emit(" "); - _emitAccess(UnownedStringSlice::fromLiteral("b"), paramDims[1], i, j, writer); - break; - } - default: SLANG_ASSERT(!"Unhandled"); - } - } - else - { - writer->emit(scalarFuncName); - writer->emit("("); - for (int k = 0; k < numParams; k++) - { - if (k > 0) - { - writer->emit(", "); - } - char c = char('a' + k); - _emitAccess(UnownedStringSlice(&c, 1), paramDims[k], i, j, writer); - } - writer->emit(")"); - } - writer->emit(";\n"); - } - } - - if (hasReturnType) - { - writer->emit("return r;\n"); - } - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_emitAnyAllDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - IRFuncType* funcType = specOp->signatureType; - SLANG_ASSERT(funcType->getParamCount() == 1); - IRType* paramType0 = funcType->getParamType(0); - - SourceWriter* writer = getSourceWriter(); - - IRType* elementType = _getElementType(paramType0); - SLANG_ASSERT(elementType); - IRType* retType = specOp->returnType; - auto retTypeName = _getTypeName(retType); - - IROp style = getTypeStyle(elementType->getOp()); - - const TypeDimension dim = _getTypeDimension(paramType0, false); - - _emitSignature(funcName, specOp); - writer->emit("\n{\n"); - writer->indent(); - - writer->emit("return "); - - for (int i = 0; i < dim.rowCount; ++i) - { - for (int j = 0; j < dim.colCount; ++j) - { - if (i > 0 || j > 0) - { - if (specOp->op == HLSLIntrinsic::Op::All) - { - writer->emit(" && "); - } - else - { - writer->emit(" || "); - } - } - - switch (style) - { - case kIROp_BoolType: - { - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - break; - } - case kIROp_IntType: - { - writer->emit("("); - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - writer->emit(" != 0)"); - break; - } - case kIROp_FloatType: - { - writer->emit("("); - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - writer->emit(" != 0.0)"); - break; - } - } - } - } - - writer->emit(";\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_emitSignature(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - IRFuncType* funcType = specOp->signatureType; - const int paramsCount = int(funcType->getParamCount()); - IRType* retType = specOp->returnType; - - emitFunctionPreambleImpl(nullptr); - - SourceWriter* writer = getSourceWriter(); - - emitType(retType); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - - for (int i = 0; i < paramsCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - - // We can't pass as const& for vector, scalar, array types, as they are pass by value - // For types passed by reference, we should do something different - IRType* paramType = funcType->getParamType(i); -#if 0 - writer->emit("const "); -#endif - emitType(paramType); -#if 0 - if (dynamicCast<IRBasicType>(paramType)) - { - writer->emit(" "); - } - else - { - writer->emit("& "); - } -#else - - writer->emit(" "); -#endif - - writer->emitChar(char('a' + i)); - } - writer->emit(")"); -} - -UnownedStringSlice CPPSourceEmitter::_getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType) -{ - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(op, retType, argTypes, argCount, intrinsic); - auto specOp = m_intrinsicSet.add(intrinsic); - _maybeEmitSpecializedOperationDefinition(specOp); - return _getFuncName(specOp); -} - -void CPPSourceEmitter::_emitGetAtDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - - IRFuncType* funcType = specOp->signatureType; - SLANG_ASSERT(funcType->getParamCount() == 2); - - IRType* srcType = funcType->getParamType(0); - - for (Index i = 0; i < 3; ++i) - { - UnownedStringSlice typePrefix = (i == 0) ? UnownedStringSlice::fromLiteral("const ") : UnownedStringSlice(); - bool lValue = (i != 2); - - emitFunctionPreambleImpl(nullptr); - - writer->emit(typePrefix); - emitType(specOp->returnType); - if (lValue) - m_writer->emit("*"); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - - writer->emit(typePrefix); - emitType(funcType->getParamType(0)); - if (lValue) - writer->emit("*"); - writer->emit(" a, "); - emitType(funcType->getParamType(1)); - writer->emit(" b)\n{\n"); - - writer->indent(); - - if (auto vectorType = as<IRVectorType>(srcType)) - { - int vecSize = int(getIntVal(vectorType->getElementCount())); - - writer->emit("SLANG_PRELUDE_ASSERT(b >= 0 && b < "); - writer->emit(vecSize); - writer->emit(");\n"); - - writer->emit("return (("); - emitType(specOp->returnType); - writer->emit("*)"); - - if (lValue) - writer->emit("a) + b;\n"); - else - writer->emit("&a)[b];\n"); - } - else if (auto matrixType = as<IRMatrixType>(srcType)) - { - //int colCount = int(getIntVal(matrixType->getColumnCount())); - int rowCount = int(getIntVal(matrixType->getRowCount())); - - writer->emit("SLANG_PRELUDE_ASSERT(b >= 0 && b < "); - writer->emit(rowCount); - writer->emit(");\n"); - - if (lValue) - writer->emit("return &(a->rows[b]);\n"); - else - writer->emit("return a.rows[b];\n"); - } - - writer->dedent(); - writer->emit("}\n\n"); - } -} - -void CPPSourceEmitter::_emitConstructConvertDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - IRFuncType* funcType = specOp->signatureType; - - SLANG_ASSERT(funcType->getParamCount() == 2); - - IRType* srcType = funcType->getParamType(1); - IRType* retType = specOp->returnType; - - emitFunctionPreambleImpl(nullptr); - - emitType(retType); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - emitType(srcType); - writer->emitChar(' '); - writer->emitChar(char('a' + 0)); - writer->emit(")"); - - writer->emit("\n{\n"); - writer->indent(); - - writer->emit("return "); - emitType(retType); - writer->emit("{ "); - - - IRType* dstElemType = _getElementType(retType); - //IRType* srcElemType = _getElementType(srcType); - - TypeDimension dim = _getTypeDimension(retType, false); - - UnownedStringSlice rowTypeName; - if (dim.rowCount > 1) - { - IRType* rowType = m_typeSet.addVectorType(dstElemType, int(dim.colCount)); - rowTypeName = _getTypeName(rowType); - } - - for (int i = 0; i < dim.rowCount; ++i) - { - if (dim.rowCount > 1) - { - if (i > 0) - { - writer->emit(", \n"); - } - - if (m_target == CodeGenTarget::CUDASource) - { - m_writer->emit("make_"); - writer->emit(rowTypeName); - m_writer->emit("("); - } - else - { - writer->emit(rowTypeName); - writer->emit("{ "); - } - } - - for (int j = 0; j < dim.colCount; ++j) - { - if (j > 0) - { - writer->emit(", "); - } - - emitType(dstElemType); - writer->emit("("); - _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer); - writer->emit(")"); - } - if (dim.rowCount > 1) - { - if (m_target == CodeGenTarget::CUDASource) - { - writer->emit(")"); - } - else - { - writer->emit("}"); - } - } - } - - writer->emit("};\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - IRFuncType* funcType = specOp->signatureType; - - emitFunctionPreambleImpl(nullptr); - - IRType* retType = specOp->returnType; - - _emitSignature(funcName, specOp); - writer->emit("\n{\n"); - writer->indent(); - - // Use C++ construction - writer->emit("return "); - emitType(retType); - writer->emit("{ "); - - const Index paramCount = Index(funcType->getParamCount()); - bool handled = false; - - if (IRVectorType* vecType = as<IRVectorType>(retType)) - { - Index elementCount = Index(getIntVal(vecType->getElementCount())); - - Index paramIndex = 0; - Index paramSubIndex = 0; - - for (Index i = 0; i < elementCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - - if (paramIndex >= paramCount) - { - writer->emit("0"); - } - else - { - IRType* paramType = funcType->getParamType(paramIndex); - - if (IRVectorType* paramVecType = as<IRVectorType>(paramType)) - { - Index paramElementCount = Index(getIntVal(paramVecType->getElementCount())); - - const UnownedStringSlice* elemNames = getVectorElementNames(paramVecType); - - writer->emitChar('a' + char(paramIndex)); - writer->emit("."); - writer->emit(elemNames[paramSubIndex]); - - paramSubIndex++; - - if (paramSubIndex >= paramElementCount) - { - paramIndex++; - paramSubIndex = 0; - } - } - else - { - writer->emitChar('a' + char(paramIndex)); - paramIndex++; - } - } - } - handled = true; - } - else if (IRMatrixType* matType = as<IRMatrixType>(retType)) - { - if (paramCount != 1) - goto fallback; - - auto paramMat = as<IRMatrixType>(funcType->getParamType(0)); - if (!paramMat) - goto fallback; - - // We are constructing a matrix from a differently sized matrix. - - Index rows = Index(getIntVal(matType->getRowCount())); - Index cols = Index(getIntVal(matType->getColumnCount())); - Index paramRows = Index(getIntVal(paramMat->getRowCount())); - Index paramCols = Index(getIntVal(paramMat->getColumnCount())); - char elementNames[] = { 'x', 'y', 'z', 'w' }; - - for (Index r = 0; r < rows; r++) - { - for (Index c = 0; c < cols; c++) - { - if (r != 0 || c != 0) - writer->emit(", "); - - if (r < paramRows && c < paramCols && c < 4) - { - writer->emitRawText("a.rows["); - writer->emit(r); - writer->emitRawText("]."); - writer->emitChar(elementNames[c]); - } - else - { - writer->emit("0"); - } - } - } - handled = true; - } -fallback: - if (!handled) - { - // Fallback default: just use all params to construct. - for (Index i = 0; i < paramCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - writer->emitChar('a' + char(i)); - } - } - - writer->emit("};\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - - -void CPPSourceEmitter::_emitConstructFromScalarDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp) -{ - SourceWriter* writer = getSourceWriter(); - IRFuncType* funcType = specOp->signatureType; - - SLANG_ASSERT(funcType->getParamCount() == 2); - - IRType* srcType = funcType->getParamType(1); - IRType* retType = specOp->returnType; - - emitFunctionPreambleImpl(nullptr); - - emitType(retType); - writer->emit(" "); - writer->emit(funcName); - writer->emit("("); - emitType(srcType); - writer->emitChar(' '); - writer->emitChar(char('a' + 0)); - writer->emit(")"); - - writer->emit("\n{\n"); - writer->indent(); - - writer->emit("return "); - emitType(retType); - writer->emit("{ "); - - const TypeDimension dim = _getTypeDimension(retType, false); - - for (int i = 0; i < dim.rowCount; ++i) - { - if (dim.rowCount > 1) - { - if (i > 0) - { - writer->emit(", \n"); - } - writer->emit("{ "); - } - for (int j = 0; j < dim.colCount; ++j) - { - if (j > 0) - { - writer->emit(", "); - } - writer->emit("a"); - } - if (dim.rowCount > 1) - { - writer->emit("}"); - } - } - - writer->emit("};\n"); - - writer->dedent(); - writer->emit("}\n\n"); -} - -void CPPSourceEmitter::_maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) -{ - // Check if it's been emitted already, if not add it. - if (!m_intrinsicEmitted.Add(specOp)) - { - return; - } - emitSpecializedOperationDefinition(specOp); -} - -void CPPSourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) -{ - typedef HLSLIntrinsic::Op Op; - - switch (specOp->op) - { - case Op::Init: - { - return _emitInitDefinition(_getFuncName(specOp), specOp); - } - case Op::Any: - case Op::All: - { - return _emitAnyAllDefinition(_getFuncName(specOp), specOp); - } - case Op::ConstructConvert: - { - return _emitConstructConvertDefinition(_getFuncName(specOp), specOp); - } - case Op::ConstructFromScalar: - { - return _emitConstructFromScalarDefinition(_getFuncName(specOp), specOp); - } - case Op::GetAt: - { - return _emitGetAtDefinition(_getFuncName(specOp), specOp); - } - case Op::Swizzle: - { - // Don't have to output anything for swizzle for now - return; - } - default: - { - const auto& info = HLSLIntrinsic::getInfo(specOp->op); - const int paramCount = (info.numOperands < 0) ? int(specOp->signatureType->getParamCount()) : info.numOperands; - - if (paramCount >= 1 && paramCount <= 3) - { - return _emitAryDefinition(specOp); - } - break; - } - } - - SLANG_ASSERT(!"Unhandled"); -} - -void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) -{ - typedef HLSLIntrinsic::Op Op; - - SLANG_UNUSED(inOuterPrec); - SourceWriter* writer = getSourceWriter(); - - switch (specOp->op) - { - case Op::Init: - { - IRType* retType = specOp->returnType; - if (IRBasicType::isaImpl(retType->getOp())) - { - SLANG_ASSERT(numOperands == 1); - - writer->emit(_getTypeName(retType)); - writer->emitChar('('); - - emitOperand(operands[0].get(), getInfo(EmitOp::General)); - - writer->emitChar(')'); - return; - } - break; - } - case Op::Swizzle: - { - // Currently only works for C++ (we use {} constuction) - which means we don't need to generate a function. - // For C we need to generate suitable construction function - auto swizzleInst = static_cast<IRSwizzle*>(inst); - const Index elementCount = Index(swizzleInst->getElementCount()); - - IRType* srcType = swizzleInst->getBase()->getDataType(); - IRVectorType* srcVecType = as<IRVectorType>(srcType); - - const UnownedStringSlice* elemNames = getVectorElementNames(srcVecType); - - // TODO(JS): Not 100% sure this is correct on the parens handling front - IRType* retType = specOp->returnType; - emitType(retType); - writer->emit("{"); - - for (Index i = 0; i < elementCount; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - - auto outerPrec = getInfo(EmitOp::General); - - auto prec = getInfo(EmitOp::Postfix); - emitOperand(swizzleInst->getBase(), leftSide(outerPrec, prec)); - - writer->emit("."); - - IRInst* irElementIndex = swizzleInst->getElementIndex(i); - SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit); - IRConstant* irConst = (IRConstant*)irElementIndex; - UInt elementIndex = (UInt)irConst->value.intVal; - SLANG_RELEASE_ASSERT(elementIndex < 4); - - writer->emit(elemNames[elementIndex]); - } - - writer->emit("}"); - return; - } - default: break; - } - - { - const auto& info = HLSLIntrinsic::getInfo(specOp->op); - // Make sure that the return type is available - const bool isOperator = _isOperator(info.funcName); - const UnownedStringSlice funcName = _getFuncName(specOp); - - switch (specOp->op) - { - case Op::ConstructFromScalar: - { - // We need to special case, because this may have come from a swizzle from a built in - // type, in that case the only parameter we want is the first one - numOperands = 1; - break; - } - - default: break; - } - - // add that we want a function - SLANG_ASSERT(info.numOperands < 0 || numOperands == info.numOperands); - - useType(specOp->returnType); - - if (isOperator) - { - // Just do the default output - defaultEmitInstExpr(inst, inOuterPrec); - } - else - { - writer->emit(funcName); - writer->emitChar('('); - - for (int i = 0; i < numOperands; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - emitOperand(operands[i].get(), getInfo(EmitOp::General)); - } - - writer->emitChar(')'); - } - } -} - -HLSLIntrinsic* CPPSourceEmitter::_addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount) -{ - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(op, returnType, argTypes, argTypeCount, intrinsic); - HLSLIntrinsic* addedIntrinsic = m_intrinsicSet.add(intrinsic); - _getFuncName(addedIntrinsic); - return addedIntrinsic; -} - -SlangResult CPPSourceEmitter::calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) -{ - outBuilder << _getTypePrefix(type->getOp()) << "_" << HLSLIntrinsic::getInfo(op).funcName; - return SLANG_OK; -} - -UnownedStringSlice CPPSourceEmitter::_getScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type) -{ - /* TODO(JS): This is kind of fast and loose. That we don't know all the parameters that are taken or - what the return type is, so we can't add to the HLSLIntrinsic map - we just generate the scalar - function name and use it (whilst also adding to the slice pool, so that we can return an - unowned slice). */ - - StringBuilder builder; - if (SLANG_FAILED(calcScalarFuncName(op, type, builder))) - { - SLANG_ASSERT(!"Unable to create scalar function name"); - return UnownedStringSlice(); - } - - // Add to the pool. - auto handle = m_slicePool.add(builder); - return m_slicePool.getSlice(handle); -} - -UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp) -{ - StringSlicePool::Handle handle = StringSlicePool::kNullHandle; - if (m_intrinsicNameMap.TryGetValue(specOp, handle)) - { - return m_slicePool.getSlice(handle); - } - - StringBuilder builder; - if (SLANG_FAILED(calcFuncName(specOp, builder))) - { - SLANG_ASSERT(!"Unable to create function name"); - // Return an empty slice, as an error... - return UnownedStringSlice(); - } - - handle = m_slicePool.add(builder); - m_intrinsicNameMap.Add(specOp, handle); - - SLANG_ASSERT(handle != StringSlicePool::kNullHandle); - return m_slicePool.getSlice(handle); -} - -SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder) -{ - typedef HLSLIntrinsic::Op Op; - - if (specOp->isScalar()) - { - IRType* paramType = specOp->signatureType->getParamType(0); - IRBasicType* basicType = as<IRBasicType>(paramType); - if (basicType) - { - return calcScalarFuncName(specOp->op, basicType, outBuilder); - } - else - { - outBuilder << getName(paramType) << HLSLIntrinsic::getInfo(specOp->op).name; - return SLANG_OK; - } - } - else - { - switch (specOp->op) - { - case Op::ConstructConvert: - { - // Work out the function name - IRFuncType* signatureType = specOp->signatureType; - SLANG_ASSERT(signatureType->getParamCount() == 2); - - IRType* dstType = signatureType->getParamType(0); - //IRType* srcType = signatureType->getParamType(1); - - outBuilder << "convert_"; - // I need a function that is called that will construct this - SLANG_RETURN_ON_FAIL(calcTypeName(dstType, CodeGenTarget::CSource, outBuilder)); - return SLANG_OK; - } - case Op::ConstructFromScalar: - { - // Work out the function name - IRFuncType* signatureType = specOp->signatureType; - SLANG_ASSERT(signatureType->getParamCount() == 2); - - IRType* dstType = signatureType->getParamType(0); - - outBuilder << "constructFromScalar_"; - // I need a function that is called that will construct this - SLANG_RETURN_ON_FAIL(calcTypeName(dstType, CodeGenTarget::CSource, outBuilder)); - return SLANG_OK; - } - case Op::GetAt: - { - outBuilder << "getAt"; - return SLANG_OK; - } - case Op::Init: - { - outBuilder << "make_"; - SLANG_RETURN_ON_FAIL(calcTypeName(specOp->returnType, CodeGenTarget::CSource, outBuilder)); - return SLANG_OK; - } - default: break; - } - - const auto& info = HLSLIntrinsic::getInfo(specOp->op); - if (info.funcName.getLength()) - { - if (!_isOperator(info.funcName)) - { - // If there is a standard default name, just use that - outBuilder << info.funcName; - return SLANG_OK; - } - } - - // Just use the name of the Op. This is probably wrong, but gives a pretty good idea of what the desired (presumably missing) op is. - outBuilder << info.name; - return SLANG_OK; - } -} - /* !!!!!!!!!!!!!!!!!!!!!! CPPSourceEmitter !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ CPPSourceEmitter::CPPSourceEmitter(const Desc& desc): Super(desc), - m_slicePool(StringSlicePool::Style::Default), - m_typeSet(desc.codeGenContext->getSession()), - m_opLookup(new HLSLIntrinsicOpLookup), - m_intrinsicSet(&m_typeSet, m_opLookup) + m_slicePool(StringSlicePool::Style::Default) { m_semanticUsedFlags = 0; //m_semanticUsedFlags = SemanticUsedFlag::GroupID | SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::DispatchThreadID; @@ -2145,12 +958,16 @@ void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) void CPPSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) { - emitSimpleType(m_typeSet.addVectorType(elementType, int(elementCount))); + m_writer->emit("Vector<"); + m_writer->emit(_getTypeName(elementType)); + m_writer->emit(", "); + m_writer->emit(elementCount); + m_writer->emit(">"); } void CPPSourceEmitter::emitSimpleTypeImpl(IRType* inType) { - UnownedStringSlice slice = _getTypeName(m_typeSet.getType(inType)); + UnownedStringSlice slice = _getTypeName(inType); m_writer->emit(slice); } @@ -2225,8 +1042,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) { - typedef HLSLIntrinsic::Op Op; - // TODO: Much of this logic duplicates code that is already // in `CLikeSourceEmitter::emitIntrinsicCallExpr`. The only // real difference is that when things bottom out on an ordinary @@ -2248,36 +1063,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( if (name == ".operator[]") { SLANG_ASSERT(argCount == 2 || argCount == 3); - - // If the first item is either a matrix or a vector, we use 'getAt' logic - IRType* targetType = args[0].get()->getDataType(); - if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType) - { - // Work out the intrinsic used - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(HLSLIntrinsic::Op::GetAt, inst->getDataType(), args, 2, intrinsic); - HLSLIntrinsic* specOp = m_intrinsicSet.add(intrinsic); - - if (argCount == 2) - { - // Load - emitCall(specOp, inst, args, 2, inOuterPrec); - } - else - { - // Store - auto prec = getInfo(EmitOp::Postfix); - needClose = maybeEmitParens(outerPrec, prec); - - emitCall(specOp, inst, inst->getOperands(), 2, inOuterPrec); - - m_writer->emit(" = "); - emitOperand(inst->getOperand(2), getInfo(EmitOp::General)); - - maybeCloseParens(needClose); - } - } - else { // The user is invoking a built-in subscript operator @@ -2318,21 +1103,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( return; } - { - Op op = m_opLookup->getOpByName(name); - if (op != Op::Invalid) - { - - // Work out the intrinsic used - HLSLIntrinsic intrinsic; - m_intrinsicSet.calcIntrinsic(op, inst->getDataType(), args, argCount, intrinsic); - HLSLIntrinsic* specOp = m_intrinsicSet.add(intrinsic); - - emitCall(specOp, inst, args, int(argCount), inOuterPrec); - return; - } - } - // Use default impl (which will do intrinsic special macro expansion as necessary) return Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); } @@ -2372,32 +1142,147 @@ const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(IRVectorType* return getVectorElementNames(basicType->getBaseType(), elemCount); } -bool CPPSourceEmitter::_tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec) +bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { - HLSLIntrinsic* specOp = m_intrinsicSet.add(inst); - if (specOp) + switch (inst->getOp()) { - if (inst->getOp() == kIROp_Call) + default: { - IRCall* call = static_cast<IRCall*>(inst); - emitCall(specOp, inst, call->getArgs(), int(call->getArgCount()), inOuterPrec); + return false; } - else + case kIROp_MakeVector: { - emitCall(specOp, inst, inst->getOperands(), int(inst->getOperandCount()), inOuterPrec); + IRType* retType = inst->getFullType(); + emitType(retType); + m_writer->emit("("); + bool isFirst = true; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto arg = inst->getOperand(i); + if (auto vectorType = as<IRVectorType>(arg->getDataType())) + { + for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++) + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(arg, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emitChar(s_xyzwNames[j]); + } + } + else + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } + } + m_writer->emit(")"); + + return true; } - return true; - } - return false; -} + case kIROp_CastFloatToInt: + case kIROp_CastIntToFloat: + case kIROp_FloatCast: + case kIROp_IntCast: + { + if (auto vectorType = as<IRVectorType>(inst->getDataType())) + { + emitType(vectorType); + m_writer->emit("{"); + for (Index i = 0; i < cast<IRIntLit>(vectorType->getElementCount())->getValue(); i++) + { + if (i > 0) + m_writer->emit(", "); + m_writer->emit("("); + emitType(vectorType->getElementType()); + m_writer->emit(")_slang_vector_get_element("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + m_writer->emit(i); + m_writer->emit(")"); + } + m_writer->emit("}"); + return true; + } + return false; + } + case kIROp_VectorReshape: + { + if (auto vectorType = as<IRVectorType>(inst->getDataType())) + { + m_writer->emit("_slang_vector_reshape<"); + emitType(vectorType->getElementType()); + m_writer->emit(", "); + emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General)); + m_writer->emit(">("); + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + return false; + } + case kIROp_GetElement: + { + auto getElementInst = static_cast<IRGetElement*>(inst); -bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) -{ - switch (inst->getOp()) - { - default: + IRInst* baseInst = getElementInst->getBase(); + IRType* baseType = baseInst->getDataType(); + if (as<IRVectorType>(baseType)) + { + m_writer->emit("_slang_vector_get_element("); + emitOperand(baseInst, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + else if (as<IRMatrixType>(baseType)) + { + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(baseInst, leftSide(outerPrec, prec)); + m_writer->emit(".rows["); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit("]"); + return true; + } + return false; + } + case kIROp_GetElementPtr: { - return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec); + auto getElementInst = static_cast<IRGetElement*>(inst); + + IRInst* baseInst = getElementInst->getBase(); + IRType* baseType = as<IRPtrTypeBase>(baseInst->getDataType())->getValueType(); + if (as<IRVectorType>(baseType)) + { + m_writer->emit("_slang_vector_get_element_ptr("); + emitOperand(baseInst, getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + else if (as<IRMatrixType>(baseType)) + { + m_writer->emit("&("); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(baseInst, leftSide(outerPrec, prec)); + m_writer->emit("->rows["); + emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General)); + m_writer->emit("]"); + m_writer->emit(")"); + return true; + } + return false; } case kIROp_swizzle: { @@ -2430,8 +1315,79 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut return true; } } - // try doing automatically - return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec); + + { + // Currently only works for C++ (we use {} constuction) - which means we don't need to generate a function. + // For C we need to generate suitable construction function + + const Index elementCount = Index(swizzleInst->getElementCount()); + + IRType* srcType = swizzleInst->getBase()->getDataType(); + IRVectorType* srcVecType = as<IRVectorType>(srcType); + + const UnownedStringSlice* elemNames = nullptr; + if (srcVecType) + elemNames = getVectorElementNames(srcVecType); + + IRType* retType = swizzleInst->getFullType(); + emitType(retType); + m_writer->emit("{"); + + for (Index i = 0; i < elementCount; ++i) + { + if (i > 0) + { + m_writer->emit(", "); + } + + auto outerPrec = getInfo(EmitOp::General); + + auto prec = getInfo(EmitOp::Postfix); + emitOperand(swizzleInst->getBase(), leftSide(outerPrec, prec)); + + if (srcVecType) + { + m_writer->emit("."); + + IRInst* irElementIndex = swizzleInst->getElementIndex(i); + SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit); + IRConstant* irConst = (IRConstant*)irElementIndex; + UInt elementIndex = (UInt)irConst->value.intVal; + SLANG_RELEASE_ASSERT(elementIndex < 4); + + m_writer->emit(elemNames[elementIndex]); + } + } + + m_writer->emit("}"); + } + return true; + } + case kIROp_FRem: + { + if (auto basicType = as<IRBasicType>(inst->getDataType())) + { + switch (basicType->getOp()) + { + case kIROp_HalfType: + m_writer->emit("F16_fmod("); + break; + case kIROp_FloatType: + m_writer->emit("F32_fmod("); + break; + case kIROp_DoubleType: + m_writer->emit("F64_fmod("); + break; + default: + return false; + } + emitOperand(inst->getOperand(0), getInfo(EmitOp::General)); + m_writer->emit(", "); + emitOperand(inst->getOperand(1), getInfo(EmitOp::General)); + m_writer->emit(")"); + return true; + } + return false; } case kIROp_Call: { @@ -2441,7 +1397,7 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut handleRequiredCapabilities(funcValue); // try doing automatically - return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec); + return false; } case kIROp_LookupWitness: { @@ -2562,29 +1518,6 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut } } -// We want order of built in types (typically output nothing), vector, matrix, other types -// Types that aren't output have negative indices -static Index _calcTypeOrder(IRType* a) -{ - switch (a->getOp()) - { - case kIROp_FuncType: - { - return -2; - } - case kIROp_VectorType: return 1; - case kIROp_MatrixType: return 2; - default: - { - if (as<IRBasicType>(a)) - { - return -1; - } - return 3; - } - } -} - void CPPSourceEmitter::emitPreModuleImpl() { if (m_target == CodeGenTarget::CPPSource) @@ -2604,45 +1537,6 @@ void CPPSourceEmitter::emitPreModuleImpl() m_writer->emit("using namespace SLANG_PRELUDE_NAMESPACE;\n"); m_writer->emit("#endif\n\n"); } - - // Emit generated functions and types - - if (m_target == CodeGenTarget::CSource) - { - // For C output we need to emit type definitions. - List<IRType*> types; - m_typeSet.getTypes(types); - - // Remove ones we don't need to emit - for (Index i = 0; i < types.getCount(); ++i) - { - if (_calcTypeOrder(types[i]) < 0) - { - types.fastRemoveAt(i); - --i; - } - } - - // Sort them so that vectors come before matrices and everything else after that - types.sort([&](IRType* a, IRType* b) { return _calcTypeOrder(a) < _calcTypeOrder(b); }); - - // Emit the type definitions - for (auto type : types) - { - emitTypeDefinition(type); - } - } - - { - List<const HLSLIntrinsic*> intrinsics; - m_intrinsicSet.getIntrinsics(intrinsics); - - // Emit all the intrinsics that were used - for (auto intrinsic : intrinsics) - { - _maybeEmitSpecializedOperationDefinition(intrinsic); - } - } } @@ -2980,11 +1874,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) { SLANG_UNUSED(sink); - // Setup all built in types used in the module - m_typeSet.addAllBuiltinTypes(module); - // If any matrix types are used, then we need appropriate vector types too. - m_typeSet.addVectorForMatrixTypes(); - List<EmitAction> actions; computeEmitActions(module, actions); diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index c5b9f3d9c..ec70b02b8 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -39,9 +39,6 @@ public: }; virtual void useType(IRType* type); - virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec); - virtual void emitTypeDefinition(IRType* type); - virtual void emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp); static UnownedStringSlice getBuiltinTypeName(IROp op); @@ -78,43 +75,21 @@ protected: virtual void emitVarDecorationsImpl(IRInst* var) SLANG_OVERRIDE; virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE; - virtual const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount); + const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount); // Replaceable for classes derived from CPPSourceEmitter virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out); - virtual SlangResult calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& out); - virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder); const UnownedStringSlice* getVectorElementNames(IRVectorType* vectorType); - void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp); - void _emitForwardDeclarations(const List<EmitAction>& actions); - void _emitAryDefinition(const HLSLIntrinsic* specOp); - - // Really we don't want any of these defined like they are here, they should be defined in slang stdlib - void _emitAnyAllDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitConstructConvertDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitConstructFromScalarDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitGetAtDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitInitDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - - void _emitSignature(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp); - void _emitInOutParamType(IRType* type, String const& name, IRType* valueType); - - UnownedStringSlice _getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType); - static TypeDimension _getTypeDimension(IRType* type, bool vecSwap); void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer); - UnownedStringSlice _getScalarFuncName(HLSLIntrinsic::Op operation, IRBasicType* scalarType); - - UnownedStringSlice _getFuncName(const HLSLIntrinsic* specOp); - UnownedStringSlice _getTypeName(IRType* type); SlangResult _calcCPPTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName); @@ -126,8 +101,6 @@ protected: void _emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName); - bool _tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec); - // Emit the actual definition (including intializer list) // of all the witness table objects in `pendingWitnessTableDefinitions`. void _emitWitnessTableDefinitions(); @@ -136,18 +109,9 @@ protected: void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC); void _maybeEmitExportLike(IRInst* inst); - HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount); - static bool _isVariable(IROp op); Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap; - Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap; - - IRTypeSet m_typeSet; - RefPtr<HLSLIntrinsicOpLookup> m_opLookup; - HLSLIntrinsicSet m_intrinsicSet; - - HashSet<const HLSLIntrinsic*> m_intrinsicEmitted; HashSet<IRInterfaceType*> m_interfaceTypesEmitted; diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 284652682..a151ab0e2 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -123,131 +123,6 @@ SlangResult CUDASourceEmitter::_calcCUDATextureTypeName(IRTextureTypeBase* texTy return SLANG_FAIL; } -SlangResult CUDASourceEmitter::calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) -{ - typedef HLSLIntrinsic::Op Op; - - UnownedStringSlice funcName; - - switch (op) - { - case Op::FRem: - { - if (type->getOp() == kIROp_FloatType || type->getOp() == kIROp_DoubleType) - { - funcName = HLSLIntrinsic::getInfo(op).funcName; - } - break; - } - default: break; - } - - if (funcName.getLength()) - { - outBuilder << funcName; - if (type->getOp() == kIROp_FloatType) - { - outBuilder << "f"; - } - return SLANG_OK; - } - - // Defer to the supers impl - return Super::calcScalarFuncName(op, type, outBuilder); -} - -void CUDASourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) -{ - typedef HLSLIntrinsic::Op Op; - - if (auto vecType = as <IRVectorType>(specOp->returnType)) - { - // Converting to or from half vector types is implemented prelude as convert___half functions - // Get the from type -> if it's half we ignore - - if (specOp->op == Op::ConstructConvert) - { - auto signatureType = specOp->signatureType; - - // Need to have impl of convert_float, double, int, uint, in prelude - - const auto paramCount = signatureType->getParamCount(); - SLANG_UNUSED(paramCount); - - // We have 2 'params' and param 1 is the source type - SLANG_ASSERT(paramCount == 2); - IRType* paramType = signatureType->getParamType(1); - - auto vecParamType = as<IRVectorType>(paramType); - - if (auto baseType = as<IRBasicType>(vecParamType->getElementType())) - { - if (baseType->getBaseType() == BaseType::Half) - { - return; - } - } - } - - if (auto baseType = as<IRBasicType>(vecType->getElementType())) - { - if (baseType->getBaseType() == BaseType::Half) - { - switch (specOp->op) - { - case Op::Init: - - case Op::Add: - case Op::Mul: - case Op::Div: - case Op::Sub: - - case Op::Neg: - - case Op::ConstructFromScalar: - case Op::ConstructConvert: - - case Op::Leq: - case Op::Less: - case Op::Greater: - case Op::Geq: - case Op::Neq: - case Op::Eql: - { - return; - } - } - } - } - } - - switch (specOp->op) - { - case Op::Init: - { - // Special case handling - auto returnType = specOp->returnType; - - if (auto vecType = as <IRVectorType>(returnType)) - { - if (auto baseType = as<IRBasicType>(vecType->getElementType())) - { - if (baseType->getBaseType() == BaseType::Half) - { - // Defined already in cuda-prelude.h - return; - } - } - } - - break; - } - default: break; - } - - Super::emitSpecializedOperationDefinition(specOp); -} - SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) { SLANG_UNUSED(target); @@ -322,25 +197,6 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, return Super::calcTypeName(type, target, out); } -const UnownedStringSlice* CUDASourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount) -{ - static const UnownedStringSlice normal[] = { UnownedStringSlice::fromLiteral("x"), UnownedStringSlice::fromLiteral("y"), UnownedStringSlice::fromLiteral("z"), UnownedStringSlice::fromLiteral("w") }; - static const UnownedStringSlice half3[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("z") }; - static const UnownedStringSlice half4[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("zw.x"), UnownedStringSlice::fromLiteral("zw.y")}; - - if (baseType == BaseType::Half) - { - switch (elemCount) - { - default: break; - case 3: return half3; - case 4: return half4; - } - } - - return normal; -} - void CUDASourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling) { Super::emitLayoutSemanticsImpl(inst, uniformSemanticSpelling); @@ -436,49 +292,6 @@ void CUDASourceEmitter::emitGlobalRTTISymbolPrefix() m_writer->emit("__constant__ "); } -void CUDASourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) -{ - switch (specOp->op) - { - case HLSLIntrinsic::Op::Init: - { - // For CUDA vector types we construct with make_ - - auto writer = m_writer; - - IRType* retType = specOp->returnType; - - if (IRVectorType* vecType = as<IRVectorType>(retType)) - { - if (numOperands == getIntVal(vecType->getElementCount())) - { - // Get the type name - writer->emit("make_"); - emitType(retType); - writer->emitChar('('); - - for (int i = 0; i < numOperands; ++i) - { - if (i > 0) - { - writer->emit(", "); - } - emitOperand(operands[i].get(), getInfo(EmitOp::General)); - } - - writer->emitChar(')'); - return; - } - } - // Just use the default - break; - } - default: break; - } - - return Super::emitCall(specOp, inst, operands, numOperands, inOuterPrec); -} - void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { if (decl->getMode() == kIRLoopControl_Unroll) @@ -487,59 +300,25 @@ void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d } } -static bool _areEquivalent(IRType* a, IRType* b) -{ - if (a == b) - { - return true; - } - if (a->getOp() != b->getOp()) - { - return false; - } - - switch (a->getOp()) - { - case kIROp_VectorType: - { - IRVectorType* vecA = static_cast<IRVectorType*>(a); - IRVectorType* vecB = static_cast<IRVectorType*>(b); - return getIntVal(vecA->getElementCount()) == getIntVal(vecB->getElementCount()) && - _areEquivalent(vecA->getElementType(), vecB->getElementType()); - } - case kIROp_MatrixType: - { - IRMatrixType* matA = static_cast<IRMatrixType*>(a); - IRMatrixType* matB = static_cast<IRMatrixType*>(b); - return getIntVal(matA->getColumnCount()) == getIntVal(matB->getColumnCount()) && - getIntVal(matA->getRowCount()) == getIntVal(matB->getRowCount()) && - _areEquivalent(matA->getElementType(), matB->getElementType()); - } - default: - { - return as<IRBasicType>(a) != nullptr; - } - } -} - void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value) { // When constructing a matrix or vector from a single value this is handled by the default path switch (value->getOp()) { - case kIROp_MakeMatrix: case kIROp_MakeVector: + case kIROp_MakeMatrix: { IRType* type = value->getDataType(); // If the types are the same, we can can just break down and use - if (_areEquivalent(dstType, type)) + if (dstType == type) { if (auto vecType = as<IRVectorType>(type)) { if (UInt(getIntVal(vecType->getElementCount())) == value->getOperandCount()) { + emitType(type); _emitInitializerList(vecType->getElementType(), value->getOperands(), value->getOperandCount()); return; } @@ -551,20 +330,25 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value // TODO(JS): If num cols = 1, then it *doesn't* actually return a vector. // That could be argued is an error because we want swizzling or [] to work. - IRType* rowType = m_typeSet.addVectorType(matType->getElementType(), int(colCount)); - IRVectorType* rowVectorType = as<IRVectorType>(rowType); + IRBuilder builder(matType->getModule()); + builder.setInsertBefore(matType); const Index operandCount = Index(value->getOperandCount()); // Can init, with vectors. // For now special case if the rowVectorType is not actually a vector (when elementSize == 1) - if (operandCount == rowCount || rowVectorType == nullptr) + if (operandCount == rowCount) { - // We have to output vectors - - // Emit the braces for the Matrix struct, contains an row array. + // Emit the braces for the Matrix struct, and then each row vector in its own line. + emitType(matType); m_writer->emit("{\n"); m_writer->indent(); - _emitInitializerList(rowType, value->getOperands(), rowCount); + for (Index i = 0; i < rowCount; ++i) + { + if (i != 0) m_writer->emit(",\n"); + emitType(matType->getElementType()); + m_writer->emit(colCount); + _emitInitializerList(matType->getElementType(), value->getOperand(i)->getOperands(), colCount); + } m_writer->dedent(); m_writer->emit("\n}"); return; @@ -575,21 +359,18 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value IRType* elementType = matType->getElementType(); IRUse* operands = value->getOperands(); - // Emit the braces for the Matrix struct, and the array of rows - m_writer->emit("{\n"); - m_writer->indent(); + // Emit the braces for the Matrix struct, and the elements of each row in its own line. + emitType(matType); m_writer->emit("{\n"); m_writer->indent(); for (Index i = 0; i < rowCount; ++i) { - if (i != 0) m_writer->emit(", "); - _emitInitializerList(elementType, operands, colCount); + if (i != 0) m_writer->emit(",\n"); + _emitInitializerListContent(elementType, operands, colCount); operands += colCount; } m_writer->dedent(); m_writer->emit("\n}"); - m_writer->dedent(); - m_writer->emit("\n}"); return; } } @@ -603,116 +384,157 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value emitOperand(value, getInfo(EmitOp::General)); } -void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount) +void CUDASourceEmitter::_emitInitializerListContent(IRType* elementType, IRUse* operands, Index operandCount) { - m_writer->emit("{\n"); - m_writer->indent(); - for (Index i = 0; i < operandCount; ++i) { if (i != 0) m_writer->emit(", "); _emitInitializerListValue(elementType, operands[i].get()); } - - m_writer->dedent(); - m_writer->emit("\n}"); } -void CUDASourceEmitter::_emitGetHalfVectorElement(IRInst* base, Index index, Index vecSize, const EmitOpInfo& inOuterPrec) -{ - SLANG_ASSERT(index < vecSize); - - EmitOpInfo outerPrec = inOuterPrec; - - auto prec = getInfo(EmitOp::Postfix); - const bool needClose = maybeEmitParens(outerPrec, prec); - emitOperand(base, leftSide(outerPrec, prec)); +void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount) +{ + m_writer->emit("{\n"); + m_writer->indent(); - m_writer->emit("."); + _emitInitializerListContent(elementType, operands, operandCount); - switch (vecSize) - { - default: - { - char const* kComponents[] = { "x", "y", "z", "w" }; - m_writer->emit(kComponents[index]); - break; - } - case 3: - { - char const* kComponents[] = { "xy.x", "xy.y", "z"}; - m_writer->emit(kComponents[index]); - break; - } - case 4: - { - char const* kComponents[] = { "xy.x", "xy.y", "zw.x", "zw.y" }; - m_writer->emit(kComponents[index]); - break; - } - } + m_writer->dedent(); + m_writer->emit("\n}"); +} - maybeCloseParens(needClose); +void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) +{ + if (targetIntrinsic->getDefinition().startsWith("__half")) + m_extensionTracker->requireBaseType(BaseType::Half); + Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); } bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) { switch(inst->getOp()) { - case kIROp_swizzle: + case kIROp_MakeVector: + case kIROp_MakeVectorFromScalar: { - // We need to special case for half types. - auto swizzleInst = static_cast<IRSwizzle*>(inst); - - IRInst* baseInst = swizzleInst->getBase(); - IRType* baseType = baseInst->getDataType(); - - // If we are swizzling from a built in type, - if (as<IRBasicType>(baseType)) + m_writer->emit("make_"); + emitType(inst->getDataType()); + m_writer->emit("("); + bool isFirst = true; + char xyzwNames[] = "xyzw"; + for (UInt i = 0; i < inst->getOperandCount(); i++) { - // Just use the default behavior + auto arg = inst->getOperand(i); + if (auto vectorType = as<IRVectorType>(arg->getDataType())) + { + for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++) + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(arg, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emitChar(xyzwNames[j]); + } + } + else + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } } - else if (auto vecType = as<IRVectorType>(baseType)) + m_writer->emit(")"); + return true; + } + case kIROp_FloatCast: + case kIROp_CastIntToFloat: + case kIROp_IntCast: + case kIROp_CastFloatToInt: + { + if (auto dstVectorType = as<IRVectorType>(inst->getDataType())) { - if (auto basicType = as<IRBasicType>(vecType->getElementType())) + m_writer->emit("make_"); + emitType(inst->getDataType()); + m_writer->emit("("); + bool isFirst = true; + char xyzwNames[] = "xyzw"; + for (UInt i = 0; i < inst->getOperandCount(); i++) { - if (basicType->getBaseType() == BaseType::Half) + auto arg = inst->getOperand(i); + if (auto vectorType = as<IRVectorType>(arg->getDataType())) { - const Index vecElementCount = Index(getIntVal(vecType->getElementCount())); - - const Index elementCount = Index(swizzleInst->getElementCount()); - if (elementCount == 1) - { - const Index index = Index(getIntVal(swizzleInst->getElementIndex(0))); - _emitGetHalfVectorElement(baseInst, index, vecElementCount, inOuterPrec); - } - else + for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++) { - auto outerPrec = getInfo(EmitOp::General); - - m_writer->emit("make___half"); - m_writer->emitInt64(elementCount); + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); m_writer->emit("("); - - for (Index i = 0; i < elementCount; ++i) - { - if (i) - { - m_writer->emit(", "); - } - - const Index index = Index(getIntVal(swizzleInst->getElementIndex(i))); - _emitGetHalfVectorElement(baseInst, index, vecElementCount, outerPrec); - } - + emitType(dstVectorType->getElementType()); m_writer->emit(")"); + auto outerPrec = getInfo(EmitOp::General); + auto prec = getInfo(EmitOp::Postfix); + emitOperand(arg, leftSide(outerPrec, prec)); + m_writer->emit("."); + m_writer->emitChar(xyzwNames[j]); } - return true; + } + else + { + if (isFirst) + isFirst = false; + else + m_writer->emit(", "); + m_writer->emit("("); + emitType(dstVectorType->getElementType()); + m_writer->emit(")"); + emitOperand(arg, getInfo(EmitOp::General)); } } + m_writer->emit(")"); + return true; } - break; + else if (auto matrixType = as<IRMatrixType>(inst->getDataType())) + { + m_writer->emit("make"); + emitType(inst->getDataType()); + m_writer->emit("("); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto arg = inst->getOperand(i); + if (i > 0) + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; + } + return false; + } + case kIROp_MakeMatrix: + case kIROp_MakeMatrixFromScalar: + case kIROp_MatrixReshape: + { + m_writer->emit("make"); + emitType(inst->getDataType()); + m_writer->emit("("); + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto arg = inst->getOperand(i); + if (i > 0) + m_writer->emit(", "); + emitOperand(arg, getInfo(EmitOp::General)); + } + m_writer->emit(")"); + return true; } case kIROp_MakeArray: { @@ -722,13 +544,9 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu IRType* elementType = arrayType->getElementType(); // Emit braces for the FixedArray struct. - m_writer->emit("{\n"); - m_writer->indent(); _emitInitializerList(elementType, inst->getOperands(), Index(inst->getOperandCount())); - m_writer->dedent(); - m_writer->emit("\n}"); return true; } case kIROp_WaveMaskBallot: @@ -820,7 +638,19 @@ void CUDASourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerVal void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type) { - m_writer->emit(_getTypeName(type)); + switch (type->getOp()) + { + case kIROp_VectorType: + { + auto vectorType = as<IRVectorType>(type); + m_writer->emit(getVectorPrefix(vectorType->getElementType()->getOp())); + m_writer->emit(as<IRIntLit>(vectorType->getElementCount())->getValue()); + break; + } + default: + m_writer->emit(_getTypeName(type)); + break; + } } void CUDASourceEmitter::emitRateQualifiersImpl(IRRate* rate) @@ -907,27 +737,6 @@ void CUDASourceEmitter::emitPreModuleImpl() // Emit generated types/functions writer->emit("\n"); - - { - List<IRType*> types; - m_typeSet.getTypes(IRTypeSet::Kind::Matrix, types); - - // Emit the type definitions - for (auto type : types) - { - emitTypeDefinition(type); - } - } - - { - List<const HLSLIntrinsic*> intrinsics; - m_intrinsicSet.getIntrinsics(intrinsics); - // Emit all the intrinsics that were used - for (auto intrinsic : intrinsics) - { - _maybeEmitSpecializedOperationDefinition(intrinsic); - } - } } @@ -951,22 +760,6 @@ bool CUDASourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* v void CUDASourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink) { - // Setup all built in types used in the module - m_typeSet.addAllBuiltinTypes(module); - // If any matrix types are used, then we need appropriate vector types too. - m_typeSet.addVectorForMatrixTypes(); - - // We need to add some vector intrinsics - used for calculating thread ids - { - IRType* type = m_typeSet.addVectorType(m_typeSet.getBuilder().getBasicType(BaseType::UInt), 3); - IRType* args[] = { type, type }; - - _addIntrinsic(HLSLIntrinsic::Op::Add, type, args, SLANG_COUNT_OF(args)); - _addIntrinsic(HLSLIntrinsic::Op::Mul, type, args, SLANG_COUNT_OF(args)); - } - - // TODO(JS): We may need to generate types (for example for matrices) - CLikeSourceEmitter::emitModuleImpl(module, sink); // Emit all witness table definitions. diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index ff947fe58..8a907dc7c 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -78,12 +78,9 @@ protected: virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE; virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE; - virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE; virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE; - virtual const UnownedStringSlice* getVectorElementNames(BaseType baseType, Index elemCount) SLANG_OVERRIDE; - virtual void emitGlobalRTTISymbolPrefix() SLANG_OVERRIDE; virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; @@ -92,23 +89,19 @@ protected: virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE; virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; - + virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) SLANG_OVERRIDE; // CPPSourceEmitter overrides virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) SLANG_OVERRIDE; - virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) SLANG_OVERRIDE; - - virtual void emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) SLANG_OVERRIDE; SlangResult _calcCUDATextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName); void _emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount); + void _emitInitializerListContent(IRType* elementType, IRUse* operands, Index operandCount); void _emitInitializerListValue(IRType* elementType, IRInst* value); - void _emitGetHalfVectorElement(IRInst* baseInst, Index index, Index vecSize, const EmitOpInfo& inOuterPrec); - RefPtr<CUDAExtensionTracker> m_extensionTracker; }; diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index c49265fe7..ef0d062bb 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1022,7 +1022,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr auto irModule = linkedIR.module; // Perform final simplifications to help emit logic to generate more compact code. - simplifyForEmit(irModule); + simplifyForEmit(irModule, targetRequest); metadata = linkedIR.metadata; diff --git a/source/slang/slang-hlsl-intrinsic-set.cpp b/source/slang/slang-hlsl-intrinsic-set.cpp index ea3476473..e69de29bb 100644 --- a/source/slang/slang-hlsl-intrinsic-set.cpp +++ b/source/slang/slang-hlsl-intrinsic-set.cpp @@ -1,590 +0,0 @@ -// slang-hlsl-intrinsic-set.cpp -#include "slang-hlsl-intrinsic-set.h" - -#include "slang-ir.h" -#include "slang-ir-insts.h" - -namespace Slang -{ - -/* static */const HLSLIntrinsic::Info HLSLIntrinsic::s_operationInfos[] = -{ -#define SLANG_HLSL_INTRINSIC_OP_INFO(x, funcName, numOperands) { UnownedStringSlice::fromLiteral(#x), UnownedStringSlice::fromLiteral(funcName), int8_t(numOperands) }, - SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_INFO) -}; - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicSet !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -HLSLIntrinsicSet::HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup): - m_intrinsicFreeList(sizeof(HLSLIntrinsic), SLANG_ALIGN_OF(HLSLIntrinsic), 1024), - m_typeSet(typeSet), - m_opLookup(lookup) -{ -} - -static IRBasicType* _getElementType(IRType* type) -{ - switch (type->getOp()) - { - case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break; - case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break; - default: break; - } - return dynamicCast<IRBasicType>(type); -} - -void HLSLIntrinsicSet::_calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out) -{ - IRBuilder& builder = m_typeSet->getBuilder(); - - // Check all types belong to the module - - IRModule* module = builder.getModule(); - - SLANG_UNUSED(module); - SLANG_ASSERT(returnType->getModule() == module); - - for (Index i = 0; i < argsCount; ++i) - { - SLANG_ASSERT(inArgs[i]->getModule() == module); - } - - // Set up the out - out.op = op; - out.returnType = returnType; - - switch (op) - { - case Op::GetAt: - { - IRType* argTypes[3]; - - SLANG_ASSERT(argsCount == 2 || argsCount == 3); - // TODO(JS): - // HACK! GetAt can be from getElementPtr or from getElement. Get element ptr means the return type will be - // a pointer. We don't want to deal with that, so strip it - if (returnType->getOp() == kIROp_PtrType) - { - returnType = as<IRType>(returnType->getOperand(0)); - } - - // TODO(JS): Similarly for the input parameters - for (Index i = 0; i < argsCount; ++i) - { - IRType* argType = inArgs[i]; - - if (argType->getOp() == kIROp_PtrType) - { - argType = as<IRType>(argType->getOperand(0)); - } - argTypes[i] = argType; - } - - out.returnType = returnType; - out.signatureType = builder.getFuncType(argsCount, argTypes, builder.getVoidType()); - break; - } - case Op::ConstructFromScalar: - { - //SLANG_ASSERT(argsCount == 1); - SLANG_ASSERT(argsCount == 1); - IRType* srcType = _getElementType(returnType); - IRType* argTypes[2] = { returnType, srcType }; - - out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType()); - break; - } - case Op::ConstructConvert: - { - // Make the return type a parameter, to make the signature take into account - SLANG_ASSERT(argsCount == 1); - IRType* argTypes[2] = { returnType, inArgs[0] }; - - out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType()); - break; - } - default: - { - out.signatureType = builder.getFuncType(argsCount, inArgs, builder.getVoidType()); - break; - } - } -} - -void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgTypes, Index argCount, HLSLIntrinsic& out) -{ - returnType = m_typeSet->getType(returnType); - - if (argCount <= 8) - { - IRType* args[8]; - for (Index i = 0; i < argCount; ++i) - { - args[i] = m_typeSet->getType(inArgTypes[i]); - } - _calcIntrinsic(op, returnType, args, argCount, out); - } - else - { - List<IRType*> args; - args.setCount(argCount); - - for (Index i = 0; i < argCount; ++i) - { - args[i] = m_typeSet->getType(inArgTypes[i]); - } - _calcIntrinsic(op, returnType, args.getBuffer(), argCount, out); - } -} - -void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRInst* inst, Index operandCount, HLSLIntrinsic& out) -{ - IRType* returnType = m_typeSet->getType(inst->getDataType()); - if (operandCount <= 8) - { - IRType* argTypes[8]; - for (Index i = 0; i < operandCount; ++i) - { - auto operand = inst->getOperand(i); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes, operandCount, out); - } - else - { - List<IRType*> argTypes; - argTypes.setCount(operandCount); - - for (Index i = 0; i < operandCount; ++i) - { - auto operand = inst->getOperand(i); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes.getBuffer(), operandCount, out); - } -} - -void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRUse* inArgs, Index argCount, HLSLIntrinsic& out) -{ - returnType = m_typeSet->getType(returnType); - - if (argCount <= 8) - { - IRType* argTypes[8]; - - for (Index i = 0; i < argCount; ++i) - { - auto operand = inArgs[i].get(); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes, argCount, out); - } - else - { - List<IRType*> argTypes; - argTypes.setCount(argCount); - - for (Index i = 0; i < argCount; ++i) - { - auto operand = inArgs[i].get(); - argTypes[i] = m_typeSet->getType(operand->getDataType()); - } - _calcIntrinsic(op, returnType, argTypes.getBuffer(), argCount, out); - } -} - -HLSLIntrinsic* HLSLIntrinsicSet::add(IRInst* inst) -{ - HLSLIntrinsic intrinsic; - if (SLANG_SUCCEEDED(makeIntrinsic(inst, intrinsic))) - { - return add(intrinsic); - } - return nullptr; -} - -SlangResult HLSLIntrinsicSet::makeIntrinsic(IRInst* inst, HLSLIntrinsic& out) -{ - // Mark as invalid... - out.op = Op::Invalid; - - { - // See if we can just directly convert - Op op = HLSLIntrinsicOpLookup::getOpForIROp(inst->getOp()); - - - // HACK: some cases we want to stop handling via the synthesis - // path, but only for vector and matrix types (not scalars). - // - switch( op ) - { - default: break; - - case Op::AsFloat: - case Op::AsInt: - case Op::AsUInt: - // Note: the `any()`/`all()` case can't be handled via a stdlib definition - // right now because `bool` vectors map to `int` vectors on the CUDA - // path, so that the generated `geAt` operation is incorrect. - // -// case Op::Any: -// case Op::All: - { - IRType* srcType = inst->getOperand(0)->getDataType(); - switch( srcType->getOp() ) - { - default: - break; - - case kIROp_VectorType: - case kIROp_MatrixType: - return SLANG_FAIL; - } - } - break; - } - - - if (op != Op::Invalid) - { - calcIntrinsic(op, inst, inst->getOperandCount(), out); - return SLANG_OK; - } - } - - // All the special cases - switch (inst->getOp()) - { - case kIROp_MakeVectorFromScalar: - case kIROp_MakeMatrixFromScalar: - { - SLANG_ASSERT(inst->getOperandCount() == 1); - calcIntrinsic(Op::ConstructFromScalar, inst, 1, out); - return SLANG_OK; - } - case kIROp_IntCast: - case kIROp_FloatCast: - case kIROp_CastIntToFloat: - case kIROp_CastFloatToInt: - { - IRType* dstType = inst->getDataType(); - IRType* srcType = inst->getOperand(0)->getDataType(); - - if ((dstType->getOp() == kIROp_VectorType || dstType->getOp() == kIROp_MatrixType) && - inst->getOperandCount() == 1) - { - if (as<IRBasicType>(srcType)) - { - calcIntrinsic(Op::ConstructFromScalar, inst, out); - } - else - { - SLANG_ASSERT(m_typeSet->getType(dstType) != m_typeSet->getType(srcType)); - // If it's constructed from a type conversion - calcIntrinsic(Op::ConstructConvert, inst, out); - } - return SLANG_OK; - } - else - { - // If we are constructing a basic type, we don't need an Op::Init - if (!IRBasicType::isaImpl(dstType->getOp())) - { - // Emit the 'init' intrinsic - calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out); - return SLANG_OK; - } - } - return SLANG_FAIL; - } - case kIROp_MakeVector: - case kIROp_VectorReshape: - { - if (inst->getOperandCount() == 1 && as<IRBasicType>(inst->getOperand(0)->getDataType())) - { - // This is make from scalar - calcIntrinsic(Op::ConstructFromScalar, inst, out); - } - else - { - calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out); - } - return SLANG_OK; - } - case kIROp_MakeMatrix: - case kIROp_MatrixReshape: - { - // We only emit as if it has one operand, but we can tell how many it actually has from the return type - calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out); - return SLANG_OK; - } - case kIROp_swizzle: - { - // We don't need to add swizzle function, but we do output the need for some other functions - - // For C++ we don't need to emit a swizzle function - // For C we need a construction function - auto swizzleInst = static_cast<IRSwizzle*>(inst); - - IRInst* baseInst = swizzleInst->getBase(); - IRType* baseType = baseInst->getDataType(); - - // If we are swizzling from a built in type, - if (as<IRBasicType>(baseType)) - { - // We can swizzle a scalar type to be a vector, or just a scalar - IRType* dstType = swizzleInst->getDataType(); - if (!as<IRBasicType>(dstType)) - { - // If it's a scalar make sure we have construct from scalar, because we will want to use that - SLANG_ASSERT(dstType->getOp() == kIROp_VectorType); - IRType* argTypes[] = { baseType }; - calcIntrinsic(Op::ConstructFromScalar, inst->getDataType(), argTypes, 1, out); - return SLANG_OK; - } - } - else - { - const Index elementCount = Index(swizzleInst->getElementCount()); - if (elementCount >= 1) - { - // Will need to generate a swizzle method - calcIntrinsic(Op::Swizzle, inst, out); - return SLANG_OK; - } - } - break; - } - case kIROp_GetElement: - { - IRInst* target = inst->getOperand(0); - IRType* targetType = target->getDataType(); - if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType) - { - // Specially handle this - calcIntrinsic(Op::GetAt, inst, out); - return SLANG_OK; - } - break; - } - case kIROp_GetElementPtr: - { - IRInst* target = inst->getOperand(0); - IRType* targetType = target->getDataType(); - - if (auto ptrType = as<IRPtrType>(targetType)) - { - targetType = as<IRType>(ptrType->getOperand(0)); - if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType) - { - // Specially handle this - calcIntrinsic(Op::GetAt, inst, out); - return SLANG_OK; - } - } - break; - } - case kIROp_Call: - { - IRCall* callInst = (IRCall*)inst; - auto funcValue = callInst->getCallee(); - - const Op op = m_opLookup->getOpFromTargetDecoration(funcValue); - if (op != Op::Invalid) - { - calcIntrinsic(op, inst->getDataType(), callInst->getArgs(), callInst->getArgCount(), out); - return SLANG_OK; - } - break; - } - - default: break; - } - - return SLANG_FAIL; -} - -void HLSLIntrinsicSet::getIntrinsics(List<const HLSLIntrinsic*>& out) const -{ - for (auto& intrinsic : m_intrinsicsList) - { - out.add(intrinsic); - } -} - -HLSLIntrinsic* HLSLIntrinsicSet::add(const HLSLIntrinsic& intrinsic) -{ - // Make sure it's valid(!) - SLANG_ASSERT(intrinsic.op != Op::Invalid); - - HLSLIntrinsic* copy = (HLSLIntrinsic*)m_intrinsicFreeList.allocate(); - *copy = intrinsic; - HLSLIntrinsicRef ref(copy); - HLSLIntrinsic** found = m_intrinsicsDict.TryGetValueOrAdd(ref, copy); - if (found) - { - // If we have found an intrinsic, we can free the copy - m_intrinsicFreeList.deallocate(copy); - return *found; - } - - // If we are adding an intrinsic for the first time, - // it should be added to the deduplicated list - m_intrinsicsList.add(copy); - - return copy; -} - -// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicOpLookup !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! - -HLSLIntrinsicOpLookup::HLSLIntrinsicOpLookup(): - m_slicePool(StringSlicePool::Style::Default) -{ - // Add all the operations with names (not ops like -, / etc) to the lookup map - for (int i = 0; i < SLANG_COUNT_OF(HLSLIntrinsic::s_operationInfos); ++i) - { - const auto& info = HLSLIntrinsic::getInfo(Op(i)); - UnownedStringSlice slice = info.funcName; - - if (slice.getLength() > 0 && slice[0] >= 'a' && slice[0] <= 'z') - { - auto handle = m_slicePool.add(slice); - Index index = Index(handle); - // Make sure there is space - if (index >= m_sliceToOpMap.getCount()) - { - Index oldSize = m_sliceToOpMap.getCount(); - m_sliceToOpMap.setCount(index + 1); - for (Index j = oldSize; j < index; j++) - { - m_sliceToOpMap[j] = Op::Invalid; - } - } - m_sliceToOpMap[index] = Op(i); - } - } -} - -HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpByName(const UnownedStringSlice& slice) -{ - const Index index = m_slicePool.findIndex(slice); - return (index >= 0 && index < m_sliceToOpMap.getCount()) ? m_sliceToOpMap[index] : Op::Invalid; -} - -static IRInst* _getSpecializedValue(IRSpecialize* specInst) -{ - auto base = specInst->getBase(); - auto baseGeneric = as<IRGeneric>(base); - if (!baseGeneric) - return base; - - auto lastBlock = baseGeneric->getLastBlock(); - if (!lastBlock) - return base; - - auto returnInst = as<IRReturn>(lastBlock->getTerminator()); - if (!returnInst) - return base; - - return returnInst->getVal(); -} - -HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpFromTargetDecoration(IRInst* inInst) -{ - // An intrinsic generic function will be invoked through a `specialize` instruction, - // so the callee won't directly be the thing that is decorated. We will look up - // through specializations until we can see the actual thing being called. - // - IRInst* inst = inInst; - while (auto specInst = as<IRSpecialize>(inst)) - { - inst = _getSpecializedValue(specInst); - - // If `getSpecializedValue` can't find the result value - // of the generic being specialized, then it returns - // the original instruction. This would be a disaster - // for use because this loop would go on forever. - // - // This case should never happen if the stdlib is well-formed - // and the compiler is doing its job right. - // - SLANG_ASSERT(inst != specInst); - } - - // We are just looking for the original name so we can match against it - for (auto dd : inst->getDecorations()) - { - if (auto decor = as<IRTargetIntrinsicDecoration>(dd)) - { - // TODO(JS): Should confirm that we'll always have this entry - which we need for lookups to work (we need the name - // not a targets transformation) - // - // It turns out that addCatchAllIntrinsicDecorationIfNeeded will add a target intrinsic with the - // original HLSL name, which has an empty `CapabilitySet`. - // - // It's not 100% clear this covers all the cases, but for now lets go with that - if (decor->getTargetCaps().isEmpty()) - { - Op op = getOpByName(decor->getDefinition()); - if (op != Op::Invalid) - { - return op; - } - } - } - } - - return Op::Invalid; -} - -HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IRInst* inst) -{ - switch (inst->getOp()) - { - case kIROp_Call: - { - return getOpFromTargetDecoration(inst); - } - default: break; - } - return getOpForIROp(inst->getOp()); -} - -/* static */HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IROp op) -{ - switch (op) - { - case kIROp_Add: return Op::Add; - case kIROp_Mul: return Op::Mul; - case kIROp_Sub: return Op::Sub; - case kIROp_Div: return Op::Div; - case kIROp_Lsh: return Op::Lsh; - case kIROp_Rsh: return Op::Rsh; - case kIROp_IRem: return Op::IRem; - case kIROp_FRem: return Op::FRem; - - case kIROp_Eql: return Op::Eql; - case kIROp_Neq: return Op::Neq; - case kIROp_Greater: return Op::Greater; - case kIROp_Less: return Op::Less; - case kIROp_Geq: return Op::Geq; - case kIROp_Leq: return Op::Leq; - - case kIROp_BitAnd: return Op::BitAnd; - case kIROp_BitXor: return Op::BitXor; - case kIROp_BitOr: return Op::BitOr; - - case kIROp_And: return Op::And; - case kIROp_Or: return Op::Or; - - case kIROp_Neg: return Op::Neg; - case kIROp_Not: return Op::Not; - case kIROp_BitNot: return Op::BitNot; - - case kIROp_MakeVectorFromScalar: return Op::ConstructFromScalar; - - default: return Op::Invalid; - } -} - -} diff --git a/source/slang/slang-hlsl-intrinsic-set.h b/source/slang/slang-hlsl-intrinsic-set.h index 3dc2996c1..8368491db 100644 --- a/source/slang/slang-hlsl-intrinsic-set.h +++ b/source/slang/slang-hlsl-intrinsic-set.h @@ -11,217 +11,5 @@ namespace Slang { -/* TODO(JS): Note that there are multiple methods to handle 'construction' operations. That is because 'construct' is used as a kind of -generic 'construction' for built in types including vectors and matrices. - -For the moment the cpp emit code, determines what kind of construct is needed, and has special handling for ConstructConvert and -ConstructFromScalar. - -That currently we do not see MakeVectorFromScalar - for example when we do... - -int2 fromScalar = 1; - -This appears as a construction from an int. - -That the better thing to do would be that there were IR instructions for the specific types of construction. I suppose there is a question -about whether there should be separate instructions for vector/matrix, or emit code should just use the destination type. In practice I think -it's fine that there isn't an instruction separating vector/matrix. That being the case I guess we arguably don't need MakeVectorFromScalar, -just constructXXXFromScalar. Would be good if there was a suitable name to encompass vector/matrix. -*/ -#define SLANG_HLSL_INTRINSIC_OP(x) \ - x(Invalid, "", -1) \ - x(Init, "", -1) \ - \ - x(Mul, "*", 2) \ - x(Div, "/", 2) \ - x(Add, "+", 2) \ - x(Sub, "-", 2) \ - x(Lsh, "<<", 2) \ - x(Rsh, ">>", 2) \ - x(IRem, "%", 2) \ - x(FRem, "fmod", 2) \ - \ - x(Eql, "==", 2) \ - x(Neq, "!=", 2) \ - x(Greater, ">", 2) \ - x(Less, "<", 2) \ - x(Geq, ">=", 2) \ - x(Leq, "<=", 2) \ - \ - x(BitAnd, "&", 2) \ - x(BitXor, "^", 2) \ - x(BitOr, "|" , 2) \ - \ - x(And, "&&", 2) \ - x(Or, "||", 2) \ - \ - x(Neg, "-", 1) \ - x(Not, "!", 1) \ - x(BitNot, "~", 1) \ - \ - x(Any, "any", 1) \ - x(All, "all", 1) \ - \ - x(Swizzle, "", -1) \ - \ - x(AsFloat, "asfloat", 1) \ - x(AsInt, "asint", -1) \ - x(AsUInt, "asuint", -1) \ - x(AsDouble, "asdouble", 2) \ - \ - x(ConstructConvert, "", 1) \ - x(ConstructFromScalar, "", 1) \ - \ - x(GetAt, "", 2) \ - /* end */ - -struct HLSLIntrinsic -{ - typedef HLSLIntrinsic ThisType; - - enum class Op : uint8_t - { -#define SLANG_HLSL_INTRINSIC_OP_ENUM(name, hlslName, numOperands) name, - SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_ENUM) - }; - - struct Info - { - UnownedStringSlice name; ///< The enum name - UnownedStringSlice funcName; ///< The HLSL function name (if there is one) - int8_t numOperands; ///< -1 if can't be handled automatically via amount of params - }; - - bool operator==(const ThisType& rhs) const { return op == rhs.op && returnType == rhs.returnType && signatureType == rhs.signatureType; } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - static bool isTypeScalar(IRType* type) - { - // Strip off ptr if it's an operand type - if (type->getOp() == kIROp_PtrType) - { - type = as<IRType>(type->getOperand(0)); - } - // If any are vec or matrix, then we - return !(type->getOp() == kIROp_MatrixType || type->getOp() == kIROp_VectorType); - } - - bool isScalar() const - { - Index paramCount = Index(signatureType->getParamCount()); - for (Index i = 0; i < paramCount; ++i) - { - if (!isTypeScalar(signatureType->getParamType(i))) - { - return false; - } - } - return isTypeScalar(returnType); - } - - HashCode getHashCode() const { return combineHash(int(op), combineHash(Slang::getHashCode(returnType), Slang::getHashCode(signatureType))); } - - static const Info& getInfo(Op op) { return s_operationInfos[Index(op)]; } - static const Info s_operationInfos[]; - - Op op; - IRType* returnType; - IRFuncType* signatureType; // Same as funcType, but has return type of void -}; - -/* A helper type that allows comparing pointers to HLSLIntrinsic types as if they are the values */ -struct HLSLIntrinsicRef -{ - typedef HLSLIntrinsicRef ThisType; - - HashCode getHashCode() const { return m_intrinsic->getHashCode(); } - bool operator==(const ThisType& rhs) const { return m_intrinsic == rhs.m_intrinsic || (*m_intrinsic == *rhs.m_intrinsic); } - bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } - - HLSLIntrinsicRef():m_intrinsic(nullptr) {} - HLSLIntrinsicRef(const ThisType& rhs):m_intrinsic(rhs.m_intrinsic) {} - HLSLIntrinsicRef(const HLSLIntrinsic* intrinsic): m_intrinsic(intrinsic) {} - void operator=(const ThisType& rhs) { m_intrinsic = rhs.m_intrinsic; } - - const HLSLIntrinsic* m_intrinsic; -}; - -class HLSLIntrinsicOpLookup : public RefObject -{ -public: - typedef HLSLIntrinsic::Op Op; - - Op getOpFromTargetDecoration(IRInst* inInst); - Op getOpByName(const UnownedStringSlice& slice); - - Op getOpForIROp(IRInst* inst); - - HLSLIntrinsicOpLookup(); - - /// Given an IROp returns the Op equivalent or Op::Invalid if not found - static Op getOpForIROp(IROp op); - -protected: - - StringSlicePool m_slicePool; - List<Op> m_sliceToOpMap; -}; - - -/* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic. -That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to -work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms -of other passes. -Even if it was the case when we may want to add types as part of emitting, we can't use the previously used -shared builder, so again we end up with pointers to the same things not being the same thing. - -To work around this we clone types we want to use as keys into the 'unique module'. -This is not necessary for all types though - as we assume nominal types *must* have unique pointers (that is the -definition of nominal). - -This could be handled in other ways (for example not testing equality on pointer equality). Anyway for now this -works, but probably needs to be handled in a better way. The better way may involve having guarantees about equality -enabled in other code generation and making de-duping possible in emit code. - -Note that one pro for this approach is that it does not alter the source module. That as it stands it's not necessary -for the source module to be immutable, because it is created for emitting and then discarded. - */ -class HLSLIntrinsicSet -{ -public: - typedef HLSLIntrinsic::Op Op; - - /* Note that calculating an intrinsic, the types will be added to the type set. That might mean subsequent code will - emit those types being required, which may not be the case */ - - void calcIntrinsic(Op op, IRType* returnType, IRType*const* args, Index argsCount, HLSLIntrinsic& out); - void calcIntrinsic(Op op, IRInst* inst, Index argsCount, HLSLIntrinsic& out); - void calcIntrinsic(Op op, IRType* returnType, IRUse* args, Index argCount, HLSLIntrinsic& out); - void calcIntrinsic(Op op, IRInst* inst, HLSLIntrinsic& out) { calcIntrinsic(op, inst, Index(inst->getOperandCount()), out); } - - SlangResult makeIntrinsic(IRInst* inst, HLSLIntrinsic& out); - - HLSLIntrinsic* add(const HLSLIntrinsic& intrinsic); - - /// Returns the intrinsic constructed if there is one from the inst. If not possible to construct returns nullptr. - HLSLIntrinsic* add(IRInst* inst); - - void getIntrinsics(List<const HLSLIntrinsic*>& out) const; - - HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup); - -protected: - // All calcs must go through this choke point for some special case handling. - // NOTE that this function must only be called with unique types (ie from the m_typeSet) - void _calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out); - - List<HLSLIntrinsic*> m_intrinsicsList; - Dictionary<HLSLIntrinsicRef, HLSLIntrinsic*> m_intrinsicsDict; - - FreeList m_intrinsicFreeList; ///< the storage for the intrinsics when they are in the map - - HLSLIntrinsicOpLookup* m_opLookup; - IRTypeSet* m_typeSet; -}; } // namespace Slang diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp index aba59e1de..1473bc466 100644 --- a/source/slang/slang-ir-address-analysis.cpp +++ b/source/slang/slang-ir-address-analysis.cpp @@ -79,9 +79,8 @@ namespace Slang // Deduplicate and move known address insts. for (auto block : func->getBlocks()) { - for (auto inst = block->getFirstChild(); inst;) + for (auto inst : block->getModifiableChildren()) { - auto next = inst->getNextInst(); switch (inst->getOp()) { case kIROp_Var: @@ -151,7 +150,6 @@ namespace Slang } break; } - inst = next; } } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index b5d3dba10..1f599a344 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -170,40 +170,36 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns { SLANG_ASSERT(origLogic->getOperandCount() == 2); - // TODO: Check other boolean cases. - if (as<IRBoolType>(origLogic->getDataType())) - { - // Boolean operations are not differentiable. For the linearization - // pass, we do not need to do anything but copy them over to the ne - // function. - auto primalLogic = maybeCloneForPrimalInst(builder, origLogic); - return InstPair(primalLogic, nullptr); - } - - SLANG_UNEXPECTED("Logical operation with non-boolean result"); + // Boolean operations are not differentiable. For the linearization + // pass, we do not need to do anything but copy them over to the ne + // function. + auto primalLogic = maybeCloneForPrimalInst(builder, origLogic); + return InstPair(primalLogic, nullptr); } InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr); - auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); - - if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) + auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType()); + if (primalPtrType) { - // Special case load from an `out` param, which will not have corresponding `diff` and - // `primal` insts yet. - - // TODO: Could we move this load to _after_ DifferentialPairGetPrimal, - // and DifferentialPairGetDifferential? - // - auto load = builder->emitLoad(primalPtr); - builder->markInstAsMixedDifferential(load, diffPairType); + if (auto diffPairType = as<IRDifferentialPairType>(primalPtrType->getValueType())) + { + // Special case load from an `out` param, which will not have corresponding `diff` and + // `primal` insts yet. - auto primalElement = builder->emitDifferentialPairGetPrimal(load); - auto diffElement = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); - return InstPair(primalElement, diffElement); + // TODO: Could we move this load to _after_ DifferentialPairGetPrimal, + // and DifferentialPairGetDifferential? + // + auto load = builder->emitLoad(primalPtr); + builder->markInstAsMixedDifferential(load, diffPairType); + + auto primalElement = builder->emitDifferentialPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + return InstPair(primalElement, diffElement); + } } auto primalLoad = maybeCloneForPrimalInst(builder, origLoad); @@ -492,7 +488,6 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (!diffReturnType) { - SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); diffReturnType = argBuilder.getVoidType(); } @@ -1364,6 +1359,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_Or: case kIROp_Geq: case kIROp_Leq: + case kIROp_Eql: + case kIROp_Neq: return transcribeBinaryLogic(builder, origInst); case kIROp_CastIntToFloat: @@ -1452,7 +1449,27 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_undefined: return transcribeUndefined(builder, origInst); + case kIROp_Not: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitXor: + case kIROp_BitCast: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_IRem: + case kIROp_ByteAddressBufferLoad: + case kIROp_ByteAddressBufferStore: + case kIROp_StructuredBufferLoad: + case kIROp_StructuredBufferStore: + case kIROp_Reinterpret: + case kIROp_IsType: + case kIROp_ImageSubscript: + case kIROp_ImageLoad: + case kIROp_ImageStore: case kIROp_CreateExistentialObject: + case kIROp_PackAnyValue: + case kIROp_UnpackAnyValue: + case kIROp_GetNativePtr: // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index d83ff57e4..d10a9349d 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -1256,10 +1256,8 @@ struct DiffUnzipPass diffBuilder.setInsertInto(diffBlock); List<IRInst*> splitInsts; - for (auto child = block->getFirstChild(); child;) + for (auto child : block->getModifiableChildren()) { - IRInst* nextChild = child->getNextInst(); - if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child)) { // Replace GetDiff(A) with A.d @@ -1267,7 +1265,6 @@ struct DiffUnzipPass { getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase())); getDiffInst->removeAndDeallocate(); - child = nextChild; continue; } } @@ -1278,7 +1275,6 @@ struct DiffUnzipPass { getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase())); getPrimalInst->removeAndDeallocate(); - child = nextChild; continue; } } @@ -1296,8 +1292,6 @@ struct DiffUnzipPass { child->insertAtEnd(primalBlock); } - - child = nextChild; } // Remove insts that were split. diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp index 3a8d1852a..721efadaf 100644 --- a/source/slang/slang-ir-byte-address-legalize.cpp +++ b/source/slang/slang-ir-byte-address-legalize.cpp @@ -66,11 +66,8 @@ struct ByteAddressBufferLegalizationContext break; } - - IRInst* nextChild = nullptr; - for( IRInst* child = inst->getFirstChild(); child; child = nextChild ) + for( IRInst* child : inst->getModifiableChildren()) { - nextChild = child->getNextInst(); processInstRec(child); } } diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index dbeb1e934..8b8b28f09 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -72,29 +72,29 @@ IRInst* cloneInstAndOperands( auto oldType = oldInst->getFullType(); auto newType = (IRType*) findCloneForOperand(env, oldType); - // Next we will create an empty shell of the instruction, - // with space for the operands, but no actual operand - // values attached. - // - UInt operandCount = oldInst->getOperandCount(); - auto newInst = builder->emitIntrinsicInst( - newType, - oldInst->getOp(), - operandCount, - nullptr); - - // Finally we will iterate over the operands of `oldInst` + // Next we will iterate over the operands of `oldInst` // to find their replacements and install them as // the operands of `newInst`. // - for(UInt ii = 0; ii < operandCount; ++ii) + UInt operandCount = oldInst->getOperandCount(); + + ShortList<IRInst*> newOperands; + newOperands.setCount(operandCount); + for (UInt ii = 0; ii < operandCount; ++ii) { auto oldOperand = oldInst->getOperand(ii); auto newOperand = findCloneForOperand(env, oldOperand); - newInst->getOperands()[ii].init(newInst, newOperand); + newOperands[ii] = newOperand; } + // Finally we create the inst with the updated operands. + auto newInst = builder->emitIntrinsicInst( + newType, + oldInst->getOp(), + operandCount, + newOperands.getArrayView().getBuffer()); + newInst->sourceLoc = oldInst->sourceLoc; return newInst; diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp index ca5e56b53..ad0dfda91 100644 --- a/source/slang/slang-ir-collect-global-uniforms.cpp +++ b/source/slang/slang-ir-collect-global-uniforms.cpp @@ -192,7 +192,8 @@ struct CollectGlobalUniformParametersContext // per-field layout information to reference the key we created // instead of the existing parameter (which we will be removing). // - fieldLayoutAttr->setOperand(0, fieldKey); + fieldLayoutAttr = as<IRStructFieldLayoutAttr>( + builder->replaceOperand(fieldLayoutAttr->getOperands(), fieldKey)); // If the given parameter doesn't contribute to uniform/ordinary usage, then // we can safely leave it at the global scope and potentially avoid a lot @@ -266,7 +267,7 @@ struct CollectGlobalUniformParametersContext // if(auto layoutAttr = as<IRStructFieldLayoutAttr>(user)) { - layoutAttr->setOperand(0, fieldKey); + builder->replaceOperand(layoutAttr->getOperands(), fieldKey); continue; } diff --git a/source/slang/slang-ir-com-interface.cpp b/source/slang/slang-ir-com-interface.cpp index 3e52054cd..0684cc8e6 100644 --- a/source/slang/slang-ir-com-interface.cpp +++ b/source/slang/slang-ir-com-interface.cpp @@ -105,7 +105,7 @@ void lowerComInterfaces(IRModule* module, ArtifactStyle artifactStyle, Diagnosti for (auto use : uses) { // Do the replacement - use->set(result); + builder.replaceOperand(use, result); } } } diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index 05c10b317..251b473e0 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -237,14 +237,16 @@ struct DeadCodeEliminationContext // might still be dead. // // The biggest wrinkle is that we walk the linked list of - // children/decorations a bit carefully, using a temporary - // to hold the next node, in case we eliminate one of - // the children as we go. + // children/decorations a bit carefully, because eliminating one inst + // may cause the other nodes to be hoisted out of the current scope. + // We need to cache all children in a work list to ensure they are + // properly traversed. // - IRInst* next = nullptr; - for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next ) + List<IRInst*> children; + for (auto child : inst->getDecorationsAndChildren()) + children.add(child); + for(IRInst* child : children) { - next = child->getNextInst(); changed |= eliminateDeadInstsRec(child); } } diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp index 51a677627..74efc3cb3 100644 --- a/source/slang/slang-ir-deduplicate.cpp +++ b/source/slang/slang-ir-deduplicate.cpp @@ -2,116 +2,84 @@ namespace Slang { - struct DeduplicateContext + void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap() { - SharedIRBuilder* builder; - IRInst* addValue(IRInst* value) - { - if (!value) return nullptr; - if (as<IRType>(value)) - return addTypeValue(value); - if (auto constValue = as<IRConstant>(value)) - return addConstantValue(constValue); - return value; - } - IRInst* addConstantValue(IRConstant* value) - { - IRConstantKey key = { value }; - value->setFullType((IRType*)addValue(value->getFullType())); - if (auto newValue = builder->getConstantMap().TryGetValue(key)) - return *newValue; - builder->getConstantMap()[key] = value; - return value; - } - IRInst* addTypeValue(IRInst* value) - { - // Do not deduplicate struct or interface types. - switch (value->getOp()) - { - case kIROp_StructType: - case kIROp_InterfaceType: - return value; - default: - break; - } + } - for (UInt i = 0; i < value->getOperandCount(); i++) - { - value->setOperand(i, addValue(value->getOperand(i))); - } - value->setFullType((IRType*)addValue(value->getFullType())); - IRInstKey key = { value }; - if (auto newValue = builder->getGlobalValueNumberingMap().TryGetValue(key)) - return *newValue; - builder->getGlobalValueNumberingMap()[key] = value; - return value; - } - }; - void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap() + void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst) + { + oldInst->replaceUsesWith(newInst); + } + + void SharedIRBuilder::removeHoistableInstFromGlobalNumberingMap(IRInst* instToRemove) { - DeduplicateContext context; - context.builder = this; - m_constantMap.Clear(); - m_globalValueNumberingMap.Clear(); - List<IRInst*> instToRemove; - for (auto inst : m_module->getGlobalInsts()) + HashSet<IRInst*> userWorkListSet; + List<IRInst*> userWorkList; + auto addToWorkList = [&](IRInst* i) { - if (auto constVal = as<IRConstant>(inst)) - { - auto newConst = context.addConstantValue(constVal); - if (newConst != constVal) - { - constVal->replaceUsesWith(newConst); - instToRemove.add(constVal); - } - } - } - for (auto inst : m_module->getGlobalInsts()) + if (userWorkListSet.Add(i)) + userWorkList.add(i); + }; + addToWorkList(instToRemove); + for (Index i = 0; i < userWorkList.getCount(); i++) { - if (as<IRType>(inst) || as<IRSpecialize>(inst)) + auto inst = userWorkList[i]; + if (getIROpInfo(inst->getOp()).isHoistable()) { - auto newInst = context.addTypeValue(inst); - if (newInst != inst) + _removeGlobalNumberingEntry(inst); + for (auto use = inst->firstUse; use; use = use->nextUse) { - inst->replaceUsesWith(newInst); - instToRemove.add(inst); + addToWorkList(use->getUser()); } } } - for (auto inst : instToRemove) - inst->removeAndDeallocate(); } - void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst) + void addHoistableInst( + IRBuilder* builder, + IRInst* inst); + + void SharedIRBuilder::tryHoistInst(IRInst* inst) { - List<IRUse*> uses; - for (auto use = oldInst->firstUse; use; use = use->nextUse) - { - uses.add(use); - } + List<IRInst*> workList; + HashSet<IRInst*> workListSet; + workList.add(inst); + workListSet.Add(inst); + IRBuilder builder(inst->getModule()); - bool shouldUpdateGlobalNumberedCache = false; - for (auto use : uses) + for (Index i = 0; i < workList.getCount(); i++) { - use->set(newInst); - // depending on the type of the user inst, we may need to rebuild and update the global - // numbering cache. - if (isGloballyNumberedInst(use->getUser())) + auto item = workList[i]; + + // Does inst no longer depend on anything defined locally? + // If so we should hoist it. + bool shouldHoist = false; + for (UInt a = 0; a < item->getOperandCount(); a++) { - shouldUpdateGlobalNumberedCache = true; + auto opParent = item->getOperand(a)->getParent(); + if (opParent != item->getParent()) + { + shouldHoist = true; + break; + } } - } - oldInst->removeAndDeallocate(); - if (shouldUpdateGlobalNumberedCache) - { - deduplicateAndRebuildGlobalNumberingMap(); - } - } - bool SharedIRBuilder::isGloballyNumberedInst(IRInst* inst) - { - if (!inst->getParent() || inst->getParent()->getOp() != kIROp_Module) - return false; - return m_globalValueNumberingMap.ContainsKey(IRInstKey{inst}); + // Hoisting this inst + if (shouldHoist) + { + item->removeFromParent(); + addHoistableInst(&builder, item); + + // Continue to consider all users for hoisting. + for (auto use = item->firstUse; use; use = use->nextUse) + { + if (getIROpInfo(use->getUser()->getOp()).isHoistable()) + { + if (workListSet.Add(use->getUser())) + workList.add(use->getUser()); + } + } + } + } } } diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 0dcd437fe..55d120228 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1791,7 +1791,7 @@ void legalizeMeshOutputParam( // the writes may only be writing to parts of the output struct, or may not // be writes at all (i.e. being passed as an out paramter). // - traverseUses(g, [&](IRInst* u) + traverseUsers(g, [&](IRInst* u) { auto l = as<IRLoad>(u); SLANG_EXPECT(l, "Mesh Output sentinel parameter wasn't used in a load"); @@ -1811,7 +1811,7 @@ void legalizeMeshOutputParam( return; } // Otherwise, go through the uses one by one and see what we can do - traverseUses(a, [&](IRInst* s) + traverseUsers(a, [&](IRInst* s) { IRBuilderInsertLocScope locScope{builder}; builder->setInsertBefore(s); @@ -2022,7 +2022,7 @@ void legalizeMeshOutputParam( for(auto builtin : builtins) { - traverseUses(builtin.param, [&](IRInst* u) + traverseUsers(builtin.param, [&](IRInst* u) { auto p = as<IRGetElementPtr>(u); SLANG_EXPECT(p, "Mesh Output sentinel parameter wasn't used as an array"); diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 7fc977170..643acdbb8 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -53,10 +53,8 @@ struct InliningPassBase // so that even if `child` gets removed (because of inlining) // we automatically start at the next instruction after it. // - IRInst* next = nullptr; - for( auto child = inst->getFirstChild(); child; child = next ) + for (auto child : inst->getModifiableChildren()) { - next = child->getNextInst(); changed |= considerAllCallSitesRec(child); } return changed; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 788e02c90..35877d680 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -10,6 +10,8 @@ #define PARENT kIROpFlag_Parent #define USE_OTHER kIROpFlag_UseOther +#define HOISTABLE kIROpFlag_Hoistable +#define GLOBAL kIROpFlag_Global INST(Nop, nop, 0, 0) @@ -17,7 +19,7 @@ INST(Nop, nop, 0, 0) /* Basic Types */ - #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, 0) + #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, HOISTABLE) FOREACH_BASE_TYPE(DEFINE_BASE_TYPE_INST) #undef DEFINE_BASE_TYPE_INST INST(AfterBaseType, afterBaseType, 0, 0) @@ -25,42 +27,42 @@ INST(Nop, nop, 0, 0) INST_RANGE(BasicType, VoidType, AfterBaseType) /* StringTypeBase */ - INST(StringType, String, 0, 0) - INST(NativeStringType, NativeString, 0, 0) + INST(StringType, String, 0, HOISTABLE) + INST(NativeStringType, NativeString, 0, HOISTABLE) INST_RANGE(StringTypeBase, StringType, NativeStringType) - INST(CapabilitySetType, CapabilitySet, 0, 0) + INST(CapabilitySetType, CapabilitySet, 0, HOISTABLE) - INST(DynamicType, DynamicType, 0, 0) + INST(DynamicType, DynamicType, 0, HOISTABLE) - INST(AnyValueType, AnyValueType, 1, 0) + INST(AnyValueType, AnyValueType, 1, HOISTABLE) - INST(RawPointerType, RawPointerType, 0, 0) - INST(RTTIPointerType, RTTIPointerType, 1, 0) + INST(RawPointerType, RawPointerType, 0, HOISTABLE) + INST(RTTIPointerType, RTTIPointerType, 1, HOISTABLE) INST(AfterRawPointerTypeBase, AfterRawPointerTypeBase, 0, 0) INST_RANGE(RawPointerTypeBase, RawPointerType, AfterRawPointerTypeBase) /* ArrayTypeBase */ - INST(ArrayType, Array, 2, 0) - INST(UnsizedArrayType, UnsizedArray, 1, 0) + INST(ArrayType, Array, 2, HOISTABLE) + INST(UnsizedArrayType, UnsizedArray, 1, HOISTABLE) INST_RANGE(ArrayTypeBase, ArrayType, UnsizedArrayType) - INST(FuncType, Func, 0, 0) - INST(BasicBlockType, BasicBlock, 0, 0) + INST(FuncType, Func, 0, HOISTABLE) + INST(BasicBlockType, BasicBlock, 0, HOISTABLE) - INST(VectorType, Vec, 2, 0) - INST(MatrixType, Mat, 3, 0) + INST(VectorType, Vec, 2, HOISTABLE) + INST(MatrixType, Mat, 3, HOISTABLE) - INST(TaggedUnionType, TaggedUnion, 0, 0) + INST(TaggedUnionType, TaggedUnion, 0, HOISTABLE) - INST(ConjunctionType, Conjunction, 0, 0) - INST(AttributedType, Attributed, 0, 0) - INST(ResultType, Result, 2, 0) - INST(OptionalType, Optional, 1, 0) + INST(ConjunctionType, Conjunction, 0, HOISTABLE) + INST(AttributedType, Attributed, 0, HOISTABLE) + INST(ResultType, Result, 2, HOISTABLE) + INST(OptionalType, Optional, 1, HOISTABLE) - INST(DifferentialPairType, DiffPair, 1, 0) - INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, 0) + INST(DifferentialPairType, DiffPair, 1, HOISTABLE) + INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) /* BindExistentialsTypeBase */ @@ -70,58 +72,58 @@ INST(Nop, nop, 0, 0) // where each `Ti, wi` pair represents the concrete type // and witness table to plug in for parameter `i`. // - INST(BindExistentialsType, BindExistentials, 1, 0) + INST(BindExistentialsType, BindExistentials, 1, HOISTABLE) // An `BindInterface<B, T0, w0>` represents the special case // of a `BindExistentials` where the type `B` is known to be // an interface type. // - INST(BoundInterfaceType, BoundInterface, 3, 0) + INST(BoundInterfaceType, BoundInterface, 3, HOISTABLE) INST_RANGE(BindExistentialsTypeBase, BindExistentialsType, BoundInterfaceType) /* Rate */ - INST(ConstExprRate, ConstExpr, 0, 0) - INST(GroupSharedRate, GroupShared, 0, 0) - INST(ActualGlobalRate, ActualGlobalRate, 0, 0) + INST(ConstExprRate, ConstExpr, 0, HOISTABLE) + INST(GroupSharedRate, GroupShared, 0, HOISTABLE) + INST(ActualGlobalRate, ActualGlobalRate, 0, HOISTABLE) INST_RANGE(Rate, ConstExprRate, GroupSharedRate) - INST(RateQualifiedType, RateQualified, 2, 0) + INST(RateQualifiedType, RateQualified, 2, HOISTABLE) // Kinds represent the "types of types." // They should not really be nested under `IRType` // in the overall hierarchy, but we can fix that later. // /* Kind */ - INST(TypeKind, Type, 0, 0) - INST(RateKind, Rate, 0, 0) - INST(GenericKind, Generic, 0, 0) + INST(TypeKind, Type, 0, HOISTABLE) + INST(RateKind, Rate, 0, HOISTABLE) + INST(GenericKind, Generic, 0, HOISTABLE) INST_RANGE(Kind, TypeKind, GenericKind) /* PtrTypeBase */ - INST(PtrType, Ptr, 1, 0) - INST(RefType, Ref, 1, 0) + INST(PtrType, Ptr, 1, HOISTABLE) + INST(RefType, Ref, 1, HOISTABLE) // A `PsuedoPtr<T>` logically represents a pointer to a value of type // `T` on a platform that cannot support pointers. The expectation // is that the "pointer" will be legalized away by storing a value // of type `T` somewhere out-of-line. - INST(PseudoPtrType, PseudoPtr, 1, 0) + INST(PseudoPtrType, PseudoPtr, 1, HOISTABLE) /* OutTypeBase */ - INST(OutType, Out, 1, 0) - INST(InOutType, InOut, 1, 0) + INST(OutType, Out, 1, HOISTABLE) + INST(InOutType, InOut, 1, HOISTABLE) INST_RANGE(OutTypeBase, OutType, InOutType) INST_RANGE(PtrTypeBase, PtrType, InOutType) // A ComPtr<T> type is treated as a opaque type that represents a reference-counted handle to a COM object. - INST(ComPtrType, ComPtr, 1, 0) + INST(ComPtrType, ComPtr, 1, HOISTABLE) // A NativePtr<T> type represents a native pointer to a managed resource. - INST(NativePtrType, NativePtr, 1, 0) + INST(NativePtrType, NativePtr, 1, HOISTABLE) /* SamplerStateTypeBase */ - INST(SamplerStateType, SamplerState, 0, 0) - INST(SamplerComparisonStateType, SamplerComparisonState, 0, 0) + INST(SamplerStateType, SamplerState, 0, HOISTABLE) + INST(SamplerComparisonStateType, SamplerComparisonState, 0, HOISTABLE) INST_RANGE(SamplerStateTypeBase, SamplerStateType, SamplerComparisonStateType) // TODO: Why do we have all this hierarchy here, when everything @@ -131,11 +133,11 @@ INST(Nop, nop, 0, 0) /* TextureTypeBase */ // NOTE! TextureFlavor::Flavor is stored in 'other' bits for these types. /* TextureType */ - INST(TextureType, TextureType, 0, USE_OTHER) + INST(TextureType, TextureType, 0, USE_OTHER | HOISTABLE) /* TextureSamplerType */ - INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER) + INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER | HOISTABLE) /* GLSLImageType */ - INST(GLSLImageType, GLSLImageType, 0, USE_OTHER) + INST(GLSLImageType, GLSLImageType, 0, USE_OTHER | HOISTABLE) INST_RANGE(TextureTypeBase, TextureType, GLSLImageType) INST_RANGE(ResourceType, TextureType, GLSLImageType) INST_RANGE(ResourceTypeBase, TextureType, GLSLImageType) @@ -143,53 +145,53 @@ INST(Nop, nop, 0, 0) /* UntypedBufferResourceType */ /* ByteAddressBufferTypeBase */ - INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, 0) - INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, 0) - INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, 0) + INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, HOISTABLE) + INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, HOISTABLE) + INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, HOISTABLE) INST_RANGE(ByteAddressBufferTypeBase, HLSLByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType) - INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, 0) + INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, HOISTABLE) INST_RANGE(UntypedBufferResourceType, HLSLByteAddressBufferType, RaytracingAccelerationStructureType) /* HLSLPatchType */ - INST(HLSLInputPatchType, InputPatch, 2, 0) - INST(HLSLOutputPatchType, OutputPatch, 2, 0) + INST(HLSLInputPatchType, InputPatch, 2, HOISTABLE) + INST(HLSLOutputPatchType, OutputPatch, 2, HOISTABLE) INST_RANGE(HLSLPatchType, HLSLInputPatchType, HLSLOutputPatchType) - INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, 0) + INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, HOISTABLE) /* BuiltinGenericType */ /* HLSLStreamOutputType */ - INST(HLSLPointStreamType, PointStream, 1, 0) - INST(HLSLLineStreamType, LineStream, 1, 0) - INST(HLSLTriangleStreamType, TriangleStream, 1, 0) + INST(HLSLPointStreamType, PointStream, 1, HOISTABLE) + INST(HLSLLineStreamType, LineStream, 1, HOISTABLE) + INST(HLSLTriangleStreamType, TriangleStream, 1, HOISTABLE) INST_RANGE(HLSLStreamOutputType, HLSLPointStreamType, HLSLTriangleStreamType) /* MeshOutputType */ - INST(VerticesType, Vertices, 2, 0) - INST(IndicesType, Indices, 2, 0) - INST(PrimitivesType, Primitives, 2, 0) + INST(VerticesType, Vertices, 2, HOISTABLE) + INST(IndicesType, Indices, 2, HOISTABLE) + INST(PrimitivesType, Primitives, 2, HOISTABLE) INST_RANGE(MeshOutputType, VerticesType, PrimitivesType) /* HLSLStructuredBufferTypeBase */ - INST(HLSLStructuredBufferType, StructuredBuffer, 0, 0) - INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, 0) - INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, 0) - INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, 0) - INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, 0) + INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE) + INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE) + INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, HOISTABLE) + INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, HOISTABLE) + INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, HOISTABLE) INST_RANGE(HLSLStructuredBufferTypeBase, HLSLStructuredBufferType, HLSLConsumeStructuredBufferType) /* PointerLikeType */ /* ParameterGroupType */ /* UniformParameterGroupType */ - INST(ConstantBufferType, ConstantBuffer, 1, 0) - INST(TextureBufferType, TextureBuffer, 1, 0) - INST(ParameterBlockType, ParameterBlock, 1, 0) - INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, 0) + INST(ConstantBufferType, ConstantBuffer, 1, HOISTABLE) + INST(TextureBufferType, TextureBuffer, 1, HOISTABLE) + INST(ParameterBlockType, ParameterBlock, 1, HOISTABLE) + INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, HOISTABLE) INST_RANGE(UniformParameterGroupType, ConstantBufferType, GLSLShaderStorageBufferType) /* VaryingParameterGroupType */ - INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, 0) - INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, 0) + INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, HOISTABLE) + INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, HOISTABLE) INST_RANGE(VaryingParameterGroupType, GLSLInputParameterGroupType, GLSLOutputParameterGroupType) INST_RANGE(ParameterGroupType, ConstantBufferType, GLSLOutputParameterGroupType) INST_RANGE(PointerLikeType, ConstantBufferType, GLSLOutputParameterGroupType) @@ -209,28 +211,28 @@ INST(Nop, nop, 0, 0) // INST(StructType, struct, 0, PARENT) INST(ClassType, class, 0, PARENT) -INST(InterfaceType, interface, 0, 0) -INST(AssociatedType, associated_type, 0, 0) -INST(ThisType, this_type, 0, 0) -INST(RTTIType, rtti_type, 0, 0) -INST(RTTIHandleType, rtti_handle_type, 0, 0) -INST(TupleType, tuple_type, 0, 0) +INST(InterfaceType, interface, 0, GLOBAL) +INST(AssociatedType, associated_type, 0, HOISTABLE) +INST(ThisType, this_type, 0, HOISTABLE) +INST(RTTIType, rtti_type, 0, HOISTABLE) +INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE) +INST(TupleType, tuple_type, 0, HOISTABLE) // A type that identifies it's contained type as being emittable as `spirv_literal. -INST(SPIRVLiteralType, spirvLiteralType, 1, 0) +INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE) // A TypeType-typed IRValue represents a IRType. // It is used to represent a type parameter/argument in a generics. -INST(TypeType, type_t, 0, 0) +INST(TypeType, type_t, 0, HOISTABLE) /*IRWitnessTableTypeBase*/ // An `IRWitnessTable` has type `WitnessTableType`. - INST(WitnessTableType, witness_table_t, 1, 0) + INST(WitnessTableType, witness_table_t, 1, HOISTABLE) // An integer type representing a witness table for targets where // witness tables are represented as integer IDs. This type is used // during the lower-generics pass while generating dynamic dispatch // code and will eventually lower into an uint type. - INST(WitnessTableIDType, witness_table_id_t, 1, 0) + INST(WitnessTableIDType, witness_table_id_t, 1, HOISTABLE) INST_RANGE(WitnessTableTypeBase, WitnessTableType, WitnessTableIDType) INST_RANGE(Type, VoidType, WitnessTableIDType) @@ -240,14 +242,14 @@ INST_RANGE(Type, VoidType, WitnessTableIDType) INST(Generic, generic, 0, PARENT) INST_RANGE(GlobalValueWithParams, Func, Generic) - INST(GlobalVar, global_var, 0, 0) + INST(GlobalVar, global_var, 0, GLOBAL) INST_RANGE(GlobalValueWithCode, Func, GlobalVar) -INST(GlobalParam, global_param, 0, 0) -INST(GlobalConstant, globalConstant, 0, 0) +INST(GlobalParam, global_param, 0, GLOBAL) +INST(GlobalConstant, globalConstant, 0, GLOBAL) -INST(StructKey, key, 0, 0) -INST(GlobalGenericParam, global_generic_param, 0, 0) +INST(StructKey, key, 0, GLOBAL) +INST(GlobalGenericParam, global_generic_param, 0, GLOBAL) INST(WitnessTable, witness_table, 0, 0) INST(GlobalHashedStringLiterals, global_hashed_string_literals, 0, 0) @@ -265,7 +267,7 @@ INST(Block, block, 0, PARENT) INST(VoidLit, void_constant, 0, 0) INST_RANGE(Constant, BoolLit, VoidLit) -INST(CapabilitySet, capabilitySet, 0, 0) +INST(CapabilitySet, capabilitySet, 0, HOISTABLE) INST(undefined, undefined, 0, 0) @@ -279,10 +281,9 @@ INST(MakeDifferentialPair, MakeDiffPair, 2, 0) INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) -INST(Specialize, specialize, 2, 0) -INST(LookupWitness, lookupWitness, 2, 0) +INST(Specialize, specialize, 2, HOISTABLE) +INST(LookupWitness, lookupWitness, 2, HOISTABLE) INST(GetSequentialID, GetSequentialID, 1, 0) -INST(lookup_witness_table, lookup_witness_table, 2, 0) INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0) INST(AllocObj, allocObj, 0, 0) @@ -317,7 +318,7 @@ INST(PackAnyValue, packAnyValue, 1, 0) INST(UnpackAnyValue, unpackAnyValue, 1, 0) INST(WitnessTableEntry, witness_table_entry, 2, 0) -INST(InterfaceRequirementEntry, interface_req_entry, 2, 0) +INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL) INST(Param, param, 0, 0) INST(StructField, field, 2, 0) @@ -558,8 +559,6 @@ INST(BitNot, bitnot, 1, 0) INST(Select, select, 3, 0) -INST(Dot, dot, 2, 0) - INST(GetStringHash, getStringHash, 1, 0) INST(WaveGetActiveMask, waveGetActiveMask, 0, 0) @@ -880,40 +879,40 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) /* Layout */ - INST(VarLayout, varLayout, 1, 0) + INST(VarLayout, varLayout, 1, HOISTABLE) /* TypeLayout */ - INST(TypeLayoutBase, typeLayout, 0, 0) - INST(ParameterGroupTypeLayout, parameterGroupTypeLayout, 2, 0) - INST(ArrayTypeLayout, arrayTypeLayout, 1, 0) - INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, 0) - INST(MatrixTypeLayout, matrixTypeLayout, 1, 0) - INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, 0) - INST(ExistentialTypeLayout, existentialTypeLayout, 0, 0) - INST(StructTypeLayout, structTypeLayout, 0, 0) + INST(TypeLayoutBase, typeLayout, 0, HOISTABLE) + INST(ParameterGroupTypeLayout, parameterGroupTypeLayout, 2, HOISTABLE) + INST(ArrayTypeLayout, arrayTypeLayout, 1, HOISTABLE) + INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, HOISTABLE) + INST(MatrixTypeLayout, matrixTypeLayout, 1, HOISTABLE) + INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, HOISTABLE) + INST(ExistentialTypeLayout, existentialTypeLayout, 0, HOISTABLE) + INST(StructTypeLayout, structTypeLayout, 0, HOISTABLE) INST_RANGE(TypeLayout, TypeLayoutBase, StructTypeLayout) - INST(EntryPointLayout, EntryPointLayout, 1, 0) + INST(EntryPointLayout, EntryPointLayout, 1, HOISTABLE) INST_RANGE(Layout, VarLayout, EntryPointLayout) /* Attr */ - INST(PendingLayoutAttr, pendingLayout, 1, 0) - INST(StageAttr, stage, 1, 0) - INST(StructFieldLayoutAttr, fieldLayout, 2, 0) - INST(CaseTypeLayoutAttr, caseLayout, 1, 0) - INST(UNormAttr, unorm, 0, 0) - INST(SNormAttr, snorm, 0, 0) - INST(NoDiffAttr, no_diff, 0, 0) + INST(PendingLayoutAttr, pendingLayout, 1, HOISTABLE) + INST(StageAttr, stage, 1, HOISTABLE) + INST(StructFieldLayoutAttr, fieldLayout, 2, HOISTABLE) + INST(CaseTypeLayoutAttr, caseLayout, 1, HOISTABLE) + INST(UNormAttr, unorm, 0, HOISTABLE) + INST(SNormAttr, snorm, 0, HOISTABLE) + INST(NoDiffAttr, no_diff, 0, HOISTABLE) /* SemanticAttr */ - INST(UserSemanticAttr, userSemantic, 2, 0) - INST(SystemValueSemanticAttr, systemValueSemantic, 2, 0) + INST(UserSemanticAttr, userSemantic, 2, HOISTABLE) + INST(SystemValueSemanticAttr, systemValueSemantic, 2, HOISTABLE) INST_RANGE(SemanticAttr, UserSemanticAttr, SystemValueSemanticAttr) /* LayoutResourceInfoAttr */ - INST(TypeSizeAttr, size, 2, 0) - INST(VarOffsetAttr, offset, 2, 0) + INST(TypeSizeAttr, size, 2, HOISTABLE) + INST(VarOffsetAttr, offset, 2, HOISTABLE) INST_RANGE(LayoutResourceInfoAttr, TypeSizeAttr, VarOffsetAttr) - INST(FuncThrowTypeAttr, FuncThrowType, 1, 0) + INST(FuncThrowTypeAttr, FuncThrowType, 1, HOISTABLE) INST_RANGE(Attr, PendingLayoutAttr, FuncThrowTypeAttr) /* Liveness */ diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 7bc711f97..7a2e1f0e2 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2436,106 +2436,37 @@ struct IRLiveRangeEnd : IRLiveRangeMarker IR_LEAF_ISA(LiveRangeEnd); }; -// Description of an instruction to be used for global value numbering -struct IRInstKey -{ - IRInst* inst; - - HashCode getHashCode(); -}; - -bool operator==(IRInstKey const& left, IRInstKey const& right); - -struct IRConstantKey -{ - IRConstant* inst; - - bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } - HashCode getHashCode() const { return inst->getHashCode(); } -}; - -struct SharedIRBuilder -{ -public: - SharedIRBuilder() - {} - - explicit SharedIRBuilder(IRModule* module) - { - init(module); - } - - void init(IRModule* module) - { - m_module = module; - m_session = module->getSession(); - - m_globalValueNumberingMap.Clear(); - m_constantMap.Clear(); - } - - IRModule* getModule() - { - return m_module; - } - - Session* getSession() - { - return m_session; - } - - void insertBlockAlongEdge(IREdge const& edge); - - // Rebuilds `globalValueNumberingMap`. This is necessary if any existing - // keys are modified (thus its hash code is changed). - void deduplicateAndRebuildGlobalNumberingMap(); - - // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement. - void replaceGlobalInst(IRInst* oldInst, IRInst* newInst); - - typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap; - typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap; - - GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; } - ConstantMap& getConstantMap() { return m_constantMap; } - - bool isGloballyNumberedInst(IRInst* inst); - -private: - // The module that will own all of the IR - IRModule* m_module; - - // The parent compilation session - Session* m_session; - - GlobalValueNumberingMap m_globalValueNumberingMap; - ConstantMap m_constantMap; -}; - struct IRBuilderSourceLocRAII; struct IRBuilder { private: - /// Shared state for all IR builders working on the same module - SharedIRBuilder* m_sharedBuilder = nullptr; + /// Shared state for all IR builders working on the same module + SharedIRBuilder* m_sharedBuilder = nullptr; - /// Default location for inserting new instructions as they are emitted + IRModule* m_module = nullptr; + + /// Default location for inserting new instructions as they are emitted IRInsertLoc m_insertLoc; - /// Information that controls how source locations are associatd with instructions that get emitted + /// Information that controls how source locations are associatd with instructions that get emitted IRBuilderSourceLocRAII* m_sourceLocInfo = nullptr; public: IRBuilder() {} + explicit IRBuilder(IRModule* module) + : m_module(module) + , m_sharedBuilder(module->getSharedBuilder()) + {} + explicit IRBuilder(SharedIRBuilder* sharedBuilder) - : m_sharedBuilder(sharedBuilder) + : IRBuilder(sharedBuilder->getModule()) {} explicit IRBuilder(SharedIRBuilder& sharedBuilder) - : m_sharedBuilder(&sharedBuilder) + : IRBuilder(sharedBuilder.getModule()) {} void init(SharedIRBuilder* sharedBuilder) @@ -2550,17 +2481,17 @@ public: SharedIRBuilder* getSharedBuilder() const { - return m_sharedBuilder; + return m_module->getSharedBuilder(); } Session* getSession() const { - return m_sharedBuilder->getSession(); + return m_module->getSession(); } IRModule* getModule() const { - return m_sharedBuilder->getModule(); + return m_module; } IRInsertLoc const& getInsertLoc() const { return m_insertLoc; } @@ -2597,6 +2528,18 @@ public: IRConstant* _findOrEmitConstant( IRConstant& keyInst); + /// Implements a special case of inst creation (intended only for calling from `_createInst`) + /// that returns an matching existing hoistable inst if it exists, otherwise it creates the inst and + /// add it to the global numbering map. + IRInst* _findOrEmitHoistableInst( + IRType* type, + IROp op, + Int fixedArgCount, + IRInst* const* fixedArgs, + Int varArgListCount, + Int const* listArgCounts, + IRInst* const* const* listArgs); + /// Create a new instruction with the given `type` and `op`, with an allocated /// size of at least `minSizeInBytes`, and with its operand list initialized /// from the provided lists of "fixed" and "variable" operands. @@ -2615,7 +2558,8 @@ public: /// size. /// /// Note: This is an extremely low-level operation and clients of an `IRBuilder` - /// should not be using it when other options are available. + /// should not be using it when other options are available. This is also where + /// all insts creation are bottlenecked through. /// IRInst* _createInst( size_t minSizeInBytes, @@ -2654,6 +2598,12 @@ public: void addInst(IRInst* inst); + // Replace the operand of a potentially hoistable inst. + // If the hoistable inst become duplicate of an existing inst, + // all uses of the original user will be replaced with the existing inst. + // The function returns the new user after any potential updates. + IRInst* replaceOperand(IRUse* use, IRInst* newValue); + IRInst* getBoolValue(bool value); IRInst* getIntValue(IRType* type, IRIntegerValue value); IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); @@ -2918,6 +2868,20 @@ public: UInt argCount, IRInst* const* args); + IRInst* createIntrinsicInst( + IRType* type, + IROp op, + IRInst* operand, + UInt operandCount, + IRInst* const* operands); + + IRInst* createIntrinsicInst( + IRType* type, + IROp op, + UInt operandListCount, + UInt const* listOperandCounts, + IRInst* const* const* listOperands); + IRInst* emitIntrinsicInst( IRType* type, IROp op, @@ -3001,6 +2965,10 @@ public: UInt argCount, IRInst* const* args); + IRInst* emitMakeMatrixFromScalar( + IRType* type, + IRInst* scalarValue); + IRInst* emitMakeArray( IRType* type, UInt argCount, @@ -3066,31 +3034,6 @@ public: IRInst* emitReinterpret(IRInst* type, IRInst* value); - IRInst* findOrAddInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands); - - IRInst* findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands); - IRInst* findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandCount, - IRInst* const* operands); - IRInst* findOrEmitHoistableInst( - IRType* type, - IROp op, - IRInst* operand, - UInt operandCount, - IRInst* const* operands); - IRFunc* createFunc(); IRGlobalVar* createGlobalVar( IRType* valueType); @@ -3841,10 +3784,6 @@ public: } }; -void addHoistableInst( - IRBuilder* builder, - IRInst* inst); - // Helper to establish the source location that will be used // by an IRBuilder. struct IRBuilderSourceLocRAII diff --git a/source/slang/slang-ir-legalize-mesh-outputs.cpp b/source/slang/slang-ir-legalize-mesh-outputs.cpp index 7c6d256ab..db4d74ddb 100644 --- a/source/slang/slang-ir-legalize-mesh-outputs.cpp +++ b/source/slang/slang-ir-legalize-mesh-outputs.cpp @@ -25,7 +25,7 @@ void legalizeMeshOutputTypes(IRModule* module) : as<IRPrimitivesType>(meshOutput) ? kIROp_PrimitivesDecoration : (SLANG_UNREACHABLE("Missing case for IRMeshOutputType"), IROp(0)); // Ensure that all params are marked up as vertices/indices/primitives - traverseUses<IRParam>(meshOutput, [&](IRParam* i) + traverseUsers<IRParam>(meshOutput, [&](IRParam* i) { builder.addMeshOutputDecoration(decorationOp, i, maxCount); }); diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp index 38503155d..d916fa691 100644 --- a/source/slang/slang-ir-legalize-types.cpp +++ b/source/slang/slang-ir-legalize-types.cpp @@ -1861,14 +1861,27 @@ static LegalVal legalizeInst( // While the operands are all "simple," they might not necessarily // be equal to the operands we started with. // + ShortList<IRInst*> newArgs; + newArgs.setCount(argCount); + bool recreate = false; for (UInt aa = 0; aa < argCount; ++aa) { auto legalArg = legalArgs[aa]; - inst->setOperand(aa, legalArg.getSimple()); + newArgs[aa] = legalArg.getSimple(); + if (newArgs[aa] != inst->getOperand(aa)) + recreate = true; + } + if (recreate) + { + IRBuilder builder(inst->getModule()); + builder.setInsertBefore(inst); + auto newInst = builder.emitIntrinsicInst(legalType.getSimple(), inst->getOp(), argCount, newArgs.getArrayView().getBuffer()); + inst->replaceUsesWith(newInst); + inst->removeFromParent(); + context->replacedInstructions.add(inst); + return LegalVal::simple(newInst); } - inst->setFullType(legalType.getSimple()); - return LegalVal::simple(inst); } @@ -1888,6 +1901,10 @@ static LegalVal legalizeInst( legalType, legalArgs.getBuffer()); + if (legalVal.flavor == LegalVal::Flavor::simple) + { + inst->replaceUsesWith(legalVal.getSimple()); + } // After we are done, we will eliminate the // original instruction by removing it from // the IR. diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 80f974536..55048484f 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -229,11 +229,14 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) switch (originalValue->getOp()) { case kIROp_StructType: + case kIROp_ClassType: case kIROp_Func: case kIROp_Generic: case kIROp_GlobalVar: case kIROp_GlobalParam: + case kIROp_GlobalConstant: case kIROp_StructKey: + case kIROp_InterfaceRequirementEntry: case kIROp_GlobalGenericParam: case kIROp_WitnessTable: case kIROp_InterfaceType: @@ -277,26 +280,34 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue) } break; + case kIROp_VoidLit: + { + return builder->getVoidValue(); + } + break; + default: { // In the default case, assume that we have some sort of "hoistable" // instruction that requires us to create a clone of it. UInt argCount = originalValue->getOperandCount(); - IRInst* clonedValue = builder->createIntrinsicInst( - cloneType(this, originalValue->getFullType()), - originalValue->getOp(), - argCount, nullptr); - registerClonedValue(this, clonedValue, originalValue); + ShortList<IRInst*> newArgs; + newArgs.setCount(argCount); for (UInt aa = 0; aa < argCount; ++aa) { IRInst* originalArg = originalValue->getOperand(aa); IRInst* clonedArg = cloneValue(this, originalArg); - clonedValue->getOperands()[aa].init(clonedValue, clonedArg); + newArgs[aa] = clonedArg; } + IRInst* clonedValue = builder->createIntrinsicInst( + cloneType(this, originalValue->getFullType()), + originalValue->getOp(), + argCount, newArgs.getArrayView().getBuffer()); + registerClonedValue(this, clonedValue, originalValue); + cloneDecorationsAndChildren(this, clonedValue, originalValue); - - addHoistableInst(builder, clonedValue); + builder->addInst(clonedValue); return clonedValue; } @@ -524,6 +535,8 @@ IRGlobalConstant* cloneGlobalConstantImpl( IRGlobalConstant* originalVal, IROriginalValuesForClone const& originalValues) { + auto oldBuilder = context->builder; + context->builder = builder; auto clonedType = cloneType(context, originalVal->getFullType()); IRGlobalConstant* clonedVal = nullptr; if(auto originalInitVal = originalVal->getValue()) @@ -537,7 +550,7 @@ IRGlobalConstant* cloneGlobalConstantImpl( } cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal); - + context->builder = oldBuilder; return clonedVal; } @@ -1174,21 +1187,24 @@ IRInst* cloneInst( // instruction with the right number of operands, intialize // it, and then add it to the sequence. UInt argCount = originalInst->getOperandCount(); - IRInst* clonedInst = builder->createIntrinsicInst( - cloneType(context, originalInst->getFullType()), - originalInst->getOp(), - argCount, nullptr); - registerClonedValue(context, clonedInst, originalValues); + ShortList<IRInst*> newArgs; + newArgs.setCount(argCount); auto oldBuilder = context->builder; context->builder = builder; for (UInt aa = 0; aa < argCount; ++aa) { IRInst* originalArg = originalInst->getOperand(aa); IRInst* clonedArg = cloneValue(context, originalArg); - clonedInst->getOperands()[aa].init(clonedInst, clonedArg); + newArgs[aa] = clonedArg; } - builder->addInst(clonedInst); context->builder = oldBuilder; + + IRInst* clonedInst = builder->createIntrinsicInst( + cloneType(context, originalInst->getFullType()), + originalInst->getOp(), + argCount, newArgs.getArrayView().getBuffer()); + builder->addInst(clonedInst); + registerClonedValue(context, clonedInst, originalValues); cloneDecorationsAndChildren(context, clonedInst, originalInst); cloneExtraDecorations(context, clonedInst, originalValues); return clonedInst; diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp index 6f412d579..f2d7159d4 100644 --- a/source/slang/slang-ir-lower-generic-function.cpp +++ b/source/slang/slang-ir-lower-generic-function.cpp @@ -56,25 +56,51 @@ namespace Slang lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType())); SLANG_ASSERT(loweredGenericType); loweredFunc->setFullType(loweredGenericType); - List<IRInst*> clonedParams; + List<IRInst*> childrenToDemote; + List<IRInst*> clonedParams; for (auto genericChild : genericParent->getFirstBlock()->getChildren()) { - if (genericChild == func) + switch (genericChild->getOp()) + { + case kIROp_Func: continue; - if (genericChild->getOp() == kIROp_Return) + case kIROp_Return: continue; + } // Process all generic parameters and local type definitions. auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild); - if (clonedChild->getOp() == kIROp_Param) + switch (clonedChild->getOp()) { - auto paramType = clonedChild->getFullType(); - auto loweredParamType = sharedContext->lowerType(&builder, paramType); - if (loweredParamType != paramType) + case kIROp_Param: { - clonedChild->setFullType((IRType*)loweredParamType); + auto paramType = clonedChild->getFullType(); + auto loweredParamType = sharedContext->lowerType(&builder, paramType); + if (loweredParamType != paramType) + { + clonedChild->setFullType((IRType*)loweredParamType); + } + clonedParams.add(clonedChild); + } + break; + + case kIROp_LookupWitness: + case kIROp_Specialize: + { + childrenToDemote.add(clonedChild); + // Make sure all uses are from the function body. + for (auto use = genericChild->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getParent() == genericChild->getParent()) + { + // This specialize/lookup is used as operand to some other + // global inst in the generic. This is not supported now. + SLANG_UNIMPLEMENTED_X( + "Unsupported use of specialize/lookupWitness in generic body."); + } + } + continue; } - clonedParams.add(clonedChild); } } cloneInstDecorationsAndChildren(&cloneEnv, &sharedContext->sharedBuilderStorage, func, loweredFunc); @@ -85,6 +111,15 @@ namespace Slang param->removeFromParent(); block->addParam(as<IRParam>(param)); } + + // Demote specialize and lookupWitness insts and their dependents down to function body. + auto insertPoint = block->getFirstOrdinaryInst(); + for (Index i = childrenToDemote.getCount() - 1; i >= 0; i--) + { + auto child = childrenToDemote[i]; + child->insertBefore(insertPoint); + } + // Lower generic typed parameters into AnyValueType. auto firstInst = loweredFunc->getFirstOrdinaryInst(); builder.setInsertBefore(firstInst); @@ -292,7 +327,8 @@ namespace Slang loweredFunc = lowerGenericFunction(funcToSpecialize); if (loweredFunc != funcToSpecialize) { - specializeInst->setOperand(0, loweredFunc); + IRBuilder builder; + builder.replaceOperand(specializeInst->getOperands(), loweredFunc); } } } diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 176142601..f3996fc01 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -38,8 +38,6 @@ struct RedundancyRemovalContext case kIROp_GetElement: case kIROp_GetElementPtr: case kIROp_UpdateElement: - case kIROp_LookupWitness: - case kIROp_Specialize: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: case kIROp_MakeOptionalValue: diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp index 5e5f61a4a..67d95c59f 100644 --- a/source/slang/slang-ir-simplify-for-emit.cpp +++ b/source/slang/slang-ir-simplify-for-emit.cpp @@ -5,12 +5,16 @@ namespace Slang { +bool isCPUTarget(TargetRequest* targetReq); +bool isCUDATarget(TargetRequest* targetReq); + struct SimplifyForEmitContext : public InstPassBase { - SimplifyForEmitContext(IRModule* inModule) - : InstPassBase(inModule) + SimplifyForEmitContext(IRModule* inModule, TargetRequest* inTargetReq) + : InstPassBase(inModule), targetReq(inTargetReq) {} + TargetRequest* targetReq; List<IRInst*> followUpWorkList; HashSet<IRInst*> followUpWorkListSet; @@ -134,7 +138,7 @@ struct SimplifyForEmitContext : public InstPassBase IRBuilder builder(sharedBuilderStorage); builder.setInsertBefore(user); auto newLoad = builder.emitLoad(load->getPtr()); - use->set(newLoad); + builder.replaceOperand(use, newLoad); } void processLoad(IRLoad* inst) @@ -330,8 +334,115 @@ struct SimplifyForEmitContext : public InstPassBase processInst(followUpWorkList[i]); } + void unifyBinaryExprOperands(IRGlobalValueWithCode* func) + { + IRBuilder builder(func->getModule()); + + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_And: + case kIROp_Or: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Leq: + case kIROp_Less: + case kIROp_Geq: + case kIROp_Greater: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Lsh: + case kIROp_Rsh: + builder.setInsertBefore(inst); + SLANG_ASSERT(inst->getOperandCount() == 2); + if (as<IRVectorType>(inst->getDataType())) + { + for (UInt a = 0; a < 2; a++) + { + if (as<IRBasicType>(inst->getOperand(a)->getDataType())) + { + auto v = builder.emitMakeVectorFromScalar( + inst->getOperand(1 - a)->getDataType(), inst->getOperand(a)); + inst->setOperand(a, v); + } + } + } + else if (as<IRMatrixType>(inst->getDataType())) + { + for (UInt a = 0; a < 2; a++) + { + if (as<IRBasicType>(inst->getOperand(a)->getDataType())) + { + auto v = builder.emitMakeMatrixFromScalar( + inst->getOperand(1 - a)->getDataType(), inst->getOperand(a)); + inst->setOperand(a, v); + } + } + } + + break; + } + } + } + } + + // Turn single element vector values into scalars before using it to call an intrinsic func. + void lowerTrivialVector(IRGlobalValueWithCode* func) + { + IRBuilder builder(func->getModule()); + List<IRInst*> instsToProcess; + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst()) + { + switch (inst->getOp()) + { + case kIROp_Call: + { + // If we are calling an intrinsic with any vector<T,1> argument, replace it with T. + auto callInst = as<IRCall>(inst); + if (getResolvedInstForDecorations(callInst->getCallee())->findDecoration<IRTargetIntrinsicDecoration>()) + { + for (UInt a = 0; a < callInst->getArgCount(); a++) + { + auto arg = callInst->getArg(a); + if (auto argVectorType = as<IRVectorType>(arg->getDataType())) + { + if (cast<IRIntLit>(argVectorType->getElementCount())->getValue() == 1) + { + builder.setInsertBefore(callInst); + UInt idx = 0; + auto newArg = builder.emitSwizzle(argVectorType->getElementType(), arg, 1, &idx); + callInst->setOperand(a + 1, newArg); + } + } + } + } + } + break; + } + } + } + } + + void processFunc(IRGlobalValueWithCode* func) { + if (isCPUTarget(targetReq) || isCUDATarget(targetReq)) + { + unifyBinaryExprOperands(func); + lowerTrivialVector(func); + } eliminateCompositeConstruct(func); deferAndDuplicateElementExtract(func); deferAndDuplicateLoad(func); @@ -345,9 +456,9 @@ struct SimplifyForEmitContext : public InstPassBase } }; -void simplifyForEmit(IRModule* module) +void simplifyForEmit(IRModule* module, TargetRequest* targetRequest) { - SimplifyForEmitContext context(module); + SimplifyForEmitContext context(module, targetRequest); context.processModule(); } diff --git a/source/slang/slang-ir-simplify-for-emit.h b/source/slang/slang-ir-simplify-for-emit.h index a6cf3bad8..e35c74841 100644 --- a/source/slang/slang-ir-simplify-for-emit.h +++ b/source/slang/slang-ir-simplify-for-emit.h @@ -4,6 +4,7 @@ namespace Slang { struct IRModule; + class TargetRequest; - void simplifyForEmit(IRModule* inModule); + void simplifyForEmit(IRModule* inModule, TargetRequest* req); } diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp index 39edaeb16..cfc9d9c76 100644 --- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp +++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp @@ -200,23 +200,22 @@ struct AssociatedTypeLookupSpecializationContext if (!seqId) return; // Insert code to pack sequential ID into an uint2 at all use sites. - IRUse* nextUse = nullptr; - for (auto use = inst->firstUse; use; use = nextUse) + traverseUses(inst, [&](IRUse* use) { - nextUse = use->nextUse; if (as<IRCOMWitnessDecoration>(use->getUser())) - continue; + { + return; + } IRBuilder builder(sharedContext->sharedBuilderStorage); builder.setInsertBefore(use->getUser()); auto uint2Type = builder.getVectorType( builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2)); IRInst* uint2Args[] = { seqId->getSequentialIDOperand(), - builder.getIntValue(builder.getUIntType(), 0)}; + builder.getIntValue(builder.getUIntType(), 0) }; auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args); - use->set(uint2seqID); - use = nextUse; - } + builder.replaceOperand(use, uint2seqID); + }); } }); @@ -229,14 +228,12 @@ struct AssociatedTypeLookupSpecializationContext builder.setInsertBefore(globalInst); auto witnessTableIDType = builder.getWitnessTableIDType( (IRType*)cast<IRWitnessTableType>(globalInst)->getConformanceType()); - IRUse* nextUse = nullptr; - for (auto use = globalInst->firstUse; use; use = nextUse) + traverseUses(globalInst, [&](IRUse* use) { - nextUse = use->nextUse; if (use->getUser()->getOp() == kIROp_WitnessTable) - continue; - use->set(witnessTableIDType); - } + return; + builder.replaceOperand(use, witnessTableIDType); + }); sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); } } diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index e4ccf40d5..03eda0d99 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -256,16 +256,16 @@ struct ResourceOutputSpecializationPass // the aid of this pass. // List<IRCall*> calls; - for( auto use = oldFunc->firstUse; use; use = use->nextUse ) - { - auto user = use->getUser(); - auto call = as<IRCall>(user); - if(!call) - continue; - if(call->getCallee() != oldFunc) - continue; - calls.add(call); - } + traverseUses(oldFunc, [&](IRUse* use) + { + auto user = use->getUser(); + auto call = as<IRCall>(user); + if (!call) + return; + if (call->getCallee() != oldFunc) + return; + calls.add(call); + }); // Once we have identified the calls to `oldFunc`, we will set about replacing // them with calls to `newFunc`. @@ -833,16 +833,16 @@ struct ResourceOutputSpecializationPass // `out`/`inout` parameters that doesn't have as many "gotcha" cases. // List<IRStore*> stores; - for( auto use = param->firstUse; use; use = use->nextUse ) - { - auto user = use->getUser(); - auto store = as<IRStore>(user); - if(!store) - continue; - if(store->ptr.get() != param) - continue; - stores.add(store); - } + traverseUses(param, [&](IRUse* use) + { + auto user = use->getUser(); + auto store = as<IRStore>(user); + if (!store) + return; + if (store->ptr.get() != param) + return; + stores.add(store); + }); // Having identified the places where a value is stored to // the output parameter, we iterate over those values to @@ -1194,16 +1194,16 @@ bool specializeResourceUsage( // Inline unspecializable resource output functions and then continue trying. for (auto func : unspecializableFuncs) { - for (auto use = func->firstUse; use; use = use->nextUse) + traverseUses(func, [&](IRUse* use) { auto user = use->getUser(); auto call = as<IRCall>(user); if (!call) - continue; + return; if (call->getCallee() != func) - continue; + return; inlineCall(call); - } + }); } simplifyIR(irModule); } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index cf7acd46c..0044e5745 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -897,7 +897,8 @@ struct SpecializationContext // specialization opportunities (generic specialization, // existential specialization, simplifications, etc.) // - iterChanged |= maybeSpecializeInst(inst); + if (inst->hasUses() || inst->mightHaveSideEffects()) + iterChanged |= maybeSpecializeInst(inst); // Finally, we need to make our logic recurse through // the whole IR module, so we want to add the children @@ -1041,7 +1042,6 @@ struct SpecializationContext // The old callee should be in the form of `specialize(.operator[], IInterfaceType)`, // we should update it to be `specialize(.operator[], elementType)`, so the return type // of the load call is `elementType`. - auto oldCallee = inst->getCallee(); // A subscript operation on mutable buffers returns a ptr type instead of a value type. // We need to make sure the pointer-ness is preserved correctly. @@ -1057,9 +1057,6 @@ struct SpecializationContext inst->replaceUsesWith(newWrapExistential); workList.Remove(inst); inst->removeAndDeallocate(); - SLANG_ASSERT(!oldCallee->hasUses()); - workList.Remove(oldCallee); - oldCallee->removeAndDeallocate(); addUsersToWorkList(newWrapExistential); workList.Remove(wrapExistential); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 3f250e31e..b195af2cc 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -923,6 +923,15 @@ IRBlock* IREdge::getSuccessor() const return cast<IRBlock>(getUse()->get()); } +void SharedIRBuilder::init(IRModule* module) +{ + m_module = module; + m_session = module->getSession(); + + m_globalValueNumberingMap.Clear(); + m_constantMap.Clear(); +} + void SharedIRBuilder::insertBlockAlongEdge( IREdge const& edge) { diff --git a/source/slang/slang-ir-type-set.cpp b/source/slang/slang-ir-type-set.cpp index 0cfe69e42..7ac617bda 100644 --- a/source/slang/slang-ir-type-set.cpp +++ b/source/slang/slang-ir-type-set.cpp @@ -7,313 +7,4 @@ namespace Slang { -IRTypeSet::IRTypeSet(Session* session) -{ - m_module = IRModule::create(session); - - m_sharedBuilder.init(m_module); - m_builder.init(m_sharedBuilder); - - m_builder.setInsertInto(m_module->getModuleInst()); -} - -IRTypeSet::~IRTypeSet() -{ - _clearTypes(); -} - -void IRTypeSet::clear() -{ - _clearTypes(); - - m_cloneMap.Clear(); - - m_module = IRModule::create(m_sharedBuilder.getSession()); - - m_sharedBuilder.init(m_module); - m_builder.init(m_sharedBuilder); - - m_builder.setInsertInto(m_module->getModuleInst()); -} - -void IRTypeSet::_clearTypes() -{ - List<IRType*> types; - getTypes(types); - - for (auto type : types) - { - // We need to destroy references to instructions in other modules - if (type->getModule() == m_module) - { - // We want to remove arguments because an argument *could* be an instruction in another module, - // and we don't want to those modules insts to have uses, in this module which is being destroyed - type->removeArguments(); - } - } -} - -IRInst* IRTypeSet::cloneInst(IRInst* inst) -{ - if (inst == nullptr) - { - return nullptr; - } - - // See if it's already cloned - if (IRInst*const* newInstPtr = m_cloneMap.TryGetValue(inst)) - { - return *newInstPtr; - } - - IRModule* module = inst->getModule(); - // All inst's must belong to a module - SLANG_ASSERT(module); - - // If it's in this module then we don't need to clone - if (module == m_module) - { - return inst; - } - - if (isNominalOp(inst->getOp())) - { - // We can clone without any definition, and add the linkage - - // TODO(JS) - // This is arguably problematic - I'm adding an instruction from another module to the map, to be it's self. - // I did have code which created a copy of the nominal instruction and name hint, but because nominality means - // 'same address' other code would generate a different name for that instruction (say as compared to being a member in - // the original instruction) - // - // Because I use findOrAddInst which doesn't hoist instructions, the hoisting doesn't rely on parenting, that would - // break. - - // If nominal, we just use the original inst - m_cloneMap.Add(inst, inst); - return inst; - } - - // It would be nice if I could use ir-clone.cpp to do this -> but it doesn't clone - // operands. We wouldn't want to clone decorations, and it can't clone IRConstant(!) so - // it's no use - - IRInst* clone = nullptr; - switch (inst->getOp()) - { - case kIROp_IntLit: - { - auto intLit = static_cast<IRConstant*>(inst); - IRType* clonedType = cloneType(intLit->getDataType()); - clone = m_builder.getIntValue(clonedType, intLit->value.intVal); - break; - } - case kIROp_StringLit: - { - auto stringLit = static_cast<IRStringLit*>(inst); - clone = m_builder.getStringValue(stringLit->getStringSlice()); - break; - } - case kIROp_VectorType: - { - auto vecType = static_cast<IRVectorType*>(inst); - const Index elementCount = Index(getIntVal(vecType->getElementCount())); - - if (elementCount <= 1) - { - clone = cloneType(vecType->getElementType()); - } - break; - } - case kIROp_MatrixType: - { - auto matType = static_cast<IRMatrixType*>(inst); - const Index columnCount = Index(getIntVal(matType->getColumnCount())); - const Index rowCount = Index(getIntVal(matType->getRowCount())); - - if (columnCount <= 1 && rowCount <= 1) - { - clone = cloneType(matType->getElementType()); - } - break; - } - default: break; - } - - if (!clone) - { - if (IRBasicType::isaImpl(inst->getOp())) - { - clone = m_builder.getType(inst->getOp()); - } - else - { - IRType* irType = dynamicCast<IRType>(inst); - if (irType) - { - auto clonedType = cloneType(inst->getFullType()); - Index operandCount = Index(inst->getOperandCount()); - - List<IRInst*> cloneOperands; - cloneOperands.setCount(operandCount); - - for (Index i = 0; i < operandCount; ++i) - { - cloneOperands[i] = cloneInst(inst->getOperand(i)); - } - - //clone = m_irBuilder.findOrEmitHoistableInst(cloneType, inst->op, operandCount, cloneOperands.getBuffer()); - - UInt operandCounts[1] = { UInt(operandCount) }; - IRInst*const* listOperands[1] = { cloneOperands.getBuffer() }; - - clone = m_builder.findOrAddInst(clonedType, inst->getOp(), 1, operandCounts, listOperands); - } - else - { - // This cloning style only works on insts that are not unique - auto clonedType = cloneType(inst->getFullType()); - - Index operandCount = Index(inst->getOperandCount()); - clone = m_builder.emitIntrinsicInst(clonedType, inst->getOp(), operandCount, nullptr); - for (Index i = 0; i < operandCount; ++i) - { - auto cloneOperand = cloneInst(inst->getOperand(i)); - clone->getOperands()[i].init(clone, cloneOperand); - } - } - } - } - - m_cloneMap.Add(inst, clone); - return clone; -} - -IRType* IRTypeSet::add(IRType* irType) -{ - if (irType->getModule() == m_module) - { - return irType; - } - // We need to clone the type - return cloneType(irType); -} - -void IRTypeSet::getTypes(List<IRType*>& outTypes) const -{ - outTypes.clear(); - for (auto inst : m_module->getModuleInst()->getChildren()) - { - if (IRType* type = as<IRType>(inst)) - { - outTypes.add(type); - } - } -} - -void IRTypeSet::getTypes(Kind kind, List<IRType*>& outTypes) const -{ - outTypes.clear(); - - for (auto inst : m_module->getModuleInst()->getChildren()) - { - IRType* type = nullptr; - - switch (kind) - { - case Kind::Scalar: - { - type = as<IRBasicType>(inst); - break; - } - case Kind::Vector: - { - type = as<IRVectorType>(inst); - break; - } - case Kind::Matrix: - { - type = as<IRMatrixType>(inst); - break; - } - default: break; - } - - if (type) - { - outTypes.add(type); - } - } -} - -IRType* IRTypeSet::addVectorType(IRType* inElementType, int colsCount) -{ - IRType* elementType = cloneType(inElementType); - if (colsCount == 1) - { - return elementType; - } - return m_builder.getVectorType(elementType, m_builder.getIntValue(m_builder.getIntType(), colsCount)); -} - -void IRTypeSet::addVectorForMatrixTypes() -{ - // Make a copy so we can alter m_types dictionary - List<IRType*> types; - getTypes(Kind::Matrix, types); - for (IRType* type : types) - { - SLANG_ASSERT(as<IRMatrixType>(type)); - IRMatrixType* matType = static_cast<IRMatrixType*>(type); - m_builder.getVectorType(matType->getElementType(), matType->getColumnCount()); - } -} - -static bool _hasNominalOperand(IRInst* inst) -{ - const Index operandCount = Index(inst->getOperandCount()); - auto operands = inst->getOperands(); - - for (Index i = 0; i < operandCount; ++i) - { - IRInst* operand = operands[i].get(); - if (isNominalOp(operand->getOp())) - { - return true; - } - } - - return false; -} - -void IRTypeSet::_addAllBuiltinTypesRec(IRInst* inst) -{ - for (IRInst* child = inst->getFirstDecorationOrChild(); child; child = child->getNextInst()) - { - IRType* type = nullptr; - - if (auto vectorType = as<IRVectorType>(child)) - { - type = vectorType; - } - else if (auto matrixType = as<IRMatrixType>(child)) - { - type = matrixType; - } - if (type && !_hasNominalOperand(type)) - { - add(type); - } - else - { - _addAllBuiltinTypesRec(child); - } - } -} - -void IRTypeSet::addAllBuiltinTypes(IRModule* module) -{ - _addAllBuiltinTypesRec(module->getModuleInst()); -} - } diff --git a/source/slang/slang-ir-type-set.h b/source/slang/slang-ir-type-set.h index 958d71cf1..f60088fcd 100644 --- a/source/slang/slang-ir-type-set.h +++ b/source/slang/slang-ir-type-set.h @@ -9,85 +9,4 @@ namespace Slang { -/* -NOTE! This type set is only designed to work for emitting code to determine unique types. It is envisaged in the -future that it will not be needed because types will be made unique within a module, and thus the pointer to a type -will uniquely identify the type. - -The other reason this type exists, is to allow an IRModule for emit to be immutable. That is not currently possible -within emit code because it may be necessary in order to emit to be able to create other types that needed (for example -vector types required for a matrix type implementation). - -This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic. -That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to -work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms -of other passes. -Even if it was the case when we may want to add types as part of emitting, we can't use the previously used -shared builder, so again we end up with pointers to the same things not being the same thing. - -To work around this we clone types we want to use as keys into the 'unique module'. -This is not necessary for all types though - as we assume nominal types *must* have unique pointers (that is the -definition of nominal). - -This could be handled in other ways (for example not testing equality on pointer equality). Anyway for now this -works, but probably needs to be handled in a better way. The better way may involve having guarantees about equality -enabled in other code generation and making de-duping possible in emit code. - -Note that one pro for this approach is that it does not alter the source module. That as it stands it's not necessary -for the source module to be immutable, because it is created for emitting and then discarded. - -NOTE! That Vector<X, 1> or Matrix<X, 1, 1> will be turned into the type X. - - */ -class IRTypeSet -{ -public: - enum class Kind - { - Scalar, - Vector, - Matrix, - CountOf, - }; - - IRType* add(IRType* type); - IRType* addVectorType(IRType* elementType, int colsCount); - - void addAllBuiltinTypes(IRModule* module); - - void addVectorForMatrixTypes(); - - void getTypes(List<IRType*>& outTypes) const; - void getTypes(Kind kind, List<IRType*>& outTypes) const; - - IRType* getType(IRType* type) { return cloneType(type); } - - IRType* cloneType(IRType* type) { return (IRType*)cloneInst((IRInst*)type); } - IRInst* cloneInst(IRInst* inst); - - /// Returns true if the type belongs and is created on the module owned by the set - bool isOwned(IRType* type) { return type->getModule() == m_module; } - - IRBuilder& getBuilder() { return m_builder; } - IRModule* getModule() const { return m_module; } - - void clear(); - - IRTypeSet(Session* session); - ~IRTypeSet(); - -protected: - void _addAllBuiltinTypesRec(IRInst* inst); - void _clearTypes(); - - // Maps insts from source modules into m_module. - // NOTE! That nominal types are not cloned, as they are identified by pointer. They are just - Dictionary<IRInst*, IRInst*> m_cloneMap; - - // Can find all types by traversing the types in the m_module - SharedIRBuilder m_sharedBuilder; - IRBuilder m_builder; - RefPtr<IRModule> m_module; -}; - } // namespace Slang diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 1ea426715..253686aa5 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -549,6 +549,8 @@ struct GenericChildrenMigrationContextImpl } if (as<IRConstant>(inst)) return false; + if (getIROpInfo(inst->getOp()).isHoistable()) + return false; return true; }); } diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index efd38f7b7..2f1ac2d1a 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -49,7 +49,7 @@ struct DeduplicateContext return *newValue; for (UInt i = 0; i < value->getOperandCount(); i++) { - value->setOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate)); + value->unsafeSetOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate)); } value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate)); if (auto newValue = deduplicateMap.TryGetValue(key)) diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index 03db96ac5..d5c0aa432 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -186,6 +186,28 @@ namespace Slang if (pp == operandParent) return; } + + // We allow out-of-order def-use in global scope. + bool allInGlobalScope = inst->getParent() && inst->getParent()->getOp() == kIROp_Module; + if (allInGlobalScope) + { + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + auto op = inst->getOperand(i); + if (!op) + continue; + if (!op->getParent()) + continue; + if (op->getParent()->getOp() != kIROp_Module) + { + allInGlobalScope = false; + break; + } + } + } + if (allInGlobalScope) + return; + // // We failed to find `operandParent` while walking the ancestors of `inst`, // so something had gone wrong. diff --git a/source/slang/slang-ir-wrap-structured-buffers.cpp b/source/slang/slang-ir-wrap-structured-buffers.cpp index 2ad09aa90..53671fa7f 100644 --- a/source/slang/slang-ir-wrap-structured-buffers.cpp +++ b/source/slang/slang-ir-wrap-structured-buffers.cpp @@ -134,7 +134,7 @@ struct WrapStructuredBuffersContext // scanning through its IR uses, since values of that // type are using it as a (type) operand. // - for( auto typeUse = newStructuredBufferType->firstUse; typeUse; typeUse = typeUse->nextUse ) + traverseUses(newStructuredBufferType, [&](IRUse* typeUse) { // There might be uses of `newStructuredBufferType` where // it isn't being used as the type of a value, so we @@ -142,7 +142,7 @@ struct WrapStructuredBuffersContext // auto valueOfStructuredBufferType = typeUse->getUser(); if(valueOfStructuredBufferType->getFullType() != newStructuredBufferType) - continue; + return; // Now we have some `valueOfStructuredBufferType`. In our running // example, this might be `gBuffer`, which is an `IRGlobalParam`. @@ -155,7 +155,7 @@ struct WrapStructuredBuffersContext // because these could be calls to intrinsic functions like // `RWStructuredBuffer.Load` // - for( auto valueUse = valueOfStructuredBufferType->firstUse; valueUse; valueUse = valueUse->nextUse ) + traverseUses(valueOfStructuredBufferType, [&](IRUse* valueUse) { // we are only interested in instructions that are calls, // with at least one argument, where the first argument @@ -165,11 +165,11 @@ struct WrapStructuredBuffersContext // auto call = as<IRCall>(valueUse->getUser()); if(!call) - continue; + return; if(call->getArgCount() == 0) - continue; + return; if(call->getArg(0) != valueOfStructuredBufferType) - continue; + return; // At this point we have a candidate `call` instruction, // but we need to determine whether it is a call to @@ -196,7 +196,7 @@ struct WrapStructuredBuffersContext // auto callee = call->getCallee(); if(!as<IRSpecialize>(callee)) - continue; + return; // At this point it seems likely we have one of the calls // we want to rewrite, but there are still intrinsics @@ -285,8 +285,8 @@ struct WrapStructuredBuffersContext newVal->setOperand(0, call); } } - } - } + }); + }); } /// Get the struture field "key" to use for generated wrappers diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1b16bfe1f..6cf0f09a5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -148,7 +148,6 @@ namespace Slang void IRUse::init(IRInst* u, IRInst* v) { clear(); - user = u; usedValue = v; if(v) @@ -170,6 +169,9 @@ namespace Slang void IRUse::set(IRInst* uv) { + // Normally we should never be modifying the operand of an hoistable inst. + // They can be modified by `replaceUsesWith`, or to be replaced by a new inst. + SLANG_ASSERT(!getIROpInfo(user->getOp()).isHoistable() || uv == usedValue); init(user, uv); } @@ -1196,11 +1198,57 @@ namespace Slang return as<IRGlobalValueWithCode>(pp); } + void addHoistableInst( + IRBuilder* builder, + IRInst* inst); + // Add an instruction into the current scope void IRBuilder::addInst( IRInst* inst) { - inst->insertAt(m_insertLoc); + if (getIROpInfo(inst->getOp()).isGlobal()) + { + addHoistableInst(this, inst); + return; + } + + if (!inst->parent) + inst->insertAt(m_insertLoc); + } + + IRInst* IRBuilder::replaceOperand(IRUse* use, IRInst* newValue) + { + auto user = use->getUser(); + if (user->getModule()) + { + user->getModule()->getSharedBuilder()->getInstReplacementMap().TryGetValue(newValue, newValue); + } + + if (!getIROpInfo(user->getOp()).isHoistable()) + { + use->set(newValue); + return user; + } + + // If user is hoistable, we need to remove it from the global number map first, + // perform the update, then try to reinsert it back to the global number map. + // If we find an equivalent entry already exists in the global number map, + // we return the existing entry. + auto builder = user->getModule()->getSharedBuilder(); + builder->_removeGlobalNumberingEntry(user); + use->init(user, newValue); + + IRInst* existingVal = nullptr; + if (builder->getGlobalValueNumberingMap().TryGetValue(IRInstKey{ user }, existingVal)) + { + user->replaceUsesWith(existingVal); + return existingVal; + } + else + { + builder->_addGlobalNumberingEntry(user); + return user; + } } // Given two parent instructions, pick the better one to use as as @@ -1645,6 +1693,13 @@ namespace Slang Int const* listArgCounts, IRInst* const* const* listArgs) { + m_sharedBuilder->getInstReplacementMap().TryGetValue((IRInst*)(type), *(IRInst**)&type); + + if (getIROpInfo(op).flags & kIROpFlag_Hoistable) + { + return _findOrEmitHoistableInst(type, op, fixedArgCount, fixedArgs, varArgListCount, listArgCounts, listArgs); + } + Int varArgCount = 0; for (Int ii = 0; ii < varArgListCount; ++ii) { @@ -1671,7 +1726,9 @@ namespace Slang { if (fixedArgs) { - operand->init(inst, fixedArgs[aa]); + auto arg = fixedArgs[aa]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->init(inst, arg); } else { @@ -1687,7 +1744,9 @@ namespace Slang { if (listArgs[ii]) { - operand->init(inst, listArgs[ii][jj]); + auto arg = listArgs[ii][jj]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->init(inst, arg); } else { @@ -2309,21 +2368,23 @@ namespace Slang args.add(getIntValue(capabilityAtomType, Int(atom))); } - return findOrEmitHoistableInst( + return createIntrinsicInst( capabilitySetType, kIROp_CapabilitySet, args.getCount(), args.getBuffer()); } - IRInst* IRBuilder::findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands) - { - UInt operandCount = 0; - for (UInt ii = 0; ii < operandListCount; ++ii) + IRInst* IRBuilder::_findOrEmitHoistableInst( + IRType* type, + IROp op, + Int fixedArgCount, + IRInst* const* fixedArgs, + Int varArgListCount, + Int const* listArgCounts, + IRInst* const* const* listArgs) + { + UInt operandCount = fixedArgCount; + for (Int ii = 0; ii < varArgListCount; ++ii) { - operandCount += listOperandCounts[ii]; + operandCount += listArgCounts[ii]; } auto& memoryArena = getModule()->getMemoryArena(); @@ -2350,102 +2411,21 @@ namespace Slang // Don't link up as we may free (if we already have this key) { IRUse* operand = inst->getOperands(); - for (UInt ii = 0; ii < operandListCount; ++ii) + for (Int ii = 0; ii < fixedArgCount; ++ii) { - UInt listOperandCount = listOperandCounts[ii]; - for (UInt jj = 0; jj < listOperandCount; ++jj) - { - operand->usedValue = listOperands[ii][jj]; - operand++; - } - } - } - - // Find or add the key/inst - { - IRInstKey key = { inst }; - - // Ideally we would add if not found, else return if was found instead of testing & then adding. - IRInst** found = getSharedBuilder()->getGlobalValueNumberingMap().TryGetValueOrAdd(key, inst); - SLANG_ASSERT(endCursor == memoryArena.getCursor()); - // If it's found, just return, and throw away the instruction - if (found) - { - memoryArena.rewindToCursor(cursor); - return *found; - } - } - - // Make the lookup 'inst' instruction into 'proper' instruction. Equivalent to - // IRInst* inst = createInstImpl<IRInst>(builder, op, type, 0, nullptr, operandListCount, listOperandCounts, listOperands); - { - if (type) - { - inst->typeUse.usedValue = nullptr; - inst->typeUse.init(inst, type); - } - - _maybeSetSourceLoc(inst); - - IRUse*const operands = inst->getOperands(); - for (UInt i = 0; i < operandCount; ++i) - { - IRUse& operand = operands[i]; - auto value = operand.usedValue; - - operand.usedValue = nullptr; - operand.init(inst, value); + auto arg = fixedArgs[ii]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->usedValue = arg; + operand++; } - } - - addHoistableInst(this, inst); - - return inst; - } - - IRInst* IRBuilder::findOrAddInst( - IRType* type, - IROp op, - UInt operandListCount, - UInt const* listOperandCounts, - IRInst* const* const* listOperands) - { - UInt operandCount = 0; - for (UInt ii = 0; ii < operandListCount; ++ii) - { - operandCount += listOperandCounts[ii]; - } - - auto& memoryArena = getModule()->getMemoryArena(); - void* cursor = memoryArena.getCursor(); - - // We are going to create a 'dummy' instruction on the memoryArena - // which can be used as a key for lookup, so see if we - // already have an equivalent instruction available to use. - size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse); - IRInst* inst = (IRInst*)memoryArena.allocateAndZero(keySize); - - void* endCursor = memoryArena.getCursor(); - // Mark as 'unused' cos it is unused on release builds. - SLANG_UNUSED(endCursor); - - new(inst) IRInst(); -#if SLANG_ENABLE_IR_BREAK_ALLOC - inst->_debugUID = _debugGetAndIncreaseInstCounter(); -#endif - inst->m_op = op; - inst->typeUse.usedValue = type; - inst->operandCount = (uint32_t)operandCount; - - // Don't link up as we may free (if we already have this key) - { - IRUse* operand = inst->getOperands(); - for (UInt ii = 0; ii < operandListCount; ++ii) + for (Int ii = 0; ii < varArgListCount; ++ii) { - UInt listOperandCount = listOperandCounts[ii]; + UInt listOperandCount = listArgCounts[ii]; for (UInt jj = 0; jj < listOperandCount; ++jj) { - operand->usedValue = listOperands[ii][jj]; + auto arg = listArgs[ii][jj]; + m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg); + operand->usedValue = arg; operand++; } } @@ -2488,50 +2468,17 @@ namespace Slang } } - addInst(inst); - return inst; - } - - - IRInst* IRBuilder::findOrEmitHoistableInst( - IRType* type, - IROp op, - UInt operandCount, - IRInst* const* operands) - { - return findOrEmitHoistableInst( - type, - op, - 1, - &operandCount, - &operands); - } - - IRInst* IRBuilder::findOrEmitHoistableInst( - IRType* type, - IROp op, - IRInst* operand, - UInt operandCount, - IRInst* const* operands) - { - UInt counts[] = { 1, operandCount }; - IRInst* const* lists[] = { &operand, operands }; + addHoistableInst(this, inst); - return findOrEmitHoistableInst( - type, - op, - 2, - counts, - lists); + return inst; } - IRType* IRBuilder::getType( IROp op, UInt operandCount, IRInst* const* operands) { - return (IRType*) findOrEmitHoistableInst( + return (IRType*)createIntrinsicInst( nullptr, op, operandCount, @@ -2831,7 +2778,7 @@ namespace Slang IRType* const* paramTypes, IRType* resultType) { - return (IRFuncType*) findOrEmitHoistableInst( + return (IRFuncType*)createIntrinsicInst( nullptr, kIROp_FuncType, resultType, @@ -2844,13 +2791,13 @@ namespace Slang { UInt counts[3] = {1, paramCount, 1}; IRInst** lists[3] = {(IRInst**)&resultType, (IRInst**)paramTypes, (IRInst**)&attribute}; - return (IRFuncType*)findOrEmitHoistableInst(nullptr, kIROp_FuncType, 3, counts, lists); + return (IRFuncType*)createIntrinsicInst(nullptr, kIROp_FuncType, 3, counts, lists); } IRWitnessTableType* IRBuilder::getWitnessTableType( IRType* baseType) { - return (IRWitnessTableType*)findOrEmitHoistableInst( + return (IRWitnessTableType*)createIntrinsicInst( nullptr, kIROp_WitnessTableType, 1, @@ -2860,7 +2807,7 @@ namespace Slang IRWitnessTableIDType* IRBuilder::getWitnessTableIDType( IRType* baseType) { - return (IRWitnessTableIDType*)findOrEmitHoistableInst( + return (IRWitnessTableIDType*)createIntrinsicInst( nullptr, kIROp_WitnessTableIDType, 1, @@ -2914,7 +2861,7 @@ namespace Slang UInt caseCount, IRType* const* caseTypes) { - return (IRType*) findOrEmitHoistableInst( + return (IRType*)createIntrinsicInst( getTypeKind(), kIROp_TaggedUnionType, caseCount, @@ -2947,7 +2894,7 @@ namespace Slang } } - return (IRType*) findOrEmitHoistableInst( + return (IRType*)createIntrinsicInst( getTypeKind(), kIROp_BindExistentialsType, baseType, @@ -3197,7 +3144,7 @@ namespace Slang if (as<IRWitnessTable>(innerReturnVal)) { - return findOrEmitHoistableInst( + return createIntrinsicInst( type, kIROp_Specialize, genericVal, @@ -3214,7 +3161,8 @@ namespace Slang argCount, args); - addInst(inst); + if (!inst->parent) + addInst(inst); return inst; } @@ -3233,7 +3181,7 @@ namespace Slang IRInst* args[] = {witnessTableVal, interfaceMethodVal}; - return findOrEmitHoistableInst( + return createIntrinsicInst( type, kIROp_LookupWitness, 2, @@ -3331,6 +3279,17 @@ namespace Slang args); } + IRInst* IRBuilder::createIntrinsicInst( + IRType* type, IROp op, IRInst* operand, UInt operandCount, IRInst* const* operands) + { + return createInstWithTrailingArgs<IRInst>(this, op, type, operand, operandCount, operands); + } + + IRInst* IRBuilder::createIntrinsicInst(IRType* type, IROp op, UInt operandListCount, UInt const* listOperandCounts, IRInst* const* const* listOperands) + { + return createInstImpl<IRInst>(this, op, type, 0, nullptr, (Int)operandListCount, (Int const* )listOperandCounts, listOperands); + } + IRInst* IRBuilder::emitIntrinsicInst( IRType* type, @@ -3343,7 +3302,8 @@ namespace Slang op, argCount, args); - addInst(inst); + if (!inst->parent) + addInst(inst); return inst; } @@ -3772,6 +3732,13 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeMatrix, argCount, args); } + IRInst* IRBuilder::emitMakeMatrixFromScalar( + IRType* type, + IRInst* scalarValue) + { + return emitIntrinsicInst(type, kIROp_MakeMatrixFromScalar, 1, &scalarValue); + } + IRInst* IRBuilder::emitMakeArray( IRType* type, UInt argCount, @@ -3938,7 +3905,7 @@ namespace Slang value->insertAtEnd(parent); } } - + IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target) { return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration); @@ -5056,7 +5023,7 @@ namespace Slang this, kIROp_GlobalConstant, type); - addInst(inst); + addGlobalValue(this, inst); return inst; } @@ -5069,7 +5036,7 @@ namespace Slang kIROp_GlobalConstant, type, val); - addInst(inst); + addGlobalValue(this, inst); return inst; } @@ -5349,7 +5316,7 @@ namespace Slang IRInst* operands[] = { kindInst, sizeInst }; - return cast<IRTypeSizeAttr>(findOrEmitHoistableInst( + return cast<IRTypeSizeAttr>(createIntrinsicInst( getVoidType(), kIROp_TypeSizeAttr, SLANG_COUNT_OF(operands), @@ -5376,7 +5343,7 @@ namespace Slang operands[operandCount++] = spaceInst; } - return cast<IRVarOffsetAttr>(findOrEmitHoistableInst( + return cast<IRVarOffsetAttr>(createIntrinsicInst( getVoidType(), kIROp_VarOffsetAttr, operandCount, @@ -5388,7 +5355,7 @@ namespace Slang { IRInst* operands[] = { pendingLayout }; - return cast<IRPendingLayoutAttr>(findOrEmitHoistableInst( + return cast<IRPendingLayoutAttr>(createIntrinsicInst( getVoidType(), kIROp_PendingLayoutAttr, SLANG_COUNT_OF(operands), @@ -5401,7 +5368,7 @@ namespace Slang { IRInst* operands[] = { key, layout }; - return cast<IRStructFieldLayoutAttr>(findOrEmitHoistableInst( + return cast<IRStructFieldLayoutAttr>(createIntrinsicInst( getVoidType(), kIROp_StructFieldLayoutAttr, SLANG_COUNT_OF(operands), @@ -5413,7 +5380,7 @@ namespace Slang { IRInst* operands[] = { layout }; - return cast<IRCaseTypeLayoutAttr>(findOrEmitHoistableInst( + return cast<IRCaseTypeLayoutAttr>(createIntrinsicInst( getVoidType(), kIROp_CaseTypeLayoutAttr, SLANG_COUNT_OF(operands), @@ -5430,7 +5397,7 @@ namespace Slang IRInst* operands[] = { nameInst, indexInst }; - return cast<IRSemanticAttr>(findOrEmitHoistableInst( + return cast<IRSemanticAttr>(createIntrinsicInst( getVoidType(), op, SLANG_COUNT_OF(operands), @@ -5441,7 +5408,7 @@ namespace Slang { auto stageInst = getIntValue(getIntType(), IRIntegerValue(stage)); IRInst* operands[] = { stageInst }; - return cast<IRStageAttr>(findOrEmitHoistableInst( + return cast<IRStageAttr>(createIntrinsicInst( getVoidType(), kIROp_StageAttr, SLANG_COUNT_OF(operands), @@ -5450,7 +5417,7 @@ namespace Slang IRAttr* IRBuilder::getAttr(IROp op, UInt operandCount, IRInst* const* operands) { - return cast<IRAttr>(findOrEmitHoistableInst( + return cast<IRAttr>(createIntrinsicInst( getVoidType(), op, operandCount, @@ -5461,7 +5428,7 @@ namespace Slang IRTypeLayout* IRBuilder::getTypeLayout(IROp op, List<IRInst*> const& operands) { - return cast<IRTypeLayout>(findOrEmitHoistableInst( + return cast<IRTypeLayout>(createIntrinsicInst( getVoidType(), op, operands.getCount(), @@ -5470,7 +5437,7 @@ namespace Slang IRVarLayout* IRBuilder::getVarLayout(List<IRInst*> const& operands) { - return cast<IRVarLayout>(findOrEmitHoistableInst( + return cast<IRVarLayout>(createIntrinsicInst( getVoidType(), kIROp_VarLayout, operands.getCount(), @@ -5483,7 +5450,7 @@ namespace Slang { IRInst* operands[] = { paramsLayout, resultLayout }; - return cast<IREntryPointLayout>(findOrEmitHoistableInst( + return cast<IREntryPointLayout>(createIntrinsicInst( getVoidType(), kIROp_EntryPointLayout, SLANG_COUNT_OF(operands), @@ -6528,70 +6495,146 @@ namespace Slang void validateIRInstOperands(IRInst*); - void IRInst::replaceUsesWith(IRInst* other) + static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other) { - // Safety check: don't try to replace something with itself. - if(other == this) - return; + SharedIRBuilder* sharedBuilder = nullptr; - // We will walk through the list of uses for the current - // instruction, and make them point to the other inst. - IRUse* ff = firstUse; + struct WorkItem + { + IRInst* thisInst; + IRInst* otherInst; + }; - // No uses? Nothing to do. - if(!ff) - return; + // A work list of hoistable users for which we need + // to deduplicate/update their entry in the global numbering map. + List<WorkItem> workList; + HashSet<IRInst*> workListSet; - ff->debugValidate(); + auto addToWorkList = [&](IRInst* src, IRInst* target) + { + if (workListSet.Add(src)) + { + WorkItem item; + item.thisInst = src; + item.otherInst = target; + workList.add(item); + } + }; - IRUse* uu = ff; - for(;;) + addToWorkList(thisInst, other); + + for (Index i = 0; i < workList.getCount(); i++) { - // The uses had better all be uses of this - // instruction, or invariants are broken. - SLANG_ASSERT(uu->get() == this); + auto workItem = workList[i]; + thisInst = workItem.thisInst; + other = workItem.otherInst; - // Swap this use over to use the other value. - uu->usedValue = other; + // Safety check: don't try to replace something with itself. + if (other == thisInst) + continue; - // Try to move to the next use, but bail - // out if we are at the last one. - IRUse* nn = uu->nextUse; - if( !nn ) - break; + if (getIROpInfo(thisInst->getOp()).isHoistable()) + { + if (!sharedBuilder) + { + SLANG_ASSERT(thisInst->getModule()); + sharedBuilder = thisInst->getModule()->getSharedBuilder(); + } + sharedBuilder->getInstReplacementMap()[thisInst] = other; + } - uu = nn; - } + // We will walk through the list of uses for the current + // instruction, and make them point to the other inst. + IRUse* ff = thisInst->firstUse; - // We are at the last use (and there must - // be at least one, because we handled - // the case of an empty list earlier). - SLANG_ASSERT(uu); + // No uses? Nothing to do. + if (!ff) + continue; - // Our job at this point is to splice - // our list of uses onto the other - // value's uses. - // - // If the value already had uses, then - // we need to patch our new list onto - // the front. - if( auto nn = other->firstUse ) - { - uu->nextUse = nn; - nn->prevLink = &uu->nextUse; - } + //ff->debugValidate(); + + IRUse* uu = ff; + for (;;) + { + // The uses had better all be uses of this + // instruction, or invariants are broken. + SLANG_ASSERT(uu->get() == thisInst); + + auto user = uu->getUser(); + bool userIsHoistable = getIROpInfo(user->getOp()).isHoistable(); + if (userIsHoistable) + { + if (!sharedBuilder) + { + SLANG_ASSERT(user->getModule()); + sharedBuilder = user->getModule()->getSharedBuilder(); + } + sharedBuilder->_removeGlobalNumberingEntry(user); + } + + // Swap this use over to use the other value. + uu->usedValue = other; + + if (userIsHoistable) + { + // Is the updated inst already exists in the global numbering map? + // If so, we need to continue work on replacing the updated inst with the existing value. + IRInst* existingVal = nullptr; + if (sharedBuilder->getGlobalValueNumberingMap().TryGetValue(IRInstKey{ user }, existingVal)) + { + addToWorkList(user, existingVal); + } + else + { + sharedBuilder->_addGlobalNumberingEntry(user); + } + } + + // Try to move to the next use, but bail + // out if we are at the last one. + IRUse* nn = uu->nextUse; + if (!nn) + break; + + uu = nn; + } - // No matter what, our list of - // uses will become the start - // of the list of uses for - // `other` - other->firstUse = ff; - ff->prevLink = &other->firstUse; + // We are at the last use (and there must + // be at least one, because we handled + // the case of an empty list earlier). + SLANG_ASSERT(uu); - // And `this` will have no uses any more. - this->firstUse = nullptr; + // Our job at this point is to splice + // our list of uses onto the other + // value's uses. + // + // If the value already had uses, then + // we need to patch our new list onto + // the front. + if (auto nn = other->firstUse) + { + uu->nextUse = nn; + nn->prevLink = &uu->nextUse; + } + + // No matter what, our list of + // uses will become the start + // of the list of uses for + // `other` + other->firstUse = ff; + ff->prevLink = &other->firstUse; + + // And `this` will have no uses any more. + thisInst->firstUse = nullptr; + + ff->debugValidate(); + } - ff->debugValidate(); + } + + void IRInst::replaceUsesWith(IRInst* other) + { + _replaceInstUsesWith(this, other); } // Insert this instruction into the same basic block @@ -6750,9 +6793,21 @@ namespace Slang // and then destroy it (it had better have no uses!) void IRInst::removeAndDeallocate() { - removeFromParent(); + if (auto module = getModule()) + { + if (getIROpInfo(getOp()).isHoistable()) + { + module->getSharedBuilder()->removeHoistableInstFromGlobalNumberingMap(this); + } + else if (auto constInst = as<IRConstant>(this)) + { + module->getSharedBuilder()->getConstantMap().Remove(IRConstantKey{ constInst }); + } + module->getSharedBuilder()->getInstReplacementMap().Remove(this); + } removeArguments(); removeAndDeallocateAllDecorationsAndChildren(); + removeFromParent(); // Run destructor to be sure... this->~IRInst(); @@ -6919,7 +6974,6 @@ namespace Slang case kIROp_Not: case kIROp_BitNot: case kIROp_Select: - case kIROp_Dot: case kIROp_MakeExistential: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 41b140972..9b8aa5cb7 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -37,12 +37,14 @@ enum : IROpFlags kIROpFlags_None = 0, kIROpFlag_Parent = 1 << 0, ///< This op is a parent op kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information + kIROpFlag_Hoistable = 1 << 2, ///< If set this op is a hoistable inst that needs to be deduplicated. + kIROpFlag_Global = 1 << 3, ///< If set this op should always be hoisted but should never be deduplicated. }; /* Bit usage of IROp is a follows MainOp | Other -Bit range: 0-7 | Remaining bits +Bit range: 0-10 | Remaining bits For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere. The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently @@ -92,6 +94,9 @@ struct IROpInfo // Flags to control how we emit additional info IROpFlags flags; + + bool isHoistable() const { return (flags & kIROpFlag_Hoistable) != 0; } + bool isGlobal() const { return (flags & kIROpFlag_Global) != 0; } }; // Look up the info for an op @@ -206,6 +211,43 @@ struct IRInstList : IRInstListBase }; template<typename T> +struct IRModifiableInstList +{ + IRInst* parent; + List<IRInst*> workList; + + IRModifiableInstList() {} + + IRModifiableInstList(T* parent, T* first, T* last); + + T* getFirst() { return workList.getCount() ? (T*)workList.getFirst() : nullptr; } + T* getLast() { return workList.getCount() ? (T*)workList.getLast() : nullptr; } + + struct Iterator + { + IRModifiableInstList<T>* list; + Index position = 0; + + Iterator() {} + Iterator(IRModifiableInstList<T>* inList, Index inPos) : list(inList), position(inPos) {} + + T* operator*() + { + return (T*)(list->workList[position]); + } + void operator++(); + + bool operator!=(Iterator const& i) + { + return i.list != list || i.position != position; + } + }; + + Iterator begin() { return Iterator(this, 0); } + Iterator end() { return Iterator(this, workList.getCount()); } +}; + +template<typename T> struct IRFilteredInstList : IRInstListBase { IRFilteredInstList() {} @@ -591,6 +633,14 @@ struct IRInst getLastChild()); } + IRModifiableInstList<IRInst> getModifiableChildren() + { + return IRModifiableInstList<IRInst>( + this, + getFirstChild(), + getLastChild()); + } + /// A doubly-linked list containing any decorations and then any children of this instruction. /// /// We store both the decorations and children of an instruction @@ -607,7 +657,13 @@ struct IRInst IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; } IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; } IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; } - + IRModifiableInstList<IRInst> getModifiableDecorationsAndChildren() + { + return IRModifiableInstList<IRInst>( + this, + m_decorationsAndChildren.first, + m_decorationsAndChildren.last); + } void removeAndDeallocateAllDecorationsAndChildren(); #ifdef SLANG_ENABLE_IR_BREAK_ALLOC @@ -647,6 +703,12 @@ struct IRInst getOperands()[index].set(value); } + void unsafeSetOperand(UInt index, IRInst* value) + { + SLANG_ASSERT(getOperands()[index].user != nullptr); + getOperands()[index].init(this, value); + } + // @@ -773,6 +835,39 @@ typename IRInstList<T>::Iterator IRInstList<T>::end() } template<typename T> +IRModifiableInstList<T>::IRModifiableInstList(T* inParent, T* first, T* last) +{ + parent = inParent; + for (auto item = first; item; item = item->next) + { + workList.add(item); + if (item == last) + break; + } +} + +template<typename T> +void IRModifiableInstList<T>::Iterator::operator++() +{ + position++; + while (position < list->workList.getCount()) + { + auto inst = list->workList[position]; + if (!as<T>(inst)) + { + // Skip insts that are not of type T. + } + else if (list->parent != inst->parent) + { + // Skip insts that are no longer in its original parent. + } + else + break; + position++; + } +} + +template<typename T> IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst) { first = fst; @@ -1796,6 +1891,104 @@ struct IRModuleInst : IRInst IR_LEAF_ISA(Module) }; +struct IRModule; + +// Description of an instruction to be used for global value numbering +struct IRInstKey +{ + IRInst* inst; + + HashCode getHashCode(); +}; + +bool operator==(IRInstKey const& left, IRInstKey const& right); + +struct IRConstantKey +{ + IRConstant* inst; + + bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); } + HashCode getHashCode() const { return inst->getHashCode(); } +}; + +struct SharedIRBuilder +{ +public: + SharedIRBuilder() + {} + + explicit SharedIRBuilder(IRModule* module) + { + init(module); + } + + void init(IRModule* module); + + IRModule* getModule() + { + return m_module; + } + + Session* getSession() + { + return m_session; + } + + void insertBlockAlongEdge(IREdge const& edge); + + // Rebuilds `globalValueNumberingMap`. This is necessary if any existing + // keys are modified (thus its hash code is changed). + void deduplicateAndRebuildGlobalNumberingMap(); + + // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement. + void replaceGlobalInst(IRInst* oldInst, IRInst* newInst); + + void removeHoistableInstFromGlobalNumberingMap(IRInst* inst); + + void tryHoistInst(IRInst* inst); + + typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap; + typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap; + + GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; } + Dictionary<IRInst*, IRInst*>& getInstReplacementMap() { return m_instReplacementMap; } + + void _addGlobalNumberingEntry(IRInst* inst) + { + m_globalValueNumberingMap.Add(IRInstKey{ inst }, inst); + m_instReplacementMap.Remove(inst); + tryHoistInst(inst); + } + void _removeGlobalNumberingEntry(IRInst* inst) + { + IRInst* value = nullptr; + if (m_globalValueNumberingMap.TryGetValue(IRInstKey{ inst }, value)) + { + if (value == inst) + { + m_globalValueNumberingMap.Remove(IRInstKey{ inst }); + } + } + } + + ConstantMap& getConstantMap() { return m_constantMap; } + +private: + // The module that will own all of the IR + IRModule* m_module; + + // The parent compilation session + Session* m_session; + + GlobalValueNumberingMap m_globalValueNumberingMap; + + // Duplicate insts that are still alive and needs to be replaced in m_globalValueNumberMap + // when used as an operand to create another inst. + Dictionary<IRInst*, IRInst*> m_instReplacementMap; + + ConstantMap m_constantMap; +}; + struct IRModule : RefObject { public: @@ -1810,6 +2003,8 @@ public: SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return m_moduleInst; } SLANG_FORCE_INLINE MemoryArena& getMemoryArena() { return m_memoryArena; } + SharedIRBuilder* getSharedBuilder() const { return &m_sharedBuilder; } + IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); } /// Create an empty instruction with the `op` opcode and space for @@ -1853,6 +2048,7 @@ private: IRModule(Session* session) : m_session(session) , m_memoryArena(kMemoryArenaBlockSize) + , m_sharedBuilder(this) { } @@ -1870,6 +2066,9 @@ private: /// The memory arena from which all IR instructions (and any associated state) in this module are allocated. MemoryArena m_memoryArena; + + /// Shared contexts for constructing and maintaining the IR. + mutable SharedIRBuilder m_sharedBuilder; }; struct IRSpecializationDictionaryItem : public IRInst @@ -1943,13 +2142,17 @@ uint32_t& _debugGetIRAllocCounter(); // TODO: Ellie, comment and move somewhere more appropriate? template<typename I = IRInst, typename F> -static void traverseUses(IRInst* inst, F f) +static void traverseUsers(IRInst* inst, F f) { - auto n = inst->firstUse; - IRUse* u; - while((u = n) != nullptr) + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) { - n = u->nextUse; + uses.add(use); + } + for (auto u : uses) + { + if (u->usedValue != inst) + continue; if(auto s = as<I>(u->getUser())) { f(s); @@ -1957,6 +2160,22 @@ static void traverseUses(IRInst* inst, F f) } } +template<typename F> +static void traverseUses(IRInst* inst, F f) +{ + List<IRUse*> uses; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto u : uses) + { + if (u->usedValue != inst) + continue; + f(u); + } +} + namespace detail { // A helper to get the singular pointer argument of something callable |
