summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-16 13:55:32 -0800
committerGitHub <noreply@github.com>2023-02-16 13:55:32 -0800
commit4c4826d47eeef4675daae4ae53ff76f4d5ebd84a (patch)
treeed4af0ded878e4f06e9641ce61d26ffd7c89ccbc /source
parenteda88e513e8b1e2abc05e9dc8555f237d96472df (diff)
Overhaul global inst deduplication and cpp/cuda backend. (#2654)
* Overhaul global inst deduplication and cpp/cuda backend. * Update IR documentation. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang5
-rw-r--r--source/slang/hlsl.meta.slang44
-rw-r--r--source/slang/slang-emit-cpp.cpp1547
-rw-r--r--source/slang/slang-emit-cpp.h38
-rw-r--r--source/slang/slang-emit-cuda.cpp511
-rw-r--r--source/slang/slang-emit-cuda.h11
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-hlsl-intrinsic-set.cpp590
-rw-r--r--source/slang/slang-hlsl-intrinsic-set.h212
-rw-r--r--source/slang/slang-ir-address-analysis.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp71
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h8
-rw-r--r--source/slang/slang-ir-byte-address-legalize.cpp5
-rw-r--r--source/slang/slang-ir-clone.cpp28
-rw-r--r--source/slang/slang-ir-collect-global-uniforms.cpp5
-rw-r--r--source/slang/slang-ir-com-interface.cpp2
-rw-r--r--source/slang/slang-ir-dce.cpp14
-rw-r--r--source/slang/slang-ir-deduplicate.cpp154
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp6
-rw-r--r--source/slang/slang-ir-inline.cpp4
-rw-r--r--source/slang/slang-ir-inst-defs.h219
-rw-r--r--source/slang/slang-ir-insts.h169
-rw-r--r--source/slang/slang-ir-legalize-mesh-outputs.cpp2
-rw-r--r--source/slang/slang-ir-legalize-types.cpp23
-rw-r--r--source/slang/slang-ir-link.cpp48
-rw-r--r--source/slang/slang-ir-lower-generic-function.cpp56
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp2
-rw-r--r--source/slang/slang-ir-simplify-for-emit.cpp121
-rw-r--r--source/slang/slang-ir-simplify-for-emit.h3
-rw-r--r--source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp25
-rw-r--r--source/slang/slang-ir-specialize-resources.cpp48
-rw-r--r--source/slang/slang-ir-specialize.cpp7
-rw-r--r--source/slang/slang-ir-ssa.cpp9
-rw-r--r--source/slang/slang-ir-type-set.cpp309
-rw-r--r--source/slang/slang-ir-type-set.h81
-rw-r--r--source/slang/slang-ir-util.cpp2
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-ir-validate.cpp22
-rw-r--r--source/slang/slang-ir-wrap-structured-buffers.cpp18
-rw-r--r--source/slang/slang-ir.cpp490
-rw-r--r--source/slang/slang-ir.h233
41 files changed, 1497 insertions, 3653 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 2a8344e3a..6357d58bd 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -786,6 +786,8 @@ __generic<T = float, let R : int = 4, let C : int = 4>
__magic_type(Matrix)
struct matrix
{
+ __intrinsic_op($(kIROp_MakeMatrixFromScalar))
+ __init(T val);
}
${{{{
@@ -1093,9 +1095,6 @@ extension matrix<T, R, C> : IDifferentiable
{
typedef matrix<T, R, C> Differential;
- __intrinsic_op($(kIROp_MakeMatrixFromScalar))
- __init(T val);
-
[__unsafeForceInlineEarly]
static Differential dzero()
{
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 464811a96..1d2b327d2 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -872,36 +872,31 @@ matrix<T, N, M> acos(matrix<T, N, M> x)
// Test if all components are non-zero (HLSL SM 1.0)
__generic<T : __BuiltinType>
+__target_intrinsic(cpp, "bool($0)")
+__target_intrinsic(cuda, "bool($0)")
__target_intrinsic(glsl, "bool($0)")
bool all(T x);
__generic<T : __BuiltinType, let N : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "all(bvec$N0($0))")
-bool all(vector<T,N> x);
-// TODO: implementation of `all()` in the stdlib is
-// blocked on fixing implementation of `bool` vector
-// `getAt` on the CUDA codegen path.
-/*
+bool all(vector<T,N> x)
{
bool result = true;
for(int i = 0; i < N; ++i)
result = result && all(x[i]);
return result;
}
-*/
__generic<T : __BuiltinType, let N : int, let M : int>
__target_intrinsic(hlsl)
-bool all(matrix<T,N,M> x);
-/*
+bool all(matrix<T,N,M> x)
{
bool result = true;
for(int i = 0; i < N; ++i)
result = result && all(x[i]);
return result;
}
-*/
// Barrier for writes to all memory spaces (HLSL SM 5.0)
__target_intrinsic(glsl, "memoryBarrier(), groupMemoryBarrier(), memoryBarrierImage(), memoryBarrierBuffer()")
@@ -916,42 +911,39 @@ void AllMemoryBarrierWithGroupSync();
// Test if any components is non-zero (HLSL SM 1.0)
__generic<T : __BuiltinType>
+__target_intrinsic(cpp, "bool($0)")
+__target_intrinsic(cuda, "bool($0)")
__target_intrinsic(glsl, "bool($0)")
bool any(T x);
__generic<T : __BuiltinType, let N : int>
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "any(bvec$N0($0))")
-bool any(vector<T, N> x);
-// TODO: implementation of `any()` in the stdlib is
-// blocked on fixing implementation of `bool` vector
-// `getAt` on the CUDA codegen path.
-/*
+bool any(vector<T, N> x)
{
bool result = false;
for(int i = 0; i < N; ++i)
result = result || any(x[i]);
return result;
}
-*/
__generic<T : __BuiltinType, let N : int, let M : int>
__target_intrinsic(hlsl)
-bool any(matrix<T, N, M> x);
-/*
+bool any(matrix<T, N, M> x)
{
bool result = false;
for(int i = 0; i < N; ++i)
result = result || any(x[i]);
return result;
}
-*/
// Reinterpret bits as a double (HLSL SM 5.0)
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "packDouble2x32(uvec2($0, $1))")
+__target_intrinsic(cpp, "$P_asdouble($0, $1)")
+__target_intrinsic(cuda, "$P_asdouble($0, $1)")
__target_intrinsic(spirv_direct, "%v = OpCompositeConstruct _type(uint2) resultId _0 _1; OpExtInst resultType resultId glsl450 59 %v")
__glsl_extension(GL_ARB_gpu_shader5)
double asdouble(uint lowbits, uint highbits);
@@ -960,11 +952,15 @@ double asdouble(uint lowbits, uint highbits);
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "intBitsToFloat")
+__target_intrinsic(cpp, "$P_asfloat($0)")
+__target_intrinsic(cuda, "$P_asfloat($0)")
__target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0")
float asfloat(int x);
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "uintBitsToFloat")
+__target_intrinsic(cpp, "$P_asfloat($0)")
+__target_intrinsic(cuda, "$P_asfloat($0)")
__target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0")
float asfloat(uint x);
@@ -1044,11 +1040,15 @@ matrix<T, N, M> asin(matrix<T, N, M> x)
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "floatBitsToInt")
+__target_intrinsic(cpp, "$P_asint($0)")
+__target_intrinsic(cuda, "$P_asint($0)")
__target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0")
int asint(float x);
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "int($0)")
+__target_intrinsic(cpp, "$P_asint($0)")
+__target_intrinsic(cuda, "$P_asint($0)")
__target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0")
int asint(uint x);
@@ -1104,6 +1104,8 @@ matrix<int,N,M> asint(matrix<int,N,M> x)
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "{ uvec2 v = unpackDouble2x32($0); $1 = v.x; $2 = v.y; }")
__glsl_extension(GL_ARB_gpu_shader5)
+__target_intrinsic(cpp, "$P_asuint($0, $1, $2)")
+__target_intrinsic(cuda, "$P_asuint($0, $1, $2)")
void asuint(double value, out uint lowbits, out uint highbits);
// Reinterpret bits as a uint (HLSL SM 4.0)
@@ -1111,11 +1113,15 @@ void asuint(double value, out uint lowbits, out uint highbits);
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "floatBitsToUint")
__target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0")
+__target_intrinsic(cpp, "$P_asuint($0)")
+__target_intrinsic(cuda, "$P_asuint($0)")
uint asuint(float x);
__target_intrinsic(hlsl)
__target_intrinsic(glsl, "uint($0)")
__target_intrinsic(spirv_direct, "OpBitcast resultType resultId _0")
+__target_intrinsic(cpp, "$P_asuint($0)")
+__target_intrinsic(cuda, "$P_asuint($0)")
uint asuint(int x);
__generic<let N : int>
@@ -1812,7 +1818,7 @@ __target_intrinsic(glsl, "unpackHalf2x16($0).x")
__glsl_version(420)
__target_intrinsic(hlsl)
__cuda_sm_version(6.0)
-__target_intrinsic(cuda, "__half2float(__short_as_half($0))")
+__target_intrinsic(cuda, "__half2float(__ushort_as_half($0))")
float f16tof32(uint value);
__generic<let N : int>
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index 87b620ed2..ba6b26ec6 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -66,111 +66,6 @@ namespace Slang {
static const char s_xyzwNames[] = "xyzw";
-static UnownedStringSlice _getTypePrefix(IROp op)
-{
- switch (op)
- {
- case kIROp_BoolType: return UnownedStringSlice::fromLiteral("Bool");
- case kIROp_IntType: return UnownedStringSlice::fromLiteral("I32");
- case kIROp_UIntType: return UnownedStringSlice::fromLiteral("U32");
- case kIROp_FloatType: return UnownedStringSlice::fromLiteral("F32");
- case kIROp_Int64Type: return UnownedStringSlice::fromLiteral("I64");
- case kIROp_UInt64Type: return UnownedStringSlice::fromLiteral("U64");
- case kIROp_DoubleType: return UnownedStringSlice::fromLiteral("F64");
- default: return UnownedStringSlice::fromLiteral("?");
- }
-}
-
-
-static IROp _getCType(IROp op)
-{
- switch (op)
- {
- case kIROp_VoidType:
- case kIROp_BoolType:
- {
- return op;
- }
- case kIROp_Int8Type:
- case kIROp_Int16Type:
- case kIROp_IntType:
- case kIROp_UInt8Type:
- case kIROp_UInt16Type:
- case kIROp_UIntType:
- {
- // Promote all these to Int
- return kIROp_IntType;
- }
- case kIROp_IntPtrType:
- case kIROp_UIntPtrType:
- {
- return kIROp_IntPtrType;
- }
- case kIROp_Int64Type:
- case kIROp_UInt64Type:
- {
- // Promote all these to Int64, we can just vary the call to make these work
- return kIROp_Int64Type;
- }
- case kIROp_DoubleType:
- {
- return kIROp_DoubleType;
- }
- case kIROp_HalfType:
- case kIROp_FloatType:
- {
- // Promote both to float
- return kIROp_FloatType;
- }
- default:
- {
- SLANG_ASSERT(!"Unhandled type");
- return kIROp_undefined;
- }
- }
-}
-
-static UnownedStringSlice _getCTypeVecPostFix(IROp op)
-{
- switch (op)
- {
- case kIROp_BoolType: return UnownedStringSlice::fromLiteral("B");
- case kIROp_IntType: return UnownedStringSlice::fromLiteral("I");
- case kIROp_UIntType: return UnownedStringSlice::fromLiteral("U");
- case kIROp_FloatType: return UnownedStringSlice::fromLiteral("F");
- case kIROp_Int64Type: return UnownedStringSlice::fromLiteral("I64");
- case kIROp_DoubleType: return UnownedStringSlice::fromLiteral("F64");
- case kIROp_IntPtrType: return UnownedStringSlice::fromLiteral("");
- case kIROp_UIntPtrType: return UnownedStringSlice::fromLiteral("");
- default: return UnownedStringSlice::fromLiteral("?");
- }
-}
-
-static bool _isCppTarget(CodeGenTarget target)
-{
- switch (target)
- {
- case CodeGenTarget::CPPSource:
- case CodeGenTarget::HostCPPSource:
- return true;
- default:
- return false;
- }
-}
-
-static bool _isCppOrCudaTarget(CodeGenTarget target)
-{
- switch (target)
- {
- case CodeGenTarget::CPPSource:
- case CodeGenTarget::HostCPPSource:
- case CodeGenTarget::CUDASource:
- return true;
- default:
- return false;
- }
-}
-
/* !!!!!!!!!!!!!!!!!!!!!!!! CPPEmitHandler !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
/* static */ UnownedStringSlice CPPSourceEmitter::getBuiltinTypeName(IROp op)
@@ -204,118 +99,8 @@ static bool _isCppOrCudaTarget(CodeGenTarget target)
}
}
-void CPPSourceEmitter::emitTypeDefinition(IRType* inType)
+UnownedStringSlice CPPSourceEmitter::_getTypeName(IRType* type)
{
- if (_isCppTarget(m_target))
- {
- // All types are templates in C++
- return;
- }
-
- IRType* type = m_typeSet.getType(inType);
- if (!m_typeSet.isOwned(type))
- {
- // If defined in a different module, we assume they are emitted already. (Assumed to
- // be a nominal type)
- return;
- }
-
- SourceWriter* writer = getSourceWriter();
-
- switch (type->getOp())
- {
- case kIROp_VectorType:
- {
- auto vecType = static_cast<IRVectorType*>(type);
-
- const UnownedStringSlice* elemNames = getVectorElementNames(vecType);
-
- int count = int(getIntVal(vecType->getElementCount()));
-
- SLANG_ASSERT(count > 0 && count < 4);
-
- UnownedStringSlice typeName = _getTypeName(type);
- UnownedStringSlice elemName = _getTypeName(vecType->getElementType());
-
- writer->emit("struct ");
- writer->emit(typeName);
- writer->emit("\n{\n");
- writer->indent();
-
- writer->emit(elemName);
- writer->emit(" ");
- for (int i = 0; i < count; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
- writer->emit(elemNames[i]);
- }
- writer->emit(";\n");
-
- writer->dedent();
- writer->emit("};\n\n");
- break;
- }
- case kIROp_MatrixType:
- {
- auto matType = static_cast<IRMatrixType*>(type);
-
- const auto rowCount = int(getIntVal(matType->getRowCount()));
- const auto colCount = int(getIntVal(matType->getColumnCount()));
-
- IRType* vecType = m_typeSet.addVectorType(matType->getElementType(), colCount);
-
- UnownedStringSlice typeName = _getTypeName(type);
- UnownedStringSlice rowTypeName = _getTypeName(vecType);
-
- writer->emit("template<>\n");
- writer->emit("struct ");
- writer->emit(typeName);
- writer->emit("\n{\n");
- writer->indent();
-
- writer->emit(rowTypeName);
- writer->emit(" rows[");
- writer->emit(rowCount);
- writer->emit("];\n");
-
- writer->dedent();
- writer->emit("};\n\n");
- break;
- }
- case kIROp_PtrType:
- case kIROp_RefType:
- {
- // We don't need to output a definition for these types
- break;
- }
- case kIROp_ArrayType:
- case kIROp_UnsizedArrayType:
- case kIROp_HLSLRWStructuredBufferType:
- {
- // We don't need to output a definition for these with C++ templates
- // For C we may need to (or do casting at point of usage)
- break;
- }
- default:
- {
- if (IRBasicType::isaImpl(type->getOp()))
- {
- // Don't emit anything for built in types
- return;
- }
- SLANG_ASSERT(!"Unhandled type");
- break;
- }
- }
-}
-
-UnownedStringSlice CPPSourceEmitter::_getTypeName(IRType* inType)
-{
- IRType* type = m_typeSet.getType(inType);
-
StringSlicePool::Handle handle = StringSlicePool::kNullHandle;
if (m_typeNameMap.TryGetValue(type, handle))
{
@@ -424,22 +209,7 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
auto vecCount = int(getIntVal(vecType->getElementCount()));
auto elemType = vecType->getElementType();
- if (_isCppOrCudaTarget(target))
- {
- out << "Vector<" << _getTypeName(elemType) << ", " << vecCount << ">";
- }
- else
- {
- out << "Vec";
- UnownedStringSlice postFix = _getCTypeVecPostFix(elemType->getOp());
-
- out << postFix;
- if (postFix.getLength() > 1)
- {
- out << "_";
- }
- out << vecCount;
- }
+ out << "Vector<" << _getTypeName(elemType) << ", " << vecCount << ">";
return SLANG_OK;
}
case kIROp_MatrixType:
@@ -450,22 +220,8 @@ SlangResult CPPSourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, S
const auto rowCount = int(getIntVal(matType->getRowCount()));
const auto colCount = int(getIntVal(matType->getColumnCount()));
- if (_isCppOrCudaTarget(target))
- {
- out << "Matrix<" << _getTypeName(elementType) << ", " << rowCount << ", " << colCount << ">";
- }
- else
- {
- out << "Mat";
- const UnownedStringSlice postFix = _getCTypeVecPostFix(_getCType(elementType->getOp()));
- out << postFix;
- if (postFix.getLength() > 1)
- {
- out << "_";
- }
- out << rowCount;
- out << colCount;
- }
+ out << "Matrix<" << _getTypeName(elementType) << ", " << rowCount << ", " << colCount << ">";
+
return SLANG_OK;
}
case kIROp_WitnessTableType:
@@ -625,17 +381,6 @@ void CPPSourceEmitter::useType(IRType* type)
_getTypeName(type);
}
-static IRBasicType* _getElementType(IRType* type)
-{
- switch (type->getOp())
- {
- case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break;
- case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break;
- default: break;
- }
- return dynamicCast<IRBasicType>(type);
-}
-
/* static */CPPSourceEmitter::TypeDimension CPPSourceEmitter::_getTypeDimension(IRType* type, bool vecSwap)
{
switch (type->getOp())
@@ -735,943 +480,11 @@ void CPPSourceEmitter::_emitAccess(const UnownedStringSlice& name, const TypeDim
}
}
-static bool _isOperator(const UnownedStringSlice& funcName)
-{
- if (funcName.getLength() > 0)
- {
- const char c = funcName[0];
- return !((c >= 'a' && c <='z') || (c >= 'A' && c <= 'Z') || c == '_');
- }
- return false;
-}
-
-void CPPSourceEmitter::_emitAryDefinition(const HLSLIntrinsic* specOp)
-{
- auto info = HLSLIntrinsic::getInfo(specOp->op);
- auto funcName = info.funcName;
- SLANG_ASSERT(funcName.getLength() > 0);
-
- const bool isOperator = _isOperator(funcName);
-
- SourceWriter* writer = getSourceWriter();
-
- IRFuncType* funcType = specOp->signatureType;
- const int numParams = int(funcType->getParamCount());
- SLANG_ASSERT(numParams <= 3);
-
- bool areAllScalar = true;
- TypeDimension paramDims[3];
- for (int i = 0; i < numParams; ++i)
- {
- paramDims[i]= _getTypeDimension(funcType->getParamType(i), false);
- areAllScalar = areAllScalar && paramDims[i].isScalar();
- }
-
- // If all are scalar, then we don't need to emit a definition
- if (areAllScalar)
- {
- return;
- }
-
- IRType* retType = specOp->returnType;
-
- UnownedStringSlice scalarFuncName(funcName);
- if (isOperator)
- {
- StringBuilder builder;
- builder << "operator";
- builder << funcName;
- _emitSignature(builder.getUnownedSlice(), specOp);
- }
- else
- {
- scalarFuncName = _getScalarFuncName(specOp->op, _getElementType(funcType->getParamType(0)));
- _emitSignature(funcName, specOp);
- }
-
- writer->emit("\n{\n");
- writer->indent();
-
- const bool hasReturnType = retType->getOp() != kIROp_VoidType;
-
- TypeDimension calcDim;
- if (hasReturnType)
- {
- emitType(retType);
- writer->emit(" r;\n");
-
- calcDim = _getTypeDimension(retType, false);
- }
- else
- {
- calcDim = _getTypeDimension(funcType->getParamType(0), false);
- }
-
- for (int i = 0; i < calcDim.rowCount; ++i)
- {
- for (int j = 0; j < calcDim.colCount; ++j)
- {
- if (hasReturnType)
- {
- _emitAccess(UnownedStringSlice::fromLiteral("r"), calcDim, i, j, writer);
- writer->emit(" = ");
- }
-
- if (isOperator)
- {
- switch (numParams)
- {
- case 1:
- {
- writer->emit(funcName);
- _emitAccess(UnownedStringSlice::fromLiteral("a"), paramDims[0], i, j, writer);
- break;
- }
- case 2:
- {
- _emitAccess(UnownedStringSlice::fromLiteral("a"), paramDims[0], i, j, writer);
- writer->emit(" ");
- writer->emit(funcName);
- writer->emit(" ");
- _emitAccess(UnownedStringSlice::fromLiteral("b"), paramDims[1], i, j, writer);
- break;
- }
- default: SLANG_ASSERT(!"Unhandled");
- }
- }
- else
- {
- writer->emit(scalarFuncName);
- writer->emit("(");
- for (int k = 0; k < numParams; k++)
- {
- if (k > 0)
- {
- writer->emit(", ");
- }
- char c = char('a' + k);
- _emitAccess(UnownedStringSlice(&c, 1), paramDims[k], i, j, writer);
- }
- writer->emit(")");
- }
- writer->emit(";\n");
- }
- }
-
- if (hasReturnType)
- {
- writer->emit("return r;\n");
- }
-
- writer->dedent();
- writer->emit("}\n\n");
-}
-
-void CPPSourceEmitter::_emitAnyAllDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp)
-{
- IRFuncType* funcType = specOp->signatureType;
- SLANG_ASSERT(funcType->getParamCount() == 1);
- IRType* paramType0 = funcType->getParamType(0);
-
- SourceWriter* writer = getSourceWriter();
-
- IRType* elementType = _getElementType(paramType0);
- SLANG_ASSERT(elementType);
- IRType* retType = specOp->returnType;
- auto retTypeName = _getTypeName(retType);
-
- IROp style = getTypeStyle(elementType->getOp());
-
- const TypeDimension dim = _getTypeDimension(paramType0, false);
-
- _emitSignature(funcName, specOp);
- writer->emit("\n{\n");
- writer->indent();
-
- writer->emit("return ");
-
- for (int i = 0; i < dim.rowCount; ++i)
- {
- for (int j = 0; j < dim.colCount; ++j)
- {
- if (i > 0 || j > 0)
- {
- if (specOp->op == HLSLIntrinsic::Op::All)
- {
- writer->emit(" && ");
- }
- else
- {
- writer->emit(" || ");
- }
- }
-
- switch (style)
- {
- case kIROp_BoolType:
- {
- _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer);
- break;
- }
- case kIROp_IntType:
- {
- writer->emit("(");
- _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer);
- writer->emit(" != 0)");
- break;
- }
- case kIROp_FloatType:
- {
- writer->emit("(");
- _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer);
- writer->emit(" != 0.0)");
- break;
- }
- }
- }
- }
-
- writer->emit(";\n");
-
- writer->dedent();
- writer->emit("}\n\n");
-}
-
-void CPPSourceEmitter::_emitSignature(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp)
-{
- IRFuncType* funcType = specOp->signatureType;
- const int paramsCount = int(funcType->getParamCount());
- IRType* retType = specOp->returnType;
-
- emitFunctionPreambleImpl(nullptr);
-
- SourceWriter* writer = getSourceWriter();
-
- emitType(retType);
- writer->emit(" ");
- writer->emit(funcName);
- writer->emit("(");
-
- for (int i = 0; i < paramsCount; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
-
- // We can't pass as const& for vector, scalar, array types, as they are pass by value
- // For types passed by reference, we should do something different
- IRType* paramType = funcType->getParamType(i);
-#if 0
- writer->emit("const ");
-#endif
- emitType(paramType);
-#if 0
- if (dynamicCast<IRBasicType>(paramType))
- {
- writer->emit(" ");
- }
- else
- {
- writer->emit("& ");
- }
-#else
-
- writer->emit(" ");
-#endif
-
- writer->emitChar(char('a' + i));
- }
- writer->emit(")");
-}
-
-UnownedStringSlice CPPSourceEmitter::_getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType)
-{
- HLSLIntrinsic intrinsic;
- m_intrinsicSet.calcIntrinsic(op, retType, argTypes, argCount, intrinsic);
- auto specOp = m_intrinsicSet.add(intrinsic);
- _maybeEmitSpecializedOperationDefinition(specOp);
- return _getFuncName(specOp);
-}
-
-void CPPSourceEmitter::_emitGetAtDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp)
-{
- SourceWriter* writer = getSourceWriter();
-
- IRFuncType* funcType = specOp->signatureType;
- SLANG_ASSERT(funcType->getParamCount() == 2);
-
- IRType* srcType = funcType->getParamType(0);
-
- for (Index i = 0; i < 3; ++i)
- {
- UnownedStringSlice typePrefix = (i == 0) ? UnownedStringSlice::fromLiteral("const ") : UnownedStringSlice();
- bool lValue = (i != 2);
-
- emitFunctionPreambleImpl(nullptr);
-
- writer->emit(typePrefix);
- emitType(specOp->returnType);
- if (lValue)
- m_writer->emit("*");
- writer->emit(" ");
- writer->emit(funcName);
- writer->emit("(");
-
- writer->emit(typePrefix);
- emitType(funcType->getParamType(0));
- if (lValue)
- writer->emit("*");
- writer->emit(" a, ");
- emitType(funcType->getParamType(1));
- writer->emit(" b)\n{\n");
-
- writer->indent();
-
- if (auto vectorType = as<IRVectorType>(srcType))
- {
- int vecSize = int(getIntVal(vectorType->getElementCount()));
-
- writer->emit("SLANG_PRELUDE_ASSERT(b >= 0 && b < ");
- writer->emit(vecSize);
- writer->emit(");\n");
-
- writer->emit("return ((");
- emitType(specOp->returnType);
- writer->emit("*)");
-
- if (lValue)
- writer->emit("a) + b;\n");
- else
- writer->emit("&a)[b];\n");
- }
- else if (auto matrixType = as<IRMatrixType>(srcType))
- {
- //int colCount = int(getIntVal(matrixType->getColumnCount()));
- int rowCount = int(getIntVal(matrixType->getRowCount()));
-
- writer->emit("SLANG_PRELUDE_ASSERT(b >= 0 && b < ");
- writer->emit(rowCount);
- writer->emit(");\n");
-
- if (lValue)
- writer->emit("return &(a->rows[b]);\n");
- else
- writer->emit("return a.rows[b];\n");
- }
-
- writer->dedent();
- writer->emit("}\n\n");
- }
-}
-
-void CPPSourceEmitter::_emitConstructConvertDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp)
-{
- SourceWriter* writer = getSourceWriter();
- IRFuncType* funcType = specOp->signatureType;
-
- SLANG_ASSERT(funcType->getParamCount() == 2);
-
- IRType* srcType = funcType->getParamType(1);
- IRType* retType = specOp->returnType;
-
- emitFunctionPreambleImpl(nullptr);
-
- emitType(retType);
- writer->emit(" ");
- writer->emit(funcName);
- writer->emit("(");
- emitType(srcType);
- writer->emitChar(' ');
- writer->emitChar(char('a' + 0));
- writer->emit(")");
-
- writer->emit("\n{\n");
- writer->indent();
-
- writer->emit("return ");
- emitType(retType);
- writer->emit("{ ");
-
-
- IRType* dstElemType = _getElementType(retType);
- //IRType* srcElemType = _getElementType(srcType);
-
- TypeDimension dim = _getTypeDimension(retType, false);
-
- UnownedStringSlice rowTypeName;
- if (dim.rowCount > 1)
- {
- IRType* rowType = m_typeSet.addVectorType(dstElemType, int(dim.colCount));
- rowTypeName = _getTypeName(rowType);
- }
-
- for (int i = 0; i < dim.rowCount; ++i)
- {
- if (dim.rowCount > 1)
- {
- if (i > 0)
- {
- writer->emit(", \n");
- }
-
- if (m_target == CodeGenTarget::CUDASource)
- {
- m_writer->emit("make_");
- writer->emit(rowTypeName);
- m_writer->emit("(");
- }
- else
- {
- writer->emit(rowTypeName);
- writer->emit("{ ");
- }
- }
-
- for (int j = 0; j < dim.colCount; ++j)
- {
- if (j > 0)
- {
- writer->emit(", ");
- }
-
- emitType(dstElemType);
- writer->emit("(");
- _emitAccess(UnownedStringSlice::fromLiteral("a"), dim, i, j, writer);
- writer->emit(")");
- }
- if (dim.rowCount > 1)
- {
- if (m_target == CodeGenTarget::CUDASource)
- {
- writer->emit(")");
- }
- else
- {
- writer->emit("}");
- }
- }
- }
-
- writer->emit("};\n");
-
- writer->dedent();
- writer->emit("}\n\n");
-}
-
-void CPPSourceEmitter::_emitInitDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp)
-{
- SourceWriter* writer = getSourceWriter();
- IRFuncType* funcType = specOp->signatureType;
-
- emitFunctionPreambleImpl(nullptr);
-
- IRType* retType = specOp->returnType;
-
- _emitSignature(funcName, specOp);
- writer->emit("\n{\n");
- writer->indent();
-
- // Use C++ construction
- writer->emit("return ");
- emitType(retType);
- writer->emit("{ ");
-
- const Index paramCount = Index(funcType->getParamCount());
- bool handled = false;
-
- if (IRVectorType* vecType = as<IRVectorType>(retType))
- {
- Index elementCount = Index(getIntVal(vecType->getElementCount()));
-
- Index paramIndex = 0;
- Index paramSubIndex = 0;
-
- for (Index i = 0; i < elementCount; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
-
- if (paramIndex >= paramCount)
- {
- writer->emit("0");
- }
- else
- {
- IRType* paramType = funcType->getParamType(paramIndex);
-
- if (IRVectorType* paramVecType = as<IRVectorType>(paramType))
- {
- Index paramElementCount = Index(getIntVal(paramVecType->getElementCount()));
-
- const UnownedStringSlice* elemNames = getVectorElementNames(paramVecType);
-
- writer->emitChar('a' + char(paramIndex));
- writer->emit(".");
- writer->emit(elemNames[paramSubIndex]);
-
- paramSubIndex++;
-
- if (paramSubIndex >= paramElementCount)
- {
- paramIndex++;
- paramSubIndex = 0;
- }
- }
- else
- {
- writer->emitChar('a' + char(paramIndex));
- paramIndex++;
- }
- }
- }
- handled = true;
- }
- else if (IRMatrixType* matType = as<IRMatrixType>(retType))
- {
- if (paramCount != 1)
- goto fallback;
-
- auto paramMat = as<IRMatrixType>(funcType->getParamType(0));
- if (!paramMat)
- goto fallback;
-
- // We are constructing a matrix from a differently sized matrix.
-
- Index rows = Index(getIntVal(matType->getRowCount()));
- Index cols = Index(getIntVal(matType->getColumnCount()));
- Index paramRows = Index(getIntVal(paramMat->getRowCount()));
- Index paramCols = Index(getIntVal(paramMat->getColumnCount()));
- char elementNames[] = { 'x', 'y', 'z', 'w' };
-
- for (Index r = 0; r < rows; r++)
- {
- for (Index c = 0; c < cols; c++)
- {
- if (r != 0 || c != 0)
- writer->emit(", ");
-
- if (r < paramRows && c < paramCols && c < 4)
- {
- writer->emitRawText("a.rows[");
- writer->emit(r);
- writer->emitRawText("].");
- writer->emitChar(elementNames[c]);
- }
- else
- {
- writer->emit("0");
- }
- }
- }
- handled = true;
- }
-fallback:
- if (!handled)
- {
- // Fallback default: just use all params to construct.
- for (Index i = 0; i < paramCount; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
- writer->emitChar('a' + char(i));
- }
- }
-
- writer->emit("};\n");
-
- writer->dedent();
- writer->emit("}\n\n");
-}
-
-
-void CPPSourceEmitter::_emitConstructFromScalarDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp)
-{
- SourceWriter* writer = getSourceWriter();
- IRFuncType* funcType = specOp->signatureType;
-
- SLANG_ASSERT(funcType->getParamCount() == 2);
-
- IRType* srcType = funcType->getParamType(1);
- IRType* retType = specOp->returnType;
-
- emitFunctionPreambleImpl(nullptr);
-
- emitType(retType);
- writer->emit(" ");
- writer->emit(funcName);
- writer->emit("(");
- emitType(srcType);
- writer->emitChar(' ');
- writer->emitChar(char('a' + 0));
- writer->emit(")");
-
- writer->emit("\n{\n");
- writer->indent();
-
- writer->emit("return ");
- emitType(retType);
- writer->emit("{ ");
-
- const TypeDimension dim = _getTypeDimension(retType, false);
-
- for (int i = 0; i < dim.rowCount; ++i)
- {
- if (dim.rowCount > 1)
- {
- if (i > 0)
- {
- writer->emit(", \n");
- }
- writer->emit("{ ");
- }
- for (int j = 0; j < dim.colCount; ++j)
- {
- if (j > 0)
- {
- writer->emit(", ");
- }
- writer->emit("a");
- }
- if (dim.rowCount > 1)
- {
- writer->emit("}");
- }
- }
-
- writer->emit("};\n");
-
- writer->dedent();
- writer->emit("}\n\n");
-}
-
-void CPPSourceEmitter::_maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp)
-{
- // Check if it's been emitted already, if not add it.
- if (!m_intrinsicEmitted.Add(specOp))
- {
- return;
- }
- emitSpecializedOperationDefinition(specOp);
-}
-
-void CPPSourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp)
-{
- typedef HLSLIntrinsic::Op Op;
-
- switch (specOp->op)
- {
- case Op::Init:
- {
- return _emitInitDefinition(_getFuncName(specOp), specOp);
- }
- case Op::Any:
- case Op::All:
- {
- return _emitAnyAllDefinition(_getFuncName(specOp), specOp);
- }
- case Op::ConstructConvert:
- {
- return _emitConstructConvertDefinition(_getFuncName(specOp), specOp);
- }
- case Op::ConstructFromScalar:
- {
- return _emitConstructFromScalarDefinition(_getFuncName(specOp), specOp);
- }
- case Op::GetAt:
- {
- return _emitGetAtDefinition(_getFuncName(specOp), specOp);
- }
- case Op::Swizzle:
- {
- // Don't have to output anything for swizzle for now
- return;
- }
- default:
- {
- const auto& info = HLSLIntrinsic::getInfo(specOp->op);
- const int paramCount = (info.numOperands < 0) ? int(specOp->signatureType->getParamCount()) : info.numOperands;
-
- if (paramCount >= 1 && paramCount <= 3)
- {
- return _emitAryDefinition(specOp);
- }
- break;
- }
- }
-
- SLANG_ASSERT(!"Unhandled");
-}
-
-void CPPSourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec)
-{
- typedef HLSLIntrinsic::Op Op;
-
- SLANG_UNUSED(inOuterPrec);
- SourceWriter* writer = getSourceWriter();
-
- switch (specOp->op)
- {
- case Op::Init:
- {
- IRType* retType = specOp->returnType;
- if (IRBasicType::isaImpl(retType->getOp()))
- {
- SLANG_ASSERT(numOperands == 1);
-
- writer->emit(_getTypeName(retType));
- writer->emitChar('(');
-
- emitOperand(operands[0].get(), getInfo(EmitOp::General));
-
- writer->emitChar(')');
- return;
- }
- break;
- }
- case Op::Swizzle:
- {
- // Currently only works for C++ (we use {} constuction) - which means we don't need to generate a function.
- // For C we need to generate suitable construction function
- auto swizzleInst = static_cast<IRSwizzle*>(inst);
- const Index elementCount = Index(swizzleInst->getElementCount());
-
- IRType* srcType = swizzleInst->getBase()->getDataType();
- IRVectorType* srcVecType = as<IRVectorType>(srcType);
-
- const UnownedStringSlice* elemNames = getVectorElementNames(srcVecType);
-
- // TODO(JS): Not 100% sure this is correct on the parens handling front
- IRType* retType = specOp->returnType;
- emitType(retType);
- writer->emit("{");
-
- for (Index i = 0; i < elementCount; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
-
- auto outerPrec = getInfo(EmitOp::General);
-
- auto prec = getInfo(EmitOp::Postfix);
- emitOperand(swizzleInst->getBase(), leftSide(outerPrec, prec));
-
- writer->emit(".");
-
- IRInst* irElementIndex = swizzleInst->getElementIndex(i);
- SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit);
- IRConstant* irConst = (IRConstant*)irElementIndex;
- UInt elementIndex = (UInt)irConst->value.intVal;
- SLANG_RELEASE_ASSERT(elementIndex < 4);
-
- writer->emit(elemNames[elementIndex]);
- }
-
- writer->emit("}");
- return;
- }
- default: break;
- }
-
- {
- const auto& info = HLSLIntrinsic::getInfo(specOp->op);
- // Make sure that the return type is available
- const bool isOperator = _isOperator(info.funcName);
- const UnownedStringSlice funcName = _getFuncName(specOp);
-
- switch (specOp->op)
- {
- case Op::ConstructFromScalar:
- {
- // We need to special case, because this may have come from a swizzle from a built in
- // type, in that case the only parameter we want is the first one
- numOperands = 1;
- break;
- }
-
- default: break;
- }
-
- // add that we want a function
- SLANG_ASSERT(info.numOperands < 0 || numOperands == info.numOperands);
-
- useType(specOp->returnType);
-
- if (isOperator)
- {
- // Just do the default output
- defaultEmitInstExpr(inst, inOuterPrec);
- }
- else
- {
- writer->emit(funcName);
- writer->emitChar('(');
-
- for (int i = 0; i < numOperands; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
- emitOperand(operands[i].get(), getInfo(EmitOp::General));
- }
-
- writer->emitChar(')');
- }
- }
-}
-
-HLSLIntrinsic* CPPSourceEmitter::_addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount)
-{
- HLSLIntrinsic intrinsic;
- m_intrinsicSet.calcIntrinsic(op, returnType, argTypes, argTypeCount, intrinsic);
- HLSLIntrinsic* addedIntrinsic = m_intrinsicSet.add(intrinsic);
- _getFuncName(addedIntrinsic);
- return addedIntrinsic;
-}
-
-SlangResult CPPSourceEmitter::calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder)
-{
- outBuilder << _getTypePrefix(type->getOp()) << "_" << HLSLIntrinsic::getInfo(op).funcName;
- return SLANG_OK;
-}
-
-UnownedStringSlice CPPSourceEmitter::_getScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type)
-{
- /* TODO(JS): This is kind of fast and loose. That we don't know all the parameters that are taken or
- what the return type is, so we can't add to the HLSLIntrinsic map - we just generate the scalar
- function name and use it (whilst also adding to the slice pool, so that we can return an
- unowned slice). */
-
- StringBuilder builder;
- if (SLANG_FAILED(calcScalarFuncName(op, type, builder)))
- {
- SLANG_ASSERT(!"Unable to create scalar function name");
- return UnownedStringSlice();
- }
-
- // Add to the pool.
- auto handle = m_slicePool.add(builder);
- return m_slicePool.getSlice(handle);
-}
-
-UnownedStringSlice CPPSourceEmitter::_getFuncName(const HLSLIntrinsic* specOp)
-{
- StringSlicePool::Handle handle = StringSlicePool::kNullHandle;
- if (m_intrinsicNameMap.TryGetValue(specOp, handle))
- {
- return m_slicePool.getSlice(handle);
- }
-
- StringBuilder builder;
- if (SLANG_FAILED(calcFuncName(specOp, builder)))
- {
- SLANG_ASSERT(!"Unable to create function name");
- // Return an empty slice, as an error...
- return UnownedStringSlice();
- }
-
- handle = m_slicePool.add(builder);
- m_intrinsicNameMap.Add(specOp, handle);
-
- SLANG_ASSERT(handle != StringSlicePool::kNullHandle);
- return m_slicePool.getSlice(handle);
-}
-
-SlangResult CPPSourceEmitter::calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& outBuilder)
-{
- typedef HLSLIntrinsic::Op Op;
-
- if (specOp->isScalar())
- {
- IRType* paramType = specOp->signatureType->getParamType(0);
- IRBasicType* basicType = as<IRBasicType>(paramType);
- if (basicType)
- {
- return calcScalarFuncName(specOp->op, basicType, outBuilder);
- }
- else
- {
- outBuilder << getName(paramType) << HLSLIntrinsic::getInfo(specOp->op).name;
- return SLANG_OK;
- }
- }
- else
- {
- switch (specOp->op)
- {
- case Op::ConstructConvert:
- {
- // Work out the function name
- IRFuncType* signatureType = specOp->signatureType;
- SLANG_ASSERT(signatureType->getParamCount() == 2);
-
- IRType* dstType = signatureType->getParamType(0);
- //IRType* srcType = signatureType->getParamType(1);
-
- outBuilder << "convert_";
- // I need a function that is called that will construct this
- SLANG_RETURN_ON_FAIL(calcTypeName(dstType, CodeGenTarget::CSource, outBuilder));
- return SLANG_OK;
- }
- case Op::ConstructFromScalar:
- {
- // Work out the function name
- IRFuncType* signatureType = specOp->signatureType;
- SLANG_ASSERT(signatureType->getParamCount() == 2);
-
- IRType* dstType = signatureType->getParamType(0);
-
- outBuilder << "constructFromScalar_";
- // I need a function that is called that will construct this
- SLANG_RETURN_ON_FAIL(calcTypeName(dstType, CodeGenTarget::CSource, outBuilder));
- return SLANG_OK;
- }
- case Op::GetAt:
- {
- outBuilder << "getAt";
- return SLANG_OK;
- }
- case Op::Init:
- {
- outBuilder << "make_";
- SLANG_RETURN_ON_FAIL(calcTypeName(specOp->returnType, CodeGenTarget::CSource, outBuilder));
- return SLANG_OK;
- }
- default: break;
- }
-
- const auto& info = HLSLIntrinsic::getInfo(specOp->op);
- if (info.funcName.getLength())
- {
- if (!_isOperator(info.funcName))
- {
- // If there is a standard default name, just use that
- outBuilder << info.funcName;
- return SLANG_OK;
- }
- }
-
- // Just use the name of the Op. This is probably wrong, but gives a pretty good idea of what the desired (presumably missing) op is.
- outBuilder << info.name;
- return SLANG_OK;
- }
-}
-
/* !!!!!!!!!!!!!!!!!!!!!! CPPSourceEmitter !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
CPPSourceEmitter::CPPSourceEmitter(const Desc& desc):
Super(desc),
- m_slicePool(StringSlicePool::Style::Default),
- m_typeSet(desc.codeGenContext->getSession()),
- m_opLookup(new HLSLIntrinsicOpLookup),
- m_intrinsicSet(&m_typeSet, m_opLookup)
+ m_slicePool(StringSlicePool::Style::Default)
{
m_semanticUsedFlags = 0;
//m_semanticUsedFlags = SemanticUsedFlag::GroupID | SemanticUsedFlag::GroupThreadID | SemanticUsedFlag::DispatchThreadID;
@@ -2145,12 +958,16 @@ void CPPSourceEmitter::emitSimpleFuncParamImpl(IRParam* param)
void CPPSourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount)
{
- emitSimpleType(m_typeSet.addVectorType(elementType, int(elementCount)));
+ m_writer->emit("Vector<");
+ m_writer->emit(_getTypeName(elementType));
+ m_writer->emit(", ");
+ m_writer->emit(elementCount);
+ m_writer->emit(">");
}
void CPPSourceEmitter::emitSimpleTypeImpl(IRType* inType)
{
- UnownedStringSlice slice = _getTypeName(m_typeSet.getType(inType));
+ UnownedStringSlice slice = _getTypeName(inType);
m_writer->emit(slice);
}
@@ -2225,8 +1042,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl(
IRTargetIntrinsicDecoration* targetIntrinsic,
EmitOpInfo const& inOuterPrec)
{
- typedef HLSLIntrinsic::Op Op;
-
// TODO: Much of this logic duplicates code that is already
// in `CLikeSourceEmitter::emitIntrinsicCallExpr`. The only
// real difference is that when things bottom out on an ordinary
@@ -2248,36 +1063,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl(
if (name == ".operator[]")
{
SLANG_ASSERT(argCount == 2 || argCount == 3);
-
- // If the first item is either a matrix or a vector, we use 'getAt' logic
- IRType* targetType = args[0].get()->getDataType();
- if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType)
- {
- // Work out the intrinsic used
- HLSLIntrinsic intrinsic;
- m_intrinsicSet.calcIntrinsic(HLSLIntrinsic::Op::GetAt, inst->getDataType(), args, 2, intrinsic);
- HLSLIntrinsic* specOp = m_intrinsicSet.add(intrinsic);
-
- if (argCount == 2)
- {
- // Load
- emitCall(specOp, inst, args, 2, inOuterPrec);
- }
- else
- {
- // Store
- auto prec = getInfo(EmitOp::Postfix);
- needClose = maybeEmitParens(outerPrec, prec);
-
- emitCall(specOp, inst, inst->getOperands(), 2, inOuterPrec);
-
- m_writer->emit(" = ");
- emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
-
- maybeCloseParens(needClose);
- }
- }
- else
{
// The user is invoking a built-in subscript operator
@@ -2318,21 +1103,6 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl(
return;
}
- {
- Op op = m_opLookup->getOpByName(name);
- if (op != Op::Invalid)
- {
-
- // Work out the intrinsic used
- HLSLIntrinsic intrinsic;
- m_intrinsicSet.calcIntrinsic(op, inst->getDataType(), args, argCount, intrinsic);
- HLSLIntrinsic* specOp = m_intrinsicSet.add(intrinsic);
-
- emitCall(specOp, inst, args, int(argCount), inOuterPrec);
- return;
- }
- }
-
// Use default impl (which will do intrinsic special macro expansion as necessary)
return Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec);
}
@@ -2372,32 +1142,147 @@ const UnownedStringSlice* CPPSourceEmitter::getVectorElementNames(IRVectorType*
return getVectorElementNames(basicType->getBaseType(), elemCount);
}
-bool CPPSourceEmitter::_tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec)
+bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
- HLSLIntrinsic* specOp = m_intrinsicSet.add(inst);
- if (specOp)
+ switch (inst->getOp())
{
- if (inst->getOp() == kIROp_Call)
+ default:
{
- IRCall* call = static_cast<IRCall*>(inst);
- emitCall(specOp, inst, call->getArgs(), int(call->getArgCount()), inOuterPrec);
+ return false;
}
- else
+ case kIROp_MakeVector:
{
- emitCall(specOp, inst, inst->getOperands(), int(inst->getOperandCount()), inOuterPrec);
+ IRType* retType = inst->getFullType();
+ emitType(retType);
+ m_writer->emit("(");
+ bool isFirst = true;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto arg = inst->getOperand(i);
+ if (auto vectorType = as<IRVectorType>(arg->getDataType()))
+ {
+ for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++)
+ {
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(arg, leftSide(outerPrec, prec));
+ m_writer->emit(".");
+ m_writer->emitChar(s_xyzwNames[j]);
+ }
+ }
+ else
+ {
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ }
+ m_writer->emit(")");
+
+ return true;
}
- return true;
- }
- return false;
-}
+ case kIROp_CastFloatToInt:
+ case kIROp_CastIntToFloat:
+ case kIROp_FloatCast:
+ case kIROp_IntCast:
+ {
+ if (auto vectorType = as<IRVectorType>(inst->getDataType()))
+ {
+ emitType(vectorType);
+ m_writer->emit("{");
+ for (Index i = 0; i < cast<IRIntLit>(vectorType->getElementCount())->getValue(); i++)
+ {
+ if (i > 0)
+ m_writer->emit(", ");
+ m_writer->emit("(");
+ emitType(vectorType->getElementType());
+ m_writer->emit(")_slang_vector_get_element(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ m_writer->emit(i);
+ m_writer->emit(")");
+ }
+ m_writer->emit("}");
+ return true;
+ }
+ return false;
+ }
+ case kIROp_VectorReshape:
+ {
+ if (auto vectorType = as<IRVectorType>(inst->getDataType()))
+ {
+ m_writer->emit("_slang_vector_reshape<");
+ emitType(vectorType->getElementType());
+ m_writer->emit(", ");
+ emitOperand(vectorType->getElementCount(), getInfo(EmitOp::General));
+ m_writer->emit(">(");
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ return false;
+ }
+ case kIROp_GetElement:
+ {
+ auto getElementInst = static_cast<IRGetElement*>(inst);
-bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
-{
- switch (inst->getOp())
- {
- default:
+ IRInst* baseInst = getElementInst->getBase();
+ IRType* baseType = baseInst->getDataType();
+ if (as<IRVectorType>(baseType))
+ {
+ m_writer->emit("_slang_vector_get_element(");
+ emitOperand(baseInst, getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ else if (as<IRMatrixType>(baseType))
+ {
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(baseInst, leftSide(outerPrec, prec));
+ m_writer->emit(".rows[");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit("]");
+ return true;
+ }
+ return false;
+ }
+ case kIROp_GetElementPtr:
{
- return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec);
+ auto getElementInst = static_cast<IRGetElement*>(inst);
+
+ IRInst* baseInst = getElementInst->getBase();
+ IRType* baseType = as<IRPtrTypeBase>(baseInst->getDataType())->getValueType();
+ if (as<IRVectorType>(baseType))
+ {
+ m_writer->emit("_slang_vector_get_element_ptr(");
+ emitOperand(baseInst, getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ else if (as<IRMatrixType>(baseType))
+ {
+ m_writer->emit("&(");
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(baseInst, leftSide(outerPrec, prec));
+ m_writer->emit("->rows[");
+ emitOperand(getElementInst->getIndex(), getInfo(EmitOp::General));
+ m_writer->emit("]");
+ m_writer->emit(")");
+ return true;
+ }
+ return false;
}
case kIROp_swizzle:
{
@@ -2430,8 +1315,79 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
return true;
}
}
- // try doing automatically
- return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec);
+
+ {
+ // Currently only works for C++ (we use {} constuction) - which means we don't need to generate a function.
+ // For C we need to generate suitable construction function
+
+ const Index elementCount = Index(swizzleInst->getElementCount());
+
+ IRType* srcType = swizzleInst->getBase()->getDataType();
+ IRVectorType* srcVecType = as<IRVectorType>(srcType);
+
+ const UnownedStringSlice* elemNames = nullptr;
+ if (srcVecType)
+ elemNames = getVectorElementNames(srcVecType);
+
+ IRType* retType = swizzleInst->getFullType();
+ emitType(retType);
+ m_writer->emit("{");
+
+ for (Index i = 0; i < elementCount; ++i)
+ {
+ if (i > 0)
+ {
+ m_writer->emit(", ");
+ }
+
+ auto outerPrec = getInfo(EmitOp::General);
+
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(swizzleInst->getBase(), leftSide(outerPrec, prec));
+
+ if (srcVecType)
+ {
+ m_writer->emit(".");
+
+ IRInst* irElementIndex = swizzleInst->getElementIndex(i);
+ SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit);
+ IRConstant* irConst = (IRConstant*)irElementIndex;
+ UInt elementIndex = (UInt)irConst->value.intVal;
+ SLANG_RELEASE_ASSERT(elementIndex < 4);
+
+ m_writer->emit(elemNames[elementIndex]);
+ }
+ }
+
+ m_writer->emit("}");
+ }
+ return true;
+ }
+ case kIROp_FRem:
+ {
+ if (auto basicType = as<IRBasicType>(inst->getDataType()))
+ {
+ switch (basicType->getOp())
+ {
+ case kIROp_HalfType:
+ m_writer->emit("F16_fmod(");
+ break;
+ case kIROp_FloatType:
+ m_writer->emit("F32_fmod(");
+ break;
+ case kIROp_DoubleType:
+ m_writer->emit("F64_fmod(");
+ break;
+ default:
+ return false;
+ }
+ emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
+ m_writer->emit(", ");
+ emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
+ m_writer->emit(")");
+ return true;
+ }
+ return false;
}
case kIROp_Call:
{
@@ -2441,7 +1397,7 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
handleRequiredCapabilities(funcValue);
// try doing automatically
- return _tryEmitInstExprAsIntrinsic(inst, inOuterPrec);
+ return false;
}
case kIROp_LookupWitness:
{
@@ -2562,29 +1518,6 @@ bool CPPSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOut
}
}
-// We want order of built in types (typically output nothing), vector, matrix, other types
-// Types that aren't output have negative indices
-static Index _calcTypeOrder(IRType* a)
-{
- switch (a->getOp())
- {
- case kIROp_FuncType:
- {
- return -2;
- }
- case kIROp_VectorType: return 1;
- case kIROp_MatrixType: return 2;
- default:
- {
- if (as<IRBasicType>(a))
- {
- return -1;
- }
- return 3;
- }
- }
-}
-
void CPPSourceEmitter::emitPreModuleImpl()
{
if (m_target == CodeGenTarget::CPPSource)
@@ -2604,45 +1537,6 @@ void CPPSourceEmitter::emitPreModuleImpl()
m_writer->emit("using namespace SLANG_PRELUDE_NAMESPACE;\n");
m_writer->emit("#endif\n\n");
}
-
- // Emit generated functions and types
-
- if (m_target == CodeGenTarget::CSource)
- {
- // For C output we need to emit type definitions.
- List<IRType*> types;
- m_typeSet.getTypes(types);
-
- // Remove ones we don't need to emit
- for (Index i = 0; i < types.getCount(); ++i)
- {
- if (_calcTypeOrder(types[i]) < 0)
- {
- types.fastRemoveAt(i);
- --i;
- }
- }
-
- // Sort them so that vectors come before matrices and everything else after that
- types.sort([&](IRType* a, IRType* b) { return _calcTypeOrder(a) < _calcTypeOrder(b); });
-
- // Emit the type definitions
- for (auto type : types)
- {
- emitTypeDefinition(type);
- }
- }
-
- {
- List<const HLSLIntrinsic*> intrinsics;
- m_intrinsicSet.getIntrinsics(intrinsics);
-
- // Emit all the intrinsics that were used
- for (auto intrinsic : intrinsics)
- {
- _maybeEmitSpecializedOperationDefinition(intrinsic);
- }
- }
}
@@ -2980,11 +1874,6 @@ void CPPSourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
- // Setup all built in types used in the module
- m_typeSet.addAllBuiltinTypes(module);
- // If any matrix types are used, then we need appropriate vector types too.
- m_typeSet.addVectorForMatrixTypes();
-
List<EmitAction> actions;
computeEmitActions(module, actions);
diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h
index c5b9f3d9c..ec70b02b8 100644
--- a/source/slang/slang-emit-cpp.h
+++ b/source/slang/slang-emit-cpp.h
@@ -39,9 +39,6 @@ public:
};
virtual void useType(IRType* type);
- virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec);
- virtual void emitTypeDefinition(IRType* type);
- virtual void emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp);
static UnownedStringSlice getBuiltinTypeName(IROp op);
@@ -78,43 +75,21 @@ protected:
virtual void emitVarDecorationsImpl(IRInst* var) SLANG_OVERRIDE;
virtual void emitGlobalInstImpl(IRInst* inst) SLANG_OVERRIDE;
- virtual const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount);
+ const UnownedStringSlice* getVectorElementNames(BaseType elemType, Index elemCount);
// Replaceable for classes derived from CPPSourceEmitter
virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out);
- virtual SlangResult calcFuncName(const HLSLIntrinsic* specOp, StringBuilder& out);
- virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder);
const UnownedStringSlice* getVectorElementNames(IRVectorType* vectorType);
- void _maybeEmitSpecializedOperationDefinition(const HLSLIntrinsic* specOp);
-
void _emitForwardDeclarations(const List<EmitAction>& actions);
- void _emitAryDefinition(const HLSLIntrinsic* specOp);
-
- // Really we don't want any of these defined like they are here, they should be defined in slang stdlib
- void _emitAnyAllDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp);
- void _emitConstructConvertDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp);
- void _emitConstructFromScalarDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp);
- void _emitGetAtDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp);
- void _emitInitDefinition(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp);
-
- void _emitSignature(const UnownedStringSlice& funcName, const HLSLIntrinsic* specOp);
-
void _emitInOutParamType(IRType* type, String const& name, IRType* valueType);
-
- UnownedStringSlice _getAndEmitSpecializedOperationDefinition(HLSLIntrinsic::Op op, IRType*const* argTypes, Int argCount, IRType* retType);
-
static TypeDimension _getTypeDimension(IRType* type, bool vecSwap);
void _emitAccess(const UnownedStringSlice& name, const TypeDimension& dimension, int row, int col, SourceWriter* writer);
- UnownedStringSlice _getScalarFuncName(HLSLIntrinsic::Op operation, IRBasicType* scalarType);
-
- UnownedStringSlice _getFuncName(const HLSLIntrinsic* specOp);
-
UnownedStringSlice _getTypeName(IRType* type);
SlangResult _calcCPPTextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName);
@@ -126,8 +101,6 @@ protected:
void _emitInitAxisValues(const Int sizeAlongAxis[kThreadGroupAxisCount], const UnownedStringSlice& mulName, const UnownedStringSlice& addName);
- bool _tryEmitInstExprAsIntrinsic(IRInst* inst, const EmitOpInfo& inOuterPrec);
-
// Emit the actual definition (including intializer list)
// of all the witness table objects in `pendingWitnessTableDefinitions`.
void _emitWitnessTableDefinitions();
@@ -136,18 +109,9 @@ protected:
void _getExportStyle(IRInst* inst, bool& outIsExport, bool& outIsExternC);
void _maybeEmitExportLike(IRInst* inst);
- HLSLIntrinsic* _addIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* argTypes, Index argTypeCount);
-
static bool _isVariable(IROp op);
Dictionary<IRType*, StringSlicePool::Handle> m_typeNameMap;
- Dictionary<const HLSLIntrinsic*, StringSlicePool::Handle> m_intrinsicNameMap;
-
- IRTypeSet m_typeSet;
- RefPtr<HLSLIntrinsicOpLookup> m_opLookup;
- HLSLIntrinsicSet m_intrinsicSet;
-
- HashSet<const HLSLIntrinsic*> m_intrinsicEmitted;
HashSet<IRInterfaceType*> m_interfaceTypesEmitted;
diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp
index 284652682..a151ab0e2 100644
--- a/source/slang/slang-emit-cuda.cpp
+++ b/source/slang/slang-emit-cuda.cpp
@@ -123,131 +123,6 @@ SlangResult CUDASourceEmitter::_calcCUDATextureTypeName(IRTextureTypeBase* texTy
return SLANG_FAIL;
}
-SlangResult CUDASourceEmitter::calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder)
-{
- typedef HLSLIntrinsic::Op Op;
-
- UnownedStringSlice funcName;
-
- switch (op)
- {
- case Op::FRem:
- {
- if (type->getOp() == kIROp_FloatType || type->getOp() == kIROp_DoubleType)
- {
- funcName = HLSLIntrinsic::getInfo(op).funcName;
- }
- break;
- }
- default: break;
- }
-
- if (funcName.getLength())
- {
- outBuilder << funcName;
- if (type->getOp() == kIROp_FloatType)
- {
- outBuilder << "f";
- }
- return SLANG_OK;
- }
-
- // Defer to the supers impl
- return Super::calcScalarFuncName(op, type, outBuilder);
-}
-
-void CUDASourceEmitter::emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp)
-{
- typedef HLSLIntrinsic::Op Op;
-
- if (auto vecType = as <IRVectorType>(specOp->returnType))
- {
- // Converting to or from half vector types is implemented prelude as convert___half functions
- // Get the from type -> if it's half we ignore
-
- if (specOp->op == Op::ConstructConvert)
- {
- auto signatureType = specOp->signatureType;
-
- // Need to have impl of convert_float, double, int, uint, in prelude
-
- const auto paramCount = signatureType->getParamCount();
- SLANG_UNUSED(paramCount);
-
- // We have 2 'params' and param 1 is the source type
- SLANG_ASSERT(paramCount == 2);
- IRType* paramType = signatureType->getParamType(1);
-
- auto vecParamType = as<IRVectorType>(paramType);
-
- if (auto baseType = as<IRBasicType>(vecParamType->getElementType()))
- {
- if (baseType->getBaseType() == BaseType::Half)
- {
- return;
- }
- }
- }
-
- if (auto baseType = as<IRBasicType>(vecType->getElementType()))
- {
- if (baseType->getBaseType() == BaseType::Half)
- {
- switch (specOp->op)
- {
- case Op::Init:
-
- case Op::Add:
- case Op::Mul:
- case Op::Div:
- case Op::Sub:
-
- case Op::Neg:
-
- case Op::ConstructFromScalar:
- case Op::ConstructConvert:
-
- case Op::Leq:
- case Op::Less:
- case Op::Greater:
- case Op::Geq:
- case Op::Neq:
- case Op::Eql:
- {
- return;
- }
- }
- }
- }
- }
-
- switch (specOp->op)
- {
- case Op::Init:
- {
- // Special case handling
- auto returnType = specOp->returnType;
-
- if (auto vecType = as <IRVectorType>(returnType))
- {
- if (auto baseType = as<IRBasicType>(vecType->getElementType()))
- {
- if (baseType->getBaseType() == BaseType::Half)
- {
- // Defined already in cuda-prelude.h
- return;
- }
- }
- }
-
- break;
- }
- default: break;
- }
-
- Super::emitSpecializedOperationDefinition(specOp);
-}
-
SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out)
{
SLANG_UNUSED(target);
@@ -322,25 +197,6 @@ SlangResult CUDASourceEmitter::calcTypeName(IRType* type, CodeGenTarget target,
return Super::calcTypeName(type, target, out);
}
-const UnownedStringSlice* CUDASourceEmitter::getVectorElementNames(BaseType baseType, Index elemCount)
-{
- static const UnownedStringSlice normal[] = { UnownedStringSlice::fromLiteral("x"), UnownedStringSlice::fromLiteral("y"), UnownedStringSlice::fromLiteral("z"), UnownedStringSlice::fromLiteral("w") };
- static const UnownedStringSlice half3[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("z") };
- static const UnownedStringSlice half4[] = { UnownedStringSlice::fromLiteral("xy.x"), UnownedStringSlice::fromLiteral("xy.y"), UnownedStringSlice::fromLiteral("zw.x"), UnownedStringSlice::fromLiteral("zw.y")};
-
- if (baseType == BaseType::Half)
- {
- switch (elemCount)
- {
- default: break;
- case 3: return half3;
- case 4: return half4;
- }
- }
-
- return normal;
-}
-
void CUDASourceEmitter::emitLayoutSemanticsImpl(IRInst* inst, char const* uniformSemanticSpelling)
{
Super::emitLayoutSemanticsImpl(inst, uniformSemanticSpelling);
@@ -436,49 +292,6 @@ void CUDASourceEmitter::emitGlobalRTTISymbolPrefix()
m_writer->emit("__constant__ ");
}
-void CUDASourceEmitter::emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec)
-{
- switch (specOp->op)
- {
- case HLSLIntrinsic::Op::Init:
- {
- // For CUDA vector types we construct with make_
-
- auto writer = m_writer;
-
- IRType* retType = specOp->returnType;
-
- if (IRVectorType* vecType = as<IRVectorType>(retType))
- {
- if (numOperands == getIntVal(vecType->getElementCount()))
- {
- // Get the type name
- writer->emit("make_");
- emitType(retType);
- writer->emitChar('(');
-
- for (int i = 0; i < numOperands; ++i)
- {
- if (i > 0)
- {
- writer->emit(", ");
- }
- emitOperand(operands[i].get(), getInfo(EmitOp::General));
- }
-
- writer->emitChar(')');
- return;
- }
- }
- // Just use the default
- break;
- }
- default: break;
- }
-
- return Super::emitCall(specOp, inst, operands, numOperands, inOuterPrec);
-}
-
void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* decl)
{
if (decl->getMode() == kIRLoopControl_Unroll)
@@ -487,59 +300,25 @@ void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d
}
}
-static bool _areEquivalent(IRType* a, IRType* b)
-{
- if (a == b)
- {
- return true;
- }
- if (a->getOp() != b->getOp())
- {
- return false;
- }
-
- switch (a->getOp())
- {
- case kIROp_VectorType:
- {
- IRVectorType* vecA = static_cast<IRVectorType*>(a);
- IRVectorType* vecB = static_cast<IRVectorType*>(b);
- return getIntVal(vecA->getElementCount()) == getIntVal(vecB->getElementCount()) &&
- _areEquivalent(vecA->getElementType(), vecB->getElementType());
- }
- case kIROp_MatrixType:
- {
- IRMatrixType* matA = static_cast<IRMatrixType*>(a);
- IRMatrixType* matB = static_cast<IRMatrixType*>(b);
- return getIntVal(matA->getColumnCount()) == getIntVal(matB->getColumnCount()) &&
- getIntVal(matA->getRowCount()) == getIntVal(matB->getRowCount()) &&
- _areEquivalent(matA->getElementType(), matB->getElementType());
- }
- default:
- {
- return as<IRBasicType>(a) != nullptr;
- }
- }
-}
-
void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value)
{
// When constructing a matrix or vector from a single value this is handled by the default path
switch (value->getOp())
{
- case kIROp_MakeMatrix:
case kIROp_MakeVector:
+ case kIROp_MakeMatrix:
{
IRType* type = value->getDataType();
// If the types are the same, we can can just break down and use
- if (_areEquivalent(dstType, type))
+ if (dstType == type)
{
if (auto vecType = as<IRVectorType>(type))
{
if (UInt(getIntVal(vecType->getElementCount())) == value->getOperandCount())
{
+ emitType(type);
_emitInitializerList(vecType->getElementType(), value->getOperands(), value->getOperandCount());
return;
}
@@ -551,20 +330,25 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value
// TODO(JS): If num cols = 1, then it *doesn't* actually return a vector.
// That could be argued is an error because we want swizzling or [] to work.
- IRType* rowType = m_typeSet.addVectorType(matType->getElementType(), int(colCount));
- IRVectorType* rowVectorType = as<IRVectorType>(rowType);
+ IRBuilder builder(matType->getModule());
+ builder.setInsertBefore(matType);
const Index operandCount = Index(value->getOperandCount());
// Can init, with vectors.
// For now special case if the rowVectorType is not actually a vector (when elementSize == 1)
- if (operandCount == rowCount || rowVectorType == nullptr)
+ if (operandCount == rowCount)
{
- // We have to output vectors
-
- // Emit the braces for the Matrix struct, contains an row array.
+ // Emit the braces for the Matrix struct, and then each row vector in its own line.
+ emitType(matType);
m_writer->emit("{\n");
m_writer->indent();
- _emitInitializerList(rowType, value->getOperands(), rowCount);
+ for (Index i = 0; i < rowCount; ++i)
+ {
+ if (i != 0) m_writer->emit(",\n");
+ emitType(matType->getElementType());
+ m_writer->emit(colCount);
+ _emitInitializerList(matType->getElementType(), value->getOperand(i)->getOperands(), colCount);
+ }
m_writer->dedent();
m_writer->emit("\n}");
return;
@@ -575,21 +359,18 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value
IRType* elementType = matType->getElementType();
IRUse* operands = value->getOperands();
- // Emit the braces for the Matrix struct, and the array of rows
- m_writer->emit("{\n");
- m_writer->indent();
+ // Emit the braces for the Matrix struct, and the elements of each row in its own line.
+ emitType(matType);
m_writer->emit("{\n");
m_writer->indent();
for (Index i = 0; i < rowCount; ++i)
{
- if (i != 0) m_writer->emit(", ");
- _emitInitializerList(elementType, operands, colCount);
+ if (i != 0) m_writer->emit(",\n");
+ _emitInitializerListContent(elementType, operands, colCount);
operands += colCount;
}
m_writer->dedent();
m_writer->emit("\n}");
- m_writer->dedent();
- m_writer->emit("\n}");
return;
}
}
@@ -603,116 +384,157 @@ void CUDASourceEmitter::_emitInitializerListValue(IRType* dstType, IRInst* value
emitOperand(value, getInfo(EmitOp::General));
}
-void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount)
+void CUDASourceEmitter::_emitInitializerListContent(IRType* elementType, IRUse* operands, Index operandCount)
{
- m_writer->emit("{\n");
- m_writer->indent();
-
for (Index i = 0; i < operandCount; ++i)
{
if (i != 0) m_writer->emit(", ");
_emitInitializerListValue(elementType, operands[i].get());
}
-
- m_writer->dedent();
- m_writer->emit("\n}");
}
-void CUDASourceEmitter::_emitGetHalfVectorElement(IRInst* base, Index index, Index vecSize, const EmitOpInfo& inOuterPrec)
-{
- SLANG_ASSERT(index < vecSize);
-
- EmitOpInfo outerPrec = inOuterPrec;
-
- auto prec = getInfo(EmitOp::Postfix);
- const bool needClose = maybeEmitParens(outerPrec, prec);
- emitOperand(base, leftSide(outerPrec, prec));
+void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount)
+{
+ m_writer->emit("{\n");
+ m_writer->indent();
- m_writer->emit(".");
+ _emitInitializerListContent(elementType, operands, operandCount);
- switch (vecSize)
- {
- default:
- {
- char const* kComponents[] = { "x", "y", "z", "w" };
- m_writer->emit(kComponents[index]);
- break;
- }
- case 3:
- {
- char const* kComponents[] = { "xy.x", "xy.y", "z"};
- m_writer->emit(kComponents[index]);
- break;
- }
- case 4:
- {
- char const* kComponents[] = { "xy.x", "xy.y", "zw.x", "zw.y" };
- m_writer->emit(kComponents[index]);
- break;
- }
- }
+ m_writer->dedent();
+ m_writer->emit("\n}");
+}
- maybeCloseParens(needClose);
+void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec)
+{
+ if (targetIntrinsic->getDefinition().startsWith("__half"))
+ m_extensionTracker->requireBaseType(BaseType::Half);
+ Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec);
}
bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
{
switch(inst->getOp())
{
- case kIROp_swizzle:
+ case kIROp_MakeVector:
+ case kIROp_MakeVectorFromScalar:
{
- // We need to special case for half types.
- auto swizzleInst = static_cast<IRSwizzle*>(inst);
-
- IRInst* baseInst = swizzleInst->getBase();
- IRType* baseType = baseInst->getDataType();
-
- // If we are swizzling from a built in type,
- if (as<IRBasicType>(baseType))
+ m_writer->emit("make_");
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ bool isFirst = true;
+ char xyzwNames[] = "xyzw";
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
{
- // Just use the default behavior
+ auto arg = inst->getOperand(i);
+ if (auto vectorType = as<IRVectorType>(arg->getDataType()))
+ {
+ for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++)
+ {
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(arg, leftSide(outerPrec, prec));
+ m_writer->emit(".");
+ m_writer->emitChar(xyzwNames[j]);
+ }
+ }
+ else
+ {
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
}
- else if (auto vecType = as<IRVectorType>(baseType))
+ m_writer->emit(")");
+ return true;
+ }
+ case kIROp_FloatCast:
+ case kIROp_CastIntToFloat:
+ case kIROp_IntCast:
+ case kIROp_CastFloatToInt:
+ {
+ if (auto dstVectorType = as<IRVectorType>(inst->getDataType()))
{
- if (auto basicType = as<IRBasicType>(vecType->getElementType()))
+ m_writer->emit("make_");
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ bool isFirst = true;
+ char xyzwNames[] = "xyzw";
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
{
- if (basicType->getBaseType() == BaseType::Half)
+ auto arg = inst->getOperand(i);
+ if (auto vectorType = as<IRVectorType>(arg->getDataType()))
{
- const Index vecElementCount = Index(getIntVal(vecType->getElementCount()));
-
- const Index elementCount = Index(swizzleInst->getElementCount());
- if (elementCount == 1)
- {
- const Index index = Index(getIntVal(swizzleInst->getElementIndex(0)));
- _emitGetHalfVectorElement(baseInst, index, vecElementCount, inOuterPrec);
- }
- else
+ for (int j = 0; j < cast<IRIntLit>(vectorType->getElementCount())->getValue(); j++)
{
- auto outerPrec = getInfo(EmitOp::General);
-
- m_writer->emit("make___half");
- m_writer->emitInt64(elementCount);
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
m_writer->emit("(");
-
- for (Index i = 0; i < elementCount; ++i)
- {
- if (i)
- {
- m_writer->emit(", ");
- }
-
- const Index index = Index(getIntVal(swizzleInst->getElementIndex(i)));
- _emitGetHalfVectorElement(baseInst, index, vecElementCount, outerPrec);
- }
-
+ emitType(dstVectorType->getElementType());
m_writer->emit(")");
+ auto outerPrec = getInfo(EmitOp::General);
+ auto prec = getInfo(EmitOp::Postfix);
+ emitOperand(arg, leftSide(outerPrec, prec));
+ m_writer->emit(".");
+ m_writer->emitChar(xyzwNames[j]);
}
- return true;
+ }
+ else
+ {
+ if (isFirst)
+ isFirst = false;
+ else
+ m_writer->emit(", ");
+ m_writer->emit("(");
+ emitType(dstVectorType->getElementType());
+ m_writer->emit(")");
+ emitOperand(arg, getInfo(EmitOp::General));
}
}
+ m_writer->emit(")");
+ return true;
}
- break;
+ else if (auto matrixType = as<IRMatrixType>(inst->getDataType()))
+ {
+ m_writer->emit("make");
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto arg = inst->getOperand(i);
+ if (i > 0)
+ m_writer->emit(", ");
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ return true;
+ }
+ return false;
+ }
+ case kIROp_MakeMatrix:
+ case kIROp_MakeMatrixFromScalar:
+ case kIROp_MatrixReshape:
+ {
+ m_writer->emit("make");
+ emitType(inst->getDataType());
+ m_writer->emit("(");
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto arg = inst->getOperand(i);
+ if (i > 0)
+ m_writer->emit(", ");
+ emitOperand(arg, getInfo(EmitOp::General));
+ }
+ m_writer->emit(")");
+ return true;
}
case kIROp_MakeArray:
{
@@ -722,13 +544,9 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
IRType* elementType = arrayType->getElementType();
// Emit braces for the FixedArray struct.
- m_writer->emit("{\n");
- m_writer->indent();
_emitInitializerList(elementType, inst->getOperands(), Index(inst->getOperandCount()));
- m_writer->dedent();
- m_writer->emit("\n}");
return true;
}
case kIROp_WaveMaskBallot:
@@ -820,7 +638,19 @@ void CUDASourceEmitter::emitVectorTypeNameImpl(IRType* elementType, IRIntegerVal
void CUDASourceEmitter::emitSimpleTypeImpl(IRType* type)
{
- m_writer->emit(_getTypeName(type));
+ switch (type->getOp())
+ {
+ case kIROp_VectorType:
+ {
+ auto vectorType = as<IRVectorType>(type);
+ m_writer->emit(getVectorPrefix(vectorType->getElementType()->getOp()));
+ m_writer->emit(as<IRIntLit>(vectorType->getElementCount())->getValue());
+ break;
+ }
+ default:
+ m_writer->emit(_getTypeName(type));
+ break;
+ }
}
void CUDASourceEmitter::emitRateQualifiersImpl(IRRate* rate)
@@ -907,27 +737,6 @@ void CUDASourceEmitter::emitPreModuleImpl()
// Emit generated types/functions
writer->emit("\n");
-
- {
- List<IRType*> types;
- m_typeSet.getTypes(IRTypeSet::Kind::Matrix, types);
-
- // Emit the type definitions
- for (auto type : types)
- {
- emitTypeDefinition(type);
- }
- }
-
- {
- List<const HLSLIntrinsic*> intrinsics;
- m_intrinsicSet.getIntrinsics(intrinsics);
- // Emit all the intrinsics that were used
- for (auto intrinsic : intrinsics)
- {
- _maybeEmitSpecializedOperationDefinition(intrinsic);
- }
- }
}
@@ -951,22 +760,6 @@ bool CUDASourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* v
void CUDASourceEmitter::emitModuleImpl(IRModule* module, DiagnosticSink* sink)
{
- // Setup all built in types used in the module
- m_typeSet.addAllBuiltinTypes(module);
- // If any matrix types are used, then we need appropriate vector types too.
- m_typeSet.addVectorForMatrixTypes();
-
- // We need to add some vector intrinsics - used for calculating thread ids
- {
- IRType* type = m_typeSet.addVectorType(m_typeSet.getBuilder().getBasicType(BaseType::UInt), 3);
- IRType* args[] = { type, type };
-
- _addIntrinsic(HLSLIntrinsic::Op::Add, type, args, SLANG_COUNT_OF(args));
- _addIntrinsic(HLSLIntrinsic::Op::Mul, type, args, SLANG_COUNT_OF(args));
- }
-
- // TODO(JS): We may need to generate types (for example for matrices)
-
CLikeSourceEmitter::emitModuleImpl(module, sink);
// Emit all witness table definitions.
diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h
index ff947fe58..8a907dc7c 100644
--- a/source/slang/slang-emit-cuda.h
+++ b/source/slang/slang-emit-cuda.h
@@ -78,12 +78,9 @@ protected:
virtual void emitVectorTypeNameImpl(IRType* elementType, IRIntegerValue elementCount) SLANG_OVERRIDE;
virtual void emitVarDecorationsImpl(IRInst* varDecl) SLANG_OVERRIDE;
virtual void emitMatrixLayoutModifiersImpl(IRVarLayout* layout) SLANG_OVERRIDE;
- virtual void emitCall(const HLSLIntrinsic* specOp, IRInst* inst, const IRUse* operands, int numOperands, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
virtual void emitFunctionPreambleImpl(IRInst* inst) SLANG_OVERRIDE;
virtual String generateEntryPointNameImpl(IREntryPointDecoration* entryPointDecor) SLANG_OVERRIDE;
- virtual const UnownedStringSlice* getVectorElementNames(BaseType baseType, Index elemCount) SLANG_OVERRIDE;
-
virtual void emitGlobalRTTISymbolPrefix() SLANG_OVERRIDE;
virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE;
@@ -92,23 +89,19 @@ protected:
virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE;
virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE;
-
+ virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) SLANG_OVERRIDE;
// CPPSourceEmitter overrides
virtual SlangResult calcTypeName(IRType* type, CodeGenTarget target, StringBuilder& out) SLANG_OVERRIDE;
- virtual SlangResult calcScalarFuncName(HLSLIntrinsic::Op op, IRBasicType* type, StringBuilder& outBuilder) SLANG_OVERRIDE;
-
- virtual void emitSpecializedOperationDefinition(const HLSLIntrinsic* specOp) SLANG_OVERRIDE;
SlangResult _calcCUDATextureTypeName(IRTextureTypeBase* texType, StringBuilder& outName);
void _emitInitializerList(IRType* elementType, IRUse* operands, Index operandCount);
+ void _emitInitializerListContent(IRType* elementType, IRUse* operands, Index operandCount);
void _emitInitializerListValue(IRType* elementType, IRInst* value);
- void _emitGetHalfVectorElement(IRInst* baseInst, Index index, Index vecSize, const EmitOpInfo& inOuterPrec);
-
RefPtr<CUDAExtensionTracker> m_extensionTracker;
};
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index c49265fe7..ef0d062bb 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1022,7 +1022,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
auto irModule = linkedIR.module;
// Perform final simplifications to help emit logic to generate more compact code.
- simplifyForEmit(irModule);
+ simplifyForEmit(irModule, targetRequest);
metadata = linkedIR.metadata;
diff --git a/source/slang/slang-hlsl-intrinsic-set.cpp b/source/slang/slang-hlsl-intrinsic-set.cpp
index ea3476473..e69de29bb 100644
--- a/source/slang/slang-hlsl-intrinsic-set.cpp
+++ b/source/slang/slang-hlsl-intrinsic-set.cpp
@@ -1,590 +0,0 @@
-// slang-hlsl-intrinsic-set.cpp
-#include "slang-hlsl-intrinsic-set.h"
-
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
-
-namespace Slang
-{
-
-/* static */const HLSLIntrinsic::Info HLSLIntrinsic::s_operationInfos[] =
-{
-#define SLANG_HLSL_INTRINSIC_OP_INFO(x, funcName, numOperands) { UnownedStringSlice::fromLiteral(#x), UnownedStringSlice::fromLiteral(funcName), int8_t(numOperands) },
- SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_INFO)
-};
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicSet !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-HLSLIntrinsicSet::HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup):
- m_intrinsicFreeList(sizeof(HLSLIntrinsic), SLANG_ALIGN_OF(HLSLIntrinsic), 1024),
- m_typeSet(typeSet),
- m_opLookup(lookup)
-{
-}
-
-static IRBasicType* _getElementType(IRType* type)
-{
- switch (type->getOp())
- {
- case kIROp_VectorType: type = static_cast<IRVectorType*>(type)->getElementType(); break;
- case kIROp_MatrixType: type = static_cast<IRMatrixType*>(type)->getElementType(); break;
- default: break;
- }
- return dynamicCast<IRBasicType>(type);
-}
-
-void HLSLIntrinsicSet::_calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out)
-{
- IRBuilder& builder = m_typeSet->getBuilder();
-
- // Check all types belong to the module
-
- IRModule* module = builder.getModule();
-
- SLANG_UNUSED(module);
- SLANG_ASSERT(returnType->getModule() == module);
-
- for (Index i = 0; i < argsCount; ++i)
- {
- SLANG_ASSERT(inArgs[i]->getModule() == module);
- }
-
- // Set up the out
- out.op = op;
- out.returnType = returnType;
-
- switch (op)
- {
- case Op::GetAt:
- {
- IRType* argTypes[3];
-
- SLANG_ASSERT(argsCount == 2 || argsCount == 3);
- // TODO(JS):
- // HACK! GetAt can be from getElementPtr or from getElement. Get element ptr means the return type will be
- // a pointer. We don't want to deal with that, so strip it
- if (returnType->getOp() == kIROp_PtrType)
- {
- returnType = as<IRType>(returnType->getOperand(0));
- }
-
- // TODO(JS): Similarly for the input parameters
- for (Index i = 0; i < argsCount; ++i)
- {
- IRType* argType = inArgs[i];
-
- if (argType->getOp() == kIROp_PtrType)
- {
- argType = as<IRType>(argType->getOperand(0));
- }
- argTypes[i] = argType;
- }
-
- out.returnType = returnType;
- out.signatureType = builder.getFuncType(argsCount, argTypes, builder.getVoidType());
- break;
- }
- case Op::ConstructFromScalar:
- {
- //SLANG_ASSERT(argsCount == 1);
- SLANG_ASSERT(argsCount == 1);
- IRType* srcType = _getElementType(returnType);
- IRType* argTypes[2] = { returnType, srcType };
-
- out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType());
- break;
- }
- case Op::ConstructConvert:
- {
- // Make the return type a parameter, to make the signature take into account
- SLANG_ASSERT(argsCount == 1);
- IRType* argTypes[2] = { returnType, inArgs[0] };
-
- out.signatureType = builder.getFuncType(2, argTypes, builder.getVoidType());
- break;
- }
- default:
- {
- out.signatureType = builder.getFuncType(argsCount, inArgs, builder.getVoidType());
- break;
- }
- }
-}
-
-void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgTypes, Index argCount, HLSLIntrinsic& out)
-{
- returnType = m_typeSet->getType(returnType);
-
- if (argCount <= 8)
- {
- IRType* args[8];
- for (Index i = 0; i < argCount; ++i)
- {
- args[i] = m_typeSet->getType(inArgTypes[i]);
- }
- _calcIntrinsic(op, returnType, args, argCount, out);
- }
- else
- {
- List<IRType*> args;
- args.setCount(argCount);
-
- for (Index i = 0; i < argCount; ++i)
- {
- args[i] = m_typeSet->getType(inArgTypes[i]);
- }
- _calcIntrinsic(op, returnType, args.getBuffer(), argCount, out);
- }
-}
-
-void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRInst* inst, Index operandCount, HLSLIntrinsic& out)
-{
- IRType* returnType = m_typeSet->getType(inst->getDataType());
- if (operandCount <= 8)
- {
- IRType* argTypes[8];
- for (Index i = 0; i < operandCount; ++i)
- {
- auto operand = inst->getOperand(i);
- argTypes[i] = m_typeSet->getType(operand->getDataType());
- }
- _calcIntrinsic(op, returnType, argTypes, operandCount, out);
- }
- else
- {
- List<IRType*> argTypes;
- argTypes.setCount(operandCount);
-
- for (Index i = 0; i < operandCount; ++i)
- {
- auto operand = inst->getOperand(i);
- argTypes[i] = m_typeSet->getType(operand->getDataType());
- }
- _calcIntrinsic(op, returnType, argTypes.getBuffer(), operandCount, out);
- }
-}
-
-void HLSLIntrinsicSet::calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRUse* inArgs, Index argCount, HLSLIntrinsic& out)
-{
- returnType = m_typeSet->getType(returnType);
-
- if (argCount <= 8)
- {
- IRType* argTypes[8];
-
- for (Index i = 0; i < argCount; ++i)
- {
- auto operand = inArgs[i].get();
- argTypes[i] = m_typeSet->getType(operand->getDataType());
- }
- _calcIntrinsic(op, returnType, argTypes, argCount, out);
- }
- else
- {
- List<IRType*> argTypes;
- argTypes.setCount(argCount);
-
- for (Index i = 0; i < argCount; ++i)
- {
- auto operand = inArgs[i].get();
- argTypes[i] = m_typeSet->getType(operand->getDataType());
- }
- _calcIntrinsic(op, returnType, argTypes.getBuffer(), argCount, out);
- }
-}
-
-HLSLIntrinsic* HLSLIntrinsicSet::add(IRInst* inst)
-{
- HLSLIntrinsic intrinsic;
- if (SLANG_SUCCEEDED(makeIntrinsic(inst, intrinsic)))
- {
- return add(intrinsic);
- }
- return nullptr;
-}
-
-SlangResult HLSLIntrinsicSet::makeIntrinsic(IRInst* inst, HLSLIntrinsic& out)
-{
- // Mark as invalid...
- out.op = Op::Invalid;
-
- {
- // See if we can just directly convert
- Op op = HLSLIntrinsicOpLookup::getOpForIROp(inst->getOp());
-
-
- // HACK: some cases we want to stop handling via the synthesis
- // path, but only for vector and matrix types (not scalars).
- //
- switch( op )
- {
- default: break;
-
- case Op::AsFloat:
- case Op::AsInt:
- case Op::AsUInt:
- // Note: the `any()`/`all()` case can't be handled via a stdlib definition
- // right now because `bool` vectors map to `int` vectors on the CUDA
- // path, so that the generated `geAt` operation is incorrect.
- //
-// case Op::Any:
-// case Op::All:
- {
- IRType* srcType = inst->getOperand(0)->getDataType();
- switch( srcType->getOp() )
- {
- default:
- break;
-
- case kIROp_VectorType:
- case kIROp_MatrixType:
- return SLANG_FAIL;
- }
- }
- break;
- }
-
-
- if (op != Op::Invalid)
- {
- calcIntrinsic(op, inst, inst->getOperandCount(), out);
- return SLANG_OK;
- }
- }
-
- // All the special cases
- switch (inst->getOp())
- {
- case kIROp_MakeVectorFromScalar:
- case kIROp_MakeMatrixFromScalar:
- {
- SLANG_ASSERT(inst->getOperandCount() == 1);
- calcIntrinsic(Op::ConstructFromScalar, inst, 1, out);
- return SLANG_OK;
- }
- case kIROp_IntCast:
- case kIROp_FloatCast:
- case kIROp_CastIntToFloat:
- case kIROp_CastFloatToInt:
- {
- IRType* dstType = inst->getDataType();
- IRType* srcType = inst->getOperand(0)->getDataType();
-
- if ((dstType->getOp() == kIROp_VectorType || dstType->getOp() == kIROp_MatrixType) &&
- inst->getOperandCount() == 1)
- {
- if (as<IRBasicType>(srcType))
- {
- calcIntrinsic(Op::ConstructFromScalar, inst, out);
- }
- else
- {
- SLANG_ASSERT(m_typeSet->getType(dstType) != m_typeSet->getType(srcType));
- // If it's constructed from a type conversion
- calcIntrinsic(Op::ConstructConvert, inst, out);
- }
- return SLANG_OK;
- }
- else
- {
- // If we are constructing a basic type, we don't need an Op::Init
- if (!IRBasicType::isaImpl(dstType->getOp()))
- {
- // Emit the 'init' intrinsic
- calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out);
- return SLANG_OK;
- }
- }
- return SLANG_FAIL;
- }
- case kIROp_MakeVector:
- case kIROp_VectorReshape:
- {
- if (inst->getOperandCount() == 1 && as<IRBasicType>(inst->getOperand(0)->getDataType()))
- {
- // This is make from scalar
- calcIntrinsic(Op::ConstructFromScalar, inst, out);
- }
- else
- {
- calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out);
- }
- return SLANG_OK;
- }
- case kIROp_MakeMatrix:
- case kIROp_MatrixReshape:
- {
- // We only emit as if it has one operand, but we can tell how many it actually has from the return type
- calcIntrinsic(Op::Init, inst, inst->getOperandCount(), out);
- return SLANG_OK;
- }
- case kIROp_swizzle:
- {
- // We don't need to add swizzle function, but we do output the need for some other functions
-
- // For C++ we don't need to emit a swizzle function
- // For C we need a construction function
- auto swizzleInst = static_cast<IRSwizzle*>(inst);
-
- IRInst* baseInst = swizzleInst->getBase();
- IRType* baseType = baseInst->getDataType();
-
- // If we are swizzling from a built in type,
- if (as<IRBasicType>(baseType))
- {
- // We can swizzle a scalar type to be a vector, or just a scalar
- IRType* dstType = swizzleInst->getDataType();
- if (!as<IRBasicType>(dstType))
- {
- // If it's a scalar make sure we have construct from scalar, because we will want to use that
- SLANG_ASSERT(dstType->getOp() == kIROp_VectorType);
- IRType* argTypes[] = { baseType };
- calcIntrinsic(Op::ConstructFromScalar, inst->getDataType(), argTypes, 1, out);
- return SLANG_OK;
- }
- }
- else
- {
- const Index elementCount = Index(swizzleInst->getElementCount());
- if (elementCount >= 1)
- {
- // Will need to generate a swizzle method
- calcIntrinsic(Op::Swizzle, inst, out);
- return SLANG_OK;
- }
- }
- break;
- }
- case kIROp_GetElement:
- {
- IRInst* target = inst->getOperand(0);
- IRType* targetType = target->getDataType();
- if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType)
- {
- // Specially handle this
- calcIntrinsic(Op::GetAt, inst, out);
- return SLANG_OK;
- }
- break;
- }
- case kIROp_GetElementPtr:
- {
- IRInst* target = inst->getOperand(0);
- IRType* targetType = target->getDataType();
-
- if (auto ptrType = as<IRPtrType>(targetType))
- {
- targetType = as<IRType>(ptrType->getOperand(0));
- if (targetType->getOp() == kIROp_VectorType || targetType->getOp() == kIROp_MatrixType)
- {
- // Specially handle this
- calcIntrinsic(Op::GetAt, inst, out);
- return SLANG_OK;
- }
- }
- break;
- }
- case kIROp_Call:
- {
- IRCall* callInst = (IRCall*)inst;
- auto funcValue = callInst->getCallee();
-
- const Op op = m_opLookup->getOpFromTargetDecoration(funcValue);
- if (op != Op::Invalid)
- {
- calcIntrinsic(op, inst->getDataType(), callInst->getArgs(), callInst->getArgCount(), out);
- return SLANG_OK;
- }
- break;
- }
-
- default: break;
- }
-
- return SLANG_FAIL;
-}
-
-void HLSLIntrinsicSet::getIntrinsics(List<const HLSLIntrinsic*>& out) const
-{
- for (auto& intrinsic : m_intrinsicsList)
- {
- out.add(intrinsic);
- }
-}
-
-HLSLIntrinsic* HLSLIntrinsicSet::add(const HLSLIntrinsic& intrinsic)
-{
- // Make sure it's valid(!)
- SLANG_ASSERT(intrinsic.op != Op::Invalid);
-
- HLSLIntrinsic* copy = (HLSLIntrinsic*)m_intrinsicFreeList.allocate();
- *copy = intrinsic;
- HLSLIntrinsicRef ref(copy);
- HLSLIntrinsic** found = m_intrinsicsDict.TryGetValueOrAdd(ref, copy);
- if (found)
- {
- // If we have found an intrinsic, we can free the copy
- m_intrinsicFreeList.deallocate(copy);
- return *found;
- }
-
- // If we are adding an intrinsic for the first time,
- // it should be added to the deduplicated list
- m_intrinsicsList.add(copy);
-
- return copy;
-}
-
-// !!!!!!!!!!!!!!!!!!!!!!!!!!!!! HLSLIntrinsicOpLookup !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
-
-HLSLIntrinsicOpLookup::HLSLIntrinsicOpLookup():
- m_slicePool(StringSlicePool::Style::Default)
-{
- // Add all the operations with names (not ops like -, / etc) to the lookup map
- for (int i = 0; i < SLANG_COUNT_OF(HLSLIntrinsic::s_operationInfos); ++i)
- {
- const auto& info = HLSLIntrinsic::getInfo(Op(i));
- UnownedStringSlice slice = info.funcName;
-
- if (slice.getLength() > 0 && slice[0] >= 'a' && slice[0] <= 'z')
- {
- auto handle = m_slicePool.add(slice);
- Index index = Index(handle);
- // Make sure there is space
- if (index >= m_sliceToOpMap.getCount())
- {
- Index oldSize = m_sliceToOpMap.getCount();
- m_sliceToOpMap.setCount(index + 1);
- for (Index j = oldSize; j < index; j++)
- {
- m_sliceToOpMap[j] = Op::Invalid;
- }
- }
- m_sliceToOpMap[index] = Op(i);
- }
- }
-}
-
-HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpByName(const UnownedStringSlice& slice)
-{
- const Index index = m_slicePool.findIndex(slice);
- return (index >= 0 && index < m_sliceToOpMap.getCount()) ? m_sliceToOpMap[index] : Op::Invalid;
-}
-
-static IRInst* _getSpecializedValue(IRSpecialize* specInst)
-{
- auto base = specInst->getBase();
- auto baseGeneric = as<IRGeneric>(base);
- if (!baseGeneric)
- return base;
-
- auto lastBlock = baseGeneric->getLastBlock();
- if (!lastBlock)
- return base;
-
- auto returnInst = as<IRReturn>(lastBlock->getTerminator());
- if (!returnInst)
- return base;
-
- return returnInst->getVal();
-}
-
-HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpFromTargetDecoration(IRInst* inInst)
-{
- // An intrinsic generic function will be invoked through a `specialize` instruction,
- // so the callee won't directly be the thing that is decorated. We will look up
- // through specializations until we can see the actual thing being called.
- //
- IRInst* inst = inInst;
- while (auto specInst = as<IRSpecialize>(inst))
- {
- inst = _getSpecializedValue(specInst);
-
- // If `getSpecializedValue` can't find the result value
- // of the generic being specialized, then it returns
- // the original instruction. This would be a disaster
- // for use because this loop would go on forever.
- //
- // This case should never happen if the stdlib is well-formed
- // and the compiler is doing its job right.
- //
- SLANG_ASSERT(inst != specInst);
- }
-
- // We are just looking for the original name so we can match against it
- for (auto dd : inst->getDecorations())
- {
- if (auto decor = as<IRTargetIntrinsicDecoration>(dd))
- {
- // TODO(JS): Should confirm that we'll always have this entry - which we need for lookups to work (we need the name
- // not a targets transformation)
- //
- // It turns out that addCatchAllIntrinsicDecorationIfNeeded will add a target intrinsic with the
- // original HLSL name, which has an empty `CapabilitySet`.
- //
- // It's not 100% clear this covers all the cases, but for now lets go with that
- if (decor->getTargetCaps().isEmpty())
- {
- Op op = getOpByName(decor->getDefinition());
- if (op != Op::Invalid)
- {
- return op;
- }
- }
- }
- }
-
- return Op::Invalid;
-}
-
-HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IRInst* inst)
-{
- switch (inst->getOp())
- {
- case kIROp_Call:
- {
- return getOpFromTargetDecoration(inst);
- }
- default: break;
- }
- return getOpForIROp(inst->getOp());
-}
-
-/* static */HLSLIntrinsic::Op HLSLIntrinsicOpLookup::getOpForIROp(IROp op)
-{
- switch (op)
- {
- case kIROp_Add: return Op::Add;
- case kIROp_Mul: return Op::Mul;
- case kIROp_Sub: return Op::Sub;
- case kIROp_Div: return Op::Div;
- case kIROp_Lsh: return Op::Lsh;
- case kIROp_Rsh: return Op::Rsh;
- case kIROp_IRem: return Op::IRem;
- case kIROp_FRem: return Op::FRem;
-
- case kIROp_Eql: return Op::Eql;
- case kIROp_Neq: return Op::Neq;
- case kIROp_Greater: return Op::Greater;
- case kIROp_Less: return Op::Less;
- case kIROp_Geq: return Op::Geq;
- case kIROp_Leq: return Op::Leq;
-
- case kIROp_BitAnd: return Op::BitAnd;
- case kIROp_BitXor: return Op::BitXor;
- case kIROp_BitOr: return Op::BitOr;
-
- case kIROp_And: return Op::And;
- case kIROp_Or: return Op::Or;
-
- case kIROp_Neg: return Op::Neg;
- case kIROp_Not: return Op::Not;
- case kIROp_BitNot: return Op::BitNot;
-
- case kIROp_MakeVectorFromScalar: return Op::ConstructFromScalar;
-
- default: return Op::Invalid;
- }
-}
-
-}
diff --git a/source/slang/slang-hlsl-intrinsic-set.h b/source/slang/slang-hlsl-intrinsic-set.h
index 3dc2996c1..8368491db 100644
--- a/source/slang/slang-hlsl-intrinsic-set.h
+++ b/source/slang/slang-hlsl-intrinsic-set.h
@@ -11,217 +11,5 @@
namespace Slang
{
-/* TODO(JS): Note that there are multiple methods to handle 'construction' operations. That is because 'construct' is used as a kind of
-generic 'construction' for built in types including vectors and matrices.
-
-For the moment the cpp emit code, determines what kind of construct is needed, and has special handling for ConstructConvert and
-ConstructFromScalar.
-
-That currently we do not see MakeVectorFromScalar - for example when we do...
-
-int2 fromScalar = 1;
-
-This appears as a construction from an int.
-
-That the better thing to do would be that there were IR instructions for the specific types of construction. I suppose there is a question
-about whether there should be separate instructions for vector/matrix, or emit code should just use the destination type. In practice I think
-it's fine that there isn't an instruction separating vector/matrix. That being the case I guess we arguably don't need MakeVectorFromScalar,
-just constructXXXFromScalar. Would be good if there was a suitable name to encompass vector/matrix.
-*/
-#define SLANG_HLSL_INTRINSIC_OP(x) \
- x(Invalid, "", -1) \
- x(Init, "", -1) \
- \
- x(Mul, "*", 2) \
- x(Div, "/", 2) \
- x(Add, "+", 2) \
- x(Sub, "-", 2) \
- x(Lsh, "<<", 2) \
- x(Rsh, ">>", 2) \
- x(IRem, "%", 2) \
- x(FRem, "fmod", 2) \
- \
- x(Eql, "==", 2) \
- x(Neq, "!=", 2) \
- x(Greater, ">", 2) \
- x(Less, "<", 2) \
- x(Geq, ">=", 2) \
- x(Leq, "<=", 2) \
- \
- x(BitAnd, "&", 2) \
- x(BitXor, "^", 2) \
- x(BitOr, "|" , 2) \
- \
- x(And, "&&", 2) \
- x(Or, "||", 2) \
- \
- x(Neg, "-", 1) \
- x(Not, "!", 1) \
- x(BitNot, "~", 1) \
- \
- x(Any, "any", 1) \
- x(All, "all", 1) \
- \
- x(Swizzle, "", -1) \
- \
- x(AsFloat, "asfloat", 1) \
- x(AsInt, "asint", -1) \
- x(AsUInt, "asuint", -1) \
- x(AsDouble, "asdouble", 2) \
- \
- x(ConstructConvert, "", 1) \
- x(ConstructFromScalar, "", 1) \
- \
- x(GetAt, "", 2) \
- /* end */
-
-struct HLSLIntrinsic
-{
- typedef HLSLIntrinsic ThisType;
-
- enum class Op : uint8_t
- {
-#define SLANG_HLSL_INTRINSIC_OP_ENUM(name, hlslName, numOperands) name,
- SLANG_HLSL_INTRINSIC_OP(SLANG_HLSL_INTRINSIC_OP_ENUM)
- };
-
- struct Info
- {
- UnownedStringSlice name; ///< The enum name
- UnownedStringSlice funcName; ///< The HLSL function name (if there is one)
- int8_t numOperands; ///< -1 if can't be handled automatically via amount of params
- };
-
- bool operator==(const ThisType& rhs) const { return op == rhs.op && returnType == rhs.returnType && signatureType == rhs.signatureType; }
- bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
-
- static bool isTypeScalar(IRType* type)
- {
- // Strip off ptr if it's an operand type
- if (type->getOp() == kIROp_PtrType)
- {
- type = as<IRType>(type->getOperand(0));
- }
- // If any are vec or matrix, then we
- return !(type->getOp() == kIROp_MatrixType || type->getOp() == kIROp_VectorType);
- }
-
- bool isScalar() const
- {
- Index paramCount = Index(signatureType->getParamCount());
- for (Index i = 0; i < paramCount; ++i)
- {
- if (!isTypeScalar(signatureType->getParamType(i)))
- {
- return false;
- }
- }
- return isTypeScalar(returnType);
- }
-
- HashCode getHashCode() const { return combineHash(int(op), combineHash(Slang::getHashCode(returnType), Slang::getHashCode(signatureType))); }
-
- static const Info& getInfo(Op op) { return s_operationInfos[Index(op)]; }
- static const Info s_operationInfos[];
-
- Op op;
- IRType* returnType;
- IRFuncType* signatureType; // Same as funcType, but has return type of void
-};
-
-/* A helper type that allows comparing pointers to HLSLIntrinsic types as if they are the values */
-struct HLSLIntrinsicRef
-{
- typedef HLSLIntrinsicRef ThisType;
-
- HashCode getHashCode() const { return m_intrinsic->getHashCode(); }
- bool operator==(const ThisType& rhs) const { return m_intrinsic == rhs.m_intrinsic || (*m_intrinsic == *rhs.m_intrinsic); }
- bool operator!=(const ThisType& rhs) const { return !(*this == rhs); }
-
- HLSLIntrinsicRef():m_intrinsic(nullptr) {}
- HLSLIntrinsicRef(const ThisType& rhs):m_intrinsic(rhs.m_intrinsic) {}
- HLSLIntrinsicRef(const HLSLIntrinsic* intrinsic): m_intrinsic(intrinsic) {}
- void operator=(const ThisType& rhs) { m_intrinsic = rhs.m_intrinsic; }
-
- const HLSLIntrinsic* m_intrinsic;
-};
-
-class HLSLIntrinsicOpLookup : public RefObject
-{
-public:
- typedef HLSLIntrinsic::Op Op;
-
- Op getOpFromTargetDecoration(IRInst* inInst);
- Op getOpByName(const UnownedStringSlice& slice);
-
- Op getOpForIROp(IRInst* inst);
-
- HLSLIntrinsicOpLookup();
-
- /// Given an IROp returns the Op equivalent or Op::Invalid if not found
- static Op getOpForIROp(IROp op);
-
-protected:
-
- StringSlicePool m_slicePool;
- List<Op> m_sliceToOpMap;
-};
-
-
-/* This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic.
-That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to
-work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms
-of other passes.
-Even if it was the case when we may want to add types as part of emitting, we can't use the previously used
-shared builder, so again we end up with pointers to the same things not being the same thing.
-
-To work around this we clone types we want to use as keys into the 'unique module'.
-This is not necessary for all types though - as we assume nominal types *must* have unique pointers (that is the
-definition of nominal).
-
-This could be handled in other ways (for example not testing equality on pointer equality). Anyway for now this
-works, but probably needs to be handled in a better way. The better way may involve having guarantees about equality
-enabled in other code generation and making de-duping possible in emit code.
-
-Note that one pro for this approach is that it does not alter the source module. That as it stands it's not necessary
-for the source module to be immutable, because it is created for emitting and then discarded.
- */
-class HLSLIntrinsicSet
-{
-public:
- typedef HLSLIntrinsic::Op Op;
-
- /* Note that calculating an intrinsic, the types will be added to the type set. That might mean subsequent code will
- emit those types being required, which may not be the case */
-
- void calcIntrinsic(Op op, IRType* returnType, IRType*const* args, Index argsCount, HLSLIntrinsic& out);
- void calcIntrinsic(Op op, IRInst* inst, Index argsCount, HLSLIntrinsic& out);
- void calcIntrinsic(Op op, IRType* returnType, IRUse* args, Index argCount, HLSLIntrinsic& out);
- void calcIntrinsic(Op op, IRInst* inst, HLSLIntrinsic& out) { calcIntrinsic(op, inst, Index(inst->getOperandCount()), out); }
-
- SlangResult makeIntrinsic(IRInst* inst, HLSLIntrinsic& out);
-
- HLSLIntrinsic* add(const HLSLIntrinsic& intrinsic);
-
- /// Returns the intrinsic constructed if there is one from the inst. If not possible to construct returns nullptr.
- HLSLIntrinsic* add(IRInst* inst);
-
- void getIntrinsics(List<const HLSLIntrinsic*>& out) const;
-
- HLSLIntrinsicSet(IRTypeSet* typeSet, HLSLIntrinsicOpLookup* lookup);
-
-protected:
- // All calcs must go through this choke point for some special case handling.
- // NOTE that this function must only be called with unique types (ie from the m_typeSet)
- void _calcIntrinsic(HLSLIntrinsic::Op op, IRType* returnType, IRType*const* inArgs, Index argsCount, HLSLIntrinsic& out);
-
- List<HLSLIntrinsic*> m_intrinsicsList;
- Dictionary<HLSLIntrinsicRef, HLSLIntrinsic*> m_intrinsicsDict;
-
- FreeList m_intrinsicFreeList; ///< the storage for the intrinsics when they are in the map
-
- HLSLIntrinsicOpLookup* m_opLookup;
- IRTypeSet* m_typeSet;
-};
} // namespace Slang
diff --git a/source/slang/slang-ir-address-analysis.cpp b/source/slang/slang-ir-address-analysis.cpp
index aba59e1de..1473bc466 100644
--- a/source/slang/slang-ir-address-analysis.cpp
+++ b/source/slang/slang-ir-address-analysis.cpp
@@ -79,9 +79,8 @@ namespace Slang
// Deduplicate and move known address insts.
for (auto block : func->getBlocks())
{
- for (auto inst = block->getFirstChild(); inst;)
+ for (auto inst : block->getModifiableChildren())
{
- auto next = inst->getNextInst();
switch (inst->getOp())
{
case kIROp_Var:
@@ -151,7 +150,6 @@ namespace Slang
}
break;
}
- inst = next;
}
}
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index b5d3dba10..1f599a344 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -170,40 +170,36 @@ InstPair ForwardDiffTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRIns
{
SLANG_ASSERT(origLogic->getOperandCount() == 2);
- // TODO: Check other boolean cases.
- if (as<IRBoolType>(origLogic->getDataType()))
- {
- // Boolean operations are not differentiable. For the linearization
- // pass, we do not need to do anything but copy them over to the ne
- // function.
- auto primalLogic = maybeCloneForPrimalInst(builder, origLogic);
- return InstPair(primalLogic, nullptr);
- }
-
- SLANG_UNEXPECTED("Logical operation with non-boolean result");
+ // Boolean operations are not differentiable. For the linearization
+ // pass, we do not need to do anything but copy them over to the ne
+ // function.
+ auto primalLogic = maybeCloneForPrimalInst(builder, origLogic);
+ return InstPair(primalLogic, nullptr);
}
InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr);
- auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType();
-
- if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType))
+ auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType());
+ if (primalPtrType)
{
- // Special case load from an `out` param, which will not have corresponding `diff` and
- // `primal` insts yet.
-
- // TODO: Could we move this load to _after_ DifferentialPairGetPrimal,
- // and DifferentialPairGetDifferential?
- //
- auto load = builder->emitLoad(primalPtr);
- builder->markInstAsMixedDifferential(load, diffPairType);
+ if (auto diffPairType = as<IRDifferentialPairType>(primalPtrType->getValueType()))
+ {
+ // Special case load from an `out` param, which will not have corresponding `diff` and
+ // `primal` insts yet.
- auto primalElement = builder->emitDifferentialPairGetPrimal(load);
- auto diffElement = builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
- return InstPair(primalElement, diffElement);
+ // TODO: Could we move this load to _after_ DifferentialPairGetPrimal,
+ // and DifferentialPairGetDifferential?
+ //
+ auto load = builder->emitLoad(primalPtr);
+ builder->markInstAsMixedDifferential(load, diffPairType);
+
+ auto primalElement = builder->emitDifferentialPairGetPrimal(load);
+ auto diffElement = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
+ return InstPair(primalElement, diffElement);
+ }
}
auto primalLoad = maybeCloneForPrimalInst(builder, origLoad);
@@ -492,7 +488,6 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (!diffReturnType)
{
- SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
diffReturnType = argBuilder.getVoidType();
}
@@ -1364,6 +1359,8 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_Or:
case kIROp_Geq:
case kIROp_Leq:
+ case kIROp_Eql:
+ case kIROp_Neq:
return transcribeBinaryLogic(builder, origInst);
case kIROp_CastIntToFloat:
@@ -1452,7 +1449,27 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_undefined:
return transcribeUndefined(builder, origInst);
+ case kIROp_Not:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitXor:
+ case kIROp_BitCast:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_IRem:
+ case kIROp_ByteAddressBufferLoad:
+ case kIROp_ByteAddressBufferStore:
+ case kIROp_StructuredBufferLoad:
+ case kIROp_StructuredBufferStore:
+ case kIROp_Reinterpret:
+ case kIROp_IsType:
+ case kIROp_ImageSubscript:
+ case kIROp_ImageLoad:
+ case kIROp_ImageStore:
case kIROp_CreateExistentialObject:
+ case kIROp_PackAnyValue:
+ case kIROp_UnpackAnyValue:
+ case kIROp_GetNativePtr:
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
// so we treat this inst as non differentiable.
// We can extend the frontend and IR with a separate op-code that can provide an explicit diff value.
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index d83ff57e4..d10a9349d 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -1256,10 +1256,8 @@ struct DiffUnzipPass
diffBuilder.setInsertInto(diffBlock);
List<IRInst*> splitInsts;
- for (auto child = block->getFirstChild(); child;)
+ for (auto child : block->getModifiableChildren())
{
- IRInst* nextChild = child->getNextInst();
-
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child))
{
// Replace GetDiff(A) with A.d
@@ -1267,7 +1265,6 @@ struct DiffUnzipPass
{
getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase()));
getDiffInst->removeAndDeallocate();
- child = nextChild;
continue;
}
}
@@ -1278,7 +1275,6 @@ struct DiffUnzipPass
{
getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase()));
getPrimalInst->removeAndDeallocate();
- child = nextChild;
continue;
}
}
@@ -1296,8 +1292,6 @@ struct DiffUnzipPass
{
child->insertAtEnd(primalBlock);
}
-
- child = nextChild;
}
// Remove insts that were split.
diff --git a/source/slang/slang-ir-byte-address-legalize.cpp b/source/slang/slang-ir-byte-address-legalize.cpp
index 3a8d1852a..721efadaf 100644
--- a/source/slang/slang-ir-byte-address-legalize.cpp
+++ b/source/slang/slang-ir-byte-address-legalize.cpp
@@ -66,11 +66,8 @@ struct ByteAddressBufferLegalizationContext
break;
}
-
- IRInst* nextChild = nullptr;
- for( IRInst* child = inst->getFirstChild(); child; child = nextChild )
+ for( IRInst* child : inst->getModifiableChildren())
{
- nextChild = child->getNextInst();
processInstRec(child);
}
}
diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp
index dbeb1e934..8b8b28f09 100644
--- a/source/slang/slang-ir-clone.cpp
+++ b/source/slang/slang-ir-clone.cpp
@@ -72,29 +72,29 @@ IRInst* cloneInstAndOperands(
auto oldType = oldInst->getFullType();
auto newType = (IRType*) findCloneForOperand(env, oldType);
- // Next we will create an empty shell of the instruction,
- // with space for the operands, but no actual operand
- // values attached.
- //
- UInt operandCount = oldInst->getOperandCount();
- auto newInst = builder->emitIntrinsicInst(
- newType,
- oldInst->getOp(),
- operandCount,
- nullptr);
-
- // Finally we will iterate over the operands of `oldInst`
+ // Next we will iterate over the operands of `oldInst`
// to find their replacements and install them as
// the operands of `newInst`.
//
- for(UInt ii = 0; ii < operandCount; ++ii)
+ UInt operandCount = oldInst->getOperandCount();
+
+ ShortList<IRInst*> newOperands;
+ newOperands.setCount(operandCount);
+ for (UInt ii = 0; ii < operandCount; ++ii)
{
auto oldOperand = oldInst->getOperand(ii);
auto newOperand = findCloneForOperand(env, oldOperand);
- newInst->getOperands()[ii].init(newInst, newOperand);
+ newOperands[ii] = newOperand;
}
+ // Finally we create the inst with the updated operands.
+ auto newInst = builder->emitIntrinsicInst(
+ newType,
+ oldInst->getOp(),
+ operandCount,
+ newOperands.getArrayView().getBuffer());
+
newInst->sourceLoc = oldInst->sourceLoc;
return newInst;
diff --git a/source/slang/slang-ir-collect-global-uniforms.cpp b/source/slang/slang-ir-collect-global-uniforms.cpp
index ca5e56b53..ad0dfda91 100644
--- a/source/slang/slang-ir-collect-global-uniforms.cpp
+++ b/source/slang/slang-ir-collect-global-uniforms.cpp
@@ -192,7 +192,8 @@ struct CollectGlobalUniformParametersContext
// per-field layout information to reference the key we created
// instead of the existing parameter (which we will be removing).
//
- fieldLayoutAttr->setOperand(0, fieldKey);
+ fieldLayoutAttr = as<IRStructFieldLayoutAttr>(
+ builder->replaceOperand(fieldLayoutAttr->getOperands(), fieldKey));
// If the given parameter doesn't contribute to uniform/ordinary usage, then
// we can safely leave it at the global scope and potentially avoid a lot
@@ -266,7 +267,7 @@ struct CollectGlobalUniformParametersContext
//
if(auto layoutAttr = as<IRStructFieldLayoutAttr>(user))
{
- layoutAttr->setOperand(0, fieldKey);
+ builder->replaceOperand(layoutAttr->getOperands(), fieldKey);
continue;
}
diff --git a/source/slang/slang-ir-com-interface.cpp b/source/slang/slang-ir-com-interface.cpp
index 3e52054cd..0684cc8e6 100644
--- a/source/slang/slang-ir-com-interface.cpp
+++ b/source/slang/slang-ir-com-interface.cpp
@@ -105,7 +105,7 @@ void lowerComInterfaces(IRModule* module, ArtifactStyle artifactStyle, Diagnosti
for (auto use : uses)
{
// Do the replacement
- use->set(result);
+ builder.replaceOperand(use, result);
}
}
}
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 05c10b317..251b473e0 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -237,14 +237,16 @@ struct DeadCodeEliminationContext
// might still be dead.
//
// The biggest wrinkle is that we walk the linked list of
- // children/decorations a bit carefully, using a temporary
- // to hold the next node, in case we eliminate one of
- // the children as we go.
+ // children/decorations a bit carefully, because eliminating one inst
+ // may cause the other nodes to be hoisted out of the current scope.
+ // We need to cache all children in a work list to ensure they are
+ // properly traversed.
//
- IRInst* next = nullptr;
- for( IRInst* child = inst->getFirstDecorationOrChild(); child; child = next )
+ List<IRInst*> children;
+ for (auto child : inst->getDecorationsAndChildren())
+ children.add(child);
+ for(IRInst* child : children)
{
- next = child->getNextInst();
changed |= eliminateDeadInstsRec(child);
}
}
diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp
index 51a677627..74efc3cb3 100644
--- a/source/slang/slang-ir-deduplicate.cpp
+++ b/source/slang/slang-ir-deduplicate.cpp
@@ -2,116 +2,84 @@
namespace Slang
{
- struct DeduplicateContext
+ void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap()
{
- SharedIRBuilder* builder;
- IRInst* addValue(IRInst* value)
- {
- if (!value) return nullptr;
- if (as<IRType>(value))
- return addTypeValue(value);
- if (auto constValue = as<IRConstant>(value))
- return addConstantValue(constValue);
- return value;
- }
- IRInst* addConstantValue(IRConstant* value)
- {
- IRConstantKey key = { value };
- value->setFullType((IRType*)addValue(value->getFullType()));
- if (auto newValue = builder->getConstantMap().TryGetValue(key))
- return *newValue;
- builder->getConstantMap()[key] = value;
- return value;
- }
- IRInst* addTypeValue(IRInst* value)
- {
- // Do not deduplicate struct or interface types.
- switch (value->getOp())
- {
- case kIROp_StructType:
- case kIROp_InterfaceType:
- return value;
- default:
- break;
- }
+ }
- for (UInt i = 0; i < value->getOperandCount(); i++)
- {
- value->setOperand(i, addValue(value->getOperand(i)));
- }
- value->setFullType((IRType*)addValue(value->getFullType()));
- IRInstKey key = { value };
- if (auto newValue = builder->getGlobalValueNumberingMap().TryGetValue(key))
- return *newValue;
- builder->getGlobalValueNumberingMap()[key] = value;
- return value;
- }
- };
- void SharedIRBuilder::deduplicateAndRebuildGlobalNumberingMap()
+ void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst)
+ {
+ oldInst->replaceUsesWith(newInst);
+ }
+
+ void SharedIRBuilder::removeHoistableInstFromGlobalNumberingMap(IRInst* instToRemove)
{
- DeduplicateContext context;
- context.builder = this;
- m_constantMap.Clear();
- m_globalValueNumberingMap.Clear();
- List<IRInst*> instToRemove;
- for (auto inst : m_module->getGlobalInsts())
+ HashSet<IRInst*> userWorkListSet;
+ List<IRInst*> userWorkList;
+ auto addToWorkList = [&](IRInst* i)
{
- if (auto constVal = as<IRConstant>(inst))
- {
- auto newConst = context.addConstantValue(constVal);
- if (newConst != constVal)
- {
- constVal->replaceUsesWith(newConst);
- instToRemove.add(constVal);
- }
- }
- }
- for (auto inst : m_module->getGlobalInsts())
+ if (userWorkListSet.Add(i))
+ userWorkList.add(i);
+ };
+ addToWorkList(instToRemove);
+ for (Index i = 0; i < userWorkList.getCount(); i++)
{
- if (as<IRType>(inst) || as<IRSpecialize>(inst))
+ auto inst = userWorkList[i];
+ if (getIROpInfo(inst->getOp()).isHoistable())
{
- auto newInst = context.addTypeValue(inst);
- if (newInst != inst)
+ _removeGlobalNumberingEntry(inst);
+ for (auto use = inst->firstUse; use; use = use->nextUse)
{
- inst->replaceUsesWith(newInst);
- instToRemove.add(inst);
+ addToWorkList(use->getUser());
}
}
}
- for (auto inst : instToRemove)
- inst->removeAndDeallocate();
}
- void SharedIRBuilder::replaceGlobalInst(IRInst* oldInst, IRInst* newInst)
+ void addHoistableInst(
+ IRBuilder* builder,
+ IRInst* inst);
+
+ void SharedIRBuilder::tryHoistInst(IRInst* inst)
{
- List<IRUse*> uses;
- for (auto use = oldInst->firstUse; use; use = use->nextUse)
- {
- uses.add(use);
- }
+ List<IRInst*> workList;
+ HashSet<IRInst*> workListSet;
+ workList.add(inst);
+ workListSet.Add(inst);
+ IRBuilder builder(inst->getModule());
- bool shouldUpdateGlobalNumberedCache = false;
- for (auto use : uses)
+ for (Index i = 0; i < workList.getCount(); i++)
{
- use->set(newInst);
- // depending on the type of the user inst, we may need to rebuild and update the global
- // numbering cache.
- if (isGloballyNumberedInst(use->getUser()))
+ auto item = workList[i];
+
+ // Does inst no longer depend on anything defined locally?
+ // If so we should hoist it.
+ bool shouldHoist = false;
+ for (UInt a = 0; a < item->getOperandCount(); a++)
{
- shouldUpdateGlobalNumberedCache = true;
+ auto opParent = item->getOperand(a)->getParent();
+ if (opParent != item->getParent())
+ {
+ shouldHoist = true;
+ break;
+ }
}
- }
- oldInst->removeAndDeallocate();
- if (shouldUpdateGlobalNumberedCache)
- {
- deduplicateAndRebuildGlobalNumberingMap();
- }
- }
- bool SharedIRBuilder::isGloballyNumberedInst(IRInst* inst)
- {
- if (!inst->getParent() || inst->getParent()->getOp() != kIROp_Module)
- return false;
- return m_globalValueNumberingMap.ContainsKey(IRInstKey{inst});
+ // Hoisting this inst
+ if (shouldHoist)
+ {
+ item->removeFromParent();
+ addHoistableInst(&builder, item);
+
+ // Continue to consider all users for hoisting.
+ for (auto use = item->firstUse; use; use = use->nextUse)
+ {
+ if (getIROpInfo(use->getUser()->getOp()).isHoistable())
+ {
+ if (workListSet.Add(use->getUser()))
+ workList.add(use->getUser());
+ }
+ }
+ }
+ }
}
}
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 0dcd437fe..55d120228 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -1791,7 +1791,7 @@ void legalizeMeshOutputParam(
// the writes may only be writing to parts of the output struct, or may not
// be writes at all (i.e. being passed as an out paramter).
//
- traverseUses(g, [&](IRInst* u)
+ traverseUsers(g, [&](IRInst* u)
{
auto l = as<IRLoad>(u);
SLANG_EXPECT(l, "Mesh Output sentinel parameter wasn't used in a load");
@@ -1811,7 +1811,7 @@ void legalizeMeshOutputParam(
return;
}
// Otherwise, go through the uses one by one and see what we can do
- traverseUses(a, [&](IRInst* s)
+ traverseUsers(a, [&](IRInst* s)
{
IRBuilderInsertLocScope locScope{builder};
builder->setInsertBefore(s);
@@ -2022,7 +2022,7 @@ void legalizeMeshOutputParam(
for(auto builtin : builtins)
{
- traverseUses(builtin.param, [&](IRInst* u)
+ traverseUsers(builtin.param, [&](IRInst* u)
{
auto p = as<IRGetElementPtr>(u);
SLANG_EXPECT(p, "Mesh Output sentinel parameter wasn't used as an array");
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 7fc977170..643acdbb8 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -53,10 +53,8 @@ struct InliningPassBase
// so that even if `child` gets removed (because of inlining)
// we automatically start at the next instruction after it.
//
- IRInst* next = nullptr;
- for( auto child = inst->getFirstChild(); child; child = next )
+ for (auto child : inst->getModifiableChildren())
{
- next = child->getNextInst();
changed |= considerAllCallSitesRec(child);
}
return changed;
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 788e02c90..35877d680 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -10,6 +10,8 @@
#define PARENT kIROpFlag_Parent
#define USE_OTHER kIROpFlag_UseOther
+#define HOISTABLE kIROpFlag_Hoistable
+#define GLOBAL kIROpFlag_Global
INST(Nop, nop, 0, 0)
@@ -17,7 +19,7 @@ INST(Nop, nop, 0, 0)
/* Basic Types */
- #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, 0)
+ #define DEFINE_BASE_TYPE_INST(NAME) INST(NAME ## Type, NAME, 0, HOISTABLE)
FOREACH_BASE_TYPE(DEFINE_BASE_TYPE_INST)
#undef DEFINE_BASE_TYPE_INST
INST(AfterBaseType, afterBaseType, 0, 0)
@@ -25,42 +27,42 @@ INST(Nop, nop, 0, 0)
INST_RANGE(BasicType, VoidType, AfterBaseType)
/* StringTypeBase */
- INST(StringType, String, 0, 0)
- INST(NativeStringType, NativeString, 0, 0)
+ INST(StringType, String, 0, HOISTABLE)
+ INST(NativeStringType, NativeString, 0, HOISTABLE)
INST_RANGE(StringTypeBase, StringType, NativeStringType)
- INST(CapabilitySetType, CapabilitySet, 0, 0)
+ INST(CapabilitySetType, CapabilitySet, 0, HOISTABLE)
- INST(DynamicType, DynamicType, 0, 0)
+ INST(DynamicType, DynamicType, 0, HOISTABLE)
- INST(AnyValueType, AnyValueType, 1, 0)
+ INST(AnyValueType, AnyValueType, 1, HOISTABLE)
- INST(RawPointerType, RawPointerType, 0, 0)
- INST(RTTIPointerType, RTTIPointerType, 1, 0)
+ INST(RawPointerType, RawPointerType, 0, HOISTABLE)
+ INST(RTTIPointerType, RTTIPointerType, 1, HOISTABLE)
INST(AfterRawPointerTypeBase, AfterRawPointerTypeBase, 0, 0)
INST_RANGE(RawPointerTypeBase, RawPointerType, AfterRawPointerTypeBase)
/* ArrayTypeBase */
- INST(ArrayType, Array, 2, 0)
- INST(UnsizedArrayType, UnsizedArray, 1, 0)
+ INST(ArrayType, Array, 2, HOISTABLE)
+ INST(UnsizedArrayType, UnsizedArray, 1, HOISTABLE)
INST_RANGE(ArrayTypeBase, ArrayType, UnsizedArrayType)
- INST(FuncType, Func, 0, 0)
- INST(BasicBlockType, BasicBlock, 0, 0)
+ INST(FuncType, Func, 0, HOISTABLE)
+ INST(BasicBlockType, BasicBlock, 0, HOISTABLE)
- INST(VectorType, Vec, 2, 0)
- INST(MatrixType, Mat, 3, 0)
+ INST(VectorType, Vec, 2, HOISTABLE)
+ INST(MatrixType, Mat, 3, HOISTABLE)
- INST(TaggedUnionType, TaggedUnion, 0, 0)
+ INST(TaggedUnionType, TaggedUnion, 0, HOISTABLE)
- INST(ConjunctionType, Conjunction, 0, 0)
- INST(AttributedType, Attributed, 0, 0)
- INST(ResultType, Result, 2, 0)
- INST(OptionalType, Optional, 1, 0)
+ INST(ConjunctionType, Conjunction, 0, HOISTABLE)
+ INST(AttributedType, Attributed, 0, HOISTABLE)
+ INST(ResultType, Result, 2, HOISTABLE)
+ INST(OptionalType, Optional, 1, HOISTABLE)
- INST(DifferentialPairType, DiffPair, 1, 0)
- INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, 0)
+ INST(DifferentialPairType, DiffPair, 1, HOISTABLE)
+ INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE)
/* BindExistentialsTypeBase */
@@ -70,58 +72,58 @@ INST(Nop, nop, 0, 0)
// where each `Ti, wi` pair represents the concrete type
// and witness table to plug in for parameter `i`.
//
- INST(BindExistentialsType, BindExistentials, 1, 0)
+ INST(BindExistentialsType, BindExistentials, 1, HOISTABLE)
// An `BindInterface<B, T0, w0>` represents the special case
// of a `BindExistentials` where the type `B` is known to be
// an interface type.
//
- INST(BoundInterfaceType, BoundInterface, 3, 0)
+ INST(BoundInterfaceType, BoundInterface, 3, HOISTABLE)
INST_RANGE(BindExistentialsTypeBase, BindExistentialsType, BoundInterfaceType)
/* Rate */
- INST(ConstExprRate, ConstExpr, 0, 0)
- INST(GroupSharedRate, GroupShared, 0, 0)
- INST(ActualGlobalRate, ActualGlobalRate, 0, 0)
+ INST(ConstExprRate, ConstExpr, 0, HOISTABLE)
+ INST(GroupSharedRate, GroupShared, 0, HOISTABLE)
+ INST(ActualGlobalRate, ActualGlobalRate, 0, HOISTABLE)
INST_RANGE(Rate, ConstExprRate, GroupSharedRate)
- INST(RateQualifiedType, RateQualified, 2, 0)
+ INST(RateQualifiedType, RateQualified, 2, HOISTABLE)
// Kinds represent the "types of types."
// They should not really be nested under `IRType`
// in the overall hierarchy, but we can fix that later.
//
/* Kind */
- INST(TypeKind, Type, 0, 0)
- INST(RateKind, Rate, 0, 0)
- INST(GenericKind, Generic, 0, 0)
+ INST(TypeKind, Type, 0, HOISTABLE)
+ INST(RateKind, Rate, 0, HOISTABLE)
+ INST(GenericKind, Generic, 0, HOISTABLE)
INST_RANGE(Kind, TypeKind, GenericKind)
/* PtrTypeBase */
- INST(PtrType, Ptr, 1, 0)
- INST(RefType, Ref, 1, 0)
+ INST(PtrType, Ptr, 1, HOISTABLE)
+ INST(RefType, Ref, 1, HOISTABLE)
// A `PsuedoPtr<T>` logically represents a pointer to a value of type
// `T` on a platform that cannot support pointers. The expectation
// is that the "pointer" will be legalized away by storing a value
// of type `T` somewhere out-of-line.
- INST(PseudoPtrType, PseudoPtr, 1, 0)
+ INST(PseudoPtrType, PseudoPtr, 1, HOISTABLE)
/* OutTypeBase */
- INST(OutType, Out, 1, 0)
- INST(InOutType, InOut, 1, 0)
+ INST(OutType, Out, 1, HOISTABLE)
+ INST(InOutType, InOut, 1, HOISTABLE)
INST_RANGE(OutTypeBase, OutType, InOutType)
INST_RANGE(PtrTypeBase, PtrType, InOutType)
// A ComPtr<T> type is treated as a opaque type that represents a reference-counted handle to a COM object.
- INST(ComPtrType, ComPtr, 1, 0)
+ INST(ComPtrType, ComPtr, 1, HOISTABLE)
// A NativePtr<T> type represents a native pointer to a managed resource.
- INST(NativePtrType, NativePtr, 1, 0)
+ INST(NativePtrType, NativePtr, 1, HOISTABLE)
/* SamplerStateTypeBase */
- INST(SamplerStateType, SamplerState, 0, 0)
- INST(SamplerComparisonStateType, SamplerComparisonState, 0, 0)
+ INST(SamplerStateType, SamplerState, 0, HOISTABLE)
+ INST(SamplerComparisonStateType, SamplerComparisonState, 0, HOISTABLE)
INST_RANGE(SamplerStateTypeBase, SamplerStateType, SamplerComparisonStateType)
// TODO: Why do we have all this hierarchy here, when everything
@@ -131,11 +133,11 @@ INST(Nop, nop, 0, 0)
/* TextureTypeBase */
// NOTE! TextureFlavor::Flavor is stored in 'other' bits for these types.
/* TextureType */
- INST(TextureType, TextureType, 0, USE_OTHER)
+ INST(TextureType, TextureType, 0, USE_OTHER | HOISTABLE)
/* TextureSamplerType */
- INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER)
+ INST(TextureSamplerType, TextureSamplerType, 0, USE_OTHER | HOISTABLE)
/* GLSLImageType */
- INST(GLSLImageType, GLSLImageType, 0, USE_OTHER)
+ INST(GLSLImageType, GLSLImageType, 0, USE_OTHER | HOISTABLE)
INST_RANGE(TextureTypeBase, TextureType, GLSLImageType)
INST_RANGE(ResourceType, TextureType, GLSLImageType)
INST_RANGE(ResourceTypeBase, TextureType, GLSLImageType)
@@ -143,53 +145,53 @@ INST(Nop, nop, 0, 0)
/* UntypedBufferResourceType */
/* ByteAddressBufferTypeBase */
- INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, 0)
- INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, 0)
- INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, 0)
+ INST(HLSLByteAddressBufferType, ByteAddressBuffer, 0, HOISTABLE)
+ INST(HLSLRWByteAddressBufferType, RWByteAddressBuffer, 0, HOISTABLE)
+ INST(HLSLRasterizerOrderedByteAddressBufferType, RasterizerOrderedByteAddressBuffer, 0, HOISTABLE)
INST_RANGE(ByteAddressBufferTypeBase, HLSLByteAddressBufferType, HLSLRasterizerOrderedByteAddressBufferType)
- INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, 0)
+ INST(RaytracingAccelerationStructureType, RaytracingAccelerationStructure, 0, HOISTABLE)
INST_RANGE(UntypedBufferResourceType, HLSLByteAddressBufferType, RaytracingAccelerationStructureType)
/* HLSLPatchType */
- INST(HLSLInputPatchType, InputPatch, 2, 0)
- INST(HLSLOutputPatchType, OutputPatch, 2, 0)
+ INST(HLSLInputPatchType, InputPatch, 2, HOISTABLE)
+ INST(HLSLOutputPatchType, OutputPatch, 2, HOISTABLE)
INST_RANGE(HLSLPatchType, HLSLInputPatchType, HLSLOutputPatchType)
- INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, 0)
+ INST(GLSLInputAttachmentType, GLSLInputAttachment, 0, HOISTABLE)
/* BuiltinGenericType */
/* HLSLStreamOutputType */
- INST(HLSLPointStreamType, PointStream, 1, 0)
- INST(HLSLLineStreamType, LineStream, 1, 0)
- INST(HLSLTriangleStreamType, TriangleStream, 1, 0)
+ INST(HLSLPointStreamType, PointStream, 1, HOISTABLE)
+ INST(HLSLLineStreamType, LineStream, 1, HOISTABLE)
+ INST(HLSLTriangleStreamType, TriangleStream, 1, HOISTABLE)
INST_RANGE(HLSLStreamOutputType, HLSLPointStreamType, HLSLTriangleStreamType)
/* MeshOutputType */
- INST(VerticesType, Vertices, 2, 0)
- INST(IndicesType, Indices, 2, 0)
- INST(PrimitivesType, Primitives, 2, 0)
+ INST(VerticesType, Vertices, 2, HOISTABLE)
+ INST(IndicesType, Indices, 2, HOISTABLE)
+ INST(PrimitivesType, Primitives, 2, HOISTABLE)
INST_RANGE(MeshOutputType, VerticesType, PrimitivesType)
/* HLSLStructuredBufferTypeBase */
- INST(HLSLStructuredBufferType, StructuredBuffer, 0, 0)
- INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, 0)
- INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, 0)
- INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, 0)
- INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, 0)
+ INST(HLSLStructuredBufferType, StructuredBuffer, 0, HOISTABLE)
+ INST(HLSLRWStructuredBufferType, RWStructuredBuffer, 0, HOISTABLE)
+ INST(HLSLRasterizerOrderedStructuredBufferType, RasterizerOrderedStructuredBuffer, 0, HOISTABLE)
+ INST(HLSLAppendStructuredBufferType, AppendStructuredBuffer, 0, HOISTABLE)
+ INST(HLSLConsumeStructuredBufferType, ConsumeStructuredBuffer, 0, HOISTABLE)
INST_RANGE(HLSLStructuredBufferTypeBase, HLSLStructuredBufferType, HLSLConsumeStructuredBufferType)
/* PointerLikeType */
/* ParameterGroupType */
/* UniformParameterGroupType */
- INST(ConstantBufferType, ConstantBuffer, 1, 0)
- INST(TextureBufferType, TextureBuffer, 1, 0)
- INST(ParameterBlockType, ParameterBlock, 1, 0)
- INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, 0)
+ INST(ConstantBufferType, ConstantBuffer, 1, HOISTABLE)
+ INST(TextureBufferType, TextureBuffer, 1, HOISTABLE)
+ INST(ParameterBlockType, ParameterBlock, 1, HOISTABLE)
+ INST(GLSLShaderStorageBufferType, GLSLShaderStorageBuffer, 0, HOISTABLE)
INST_RANGE(UniformParameterGroupType, ConstantBufferType, GLSLShaderStorageBufferType)
/* VaryingParameterGroupType */
- INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, 0)
- INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, 0)
+ INST(GLSLInputParameterGroupType, GLSLInputParameterGroup, 0, HOISTABLE)
+ INST(GLSLOutputParameterGroupType, GLSLOutputParameterGroup, 0, HOISTABLE)
INST_RANGE(VaryingParameterGroupType, GLSLInputParameterGroupType, GLSLOutputParameterGroupType)
INST_RANGE(ParameterGroupType, ConstantBufferType, GLSLOutputParameterGroupType)
INST_RANGE(PointerLikeType, ConstantBufferType, GLSLOutputParameterGroupType)
@@ -209,28 +211,28 @@ INST(Nop, nop, 0, 0)
//
INST(StructType, struct, 0, PARENT)
INST(ClassType, class, 0, PARENT)
-INST(InterfaceType, interface, 0, 0)
-INST(AssociatedType, associated_type, 0, 0)
-INST(ThisType, this_type, 0, 0)
-INST(RTTIType, rtti_type, 0, 0)
-INST(RTTIHandleType, rtti_handle_type, 0, 0)
-INST(TupleType, tuple_type, 0, 0)
+INST(InterfaceType, interface, 0, GLOBAL)
+INST(AssociatedType, associated_type, 0, HOISTABLE)
+INST(ThisType, this_type, 0, HOISTABLE)
+INST(RTTIType, rtti_type, 0, HOISTABLE)
+INST(RTTIHandleType, rtti_handle_type, 0, HOISTABLE)
+INST(TupleType, tuple_type, 0, HOISTABLE)
// A type that identifies it's contained type as being emittable as `spirv_literal.
-INST(SPIRVLiteralType, spirvLiteralType, 1, 0)
+INST(SPIRVLiteralType, spirvLiteralType, 1, HOISTABLE)
// A TypeType-typed IRValue represents a IRType.
// It is used to represent a type parameter/argument in a generics.
-INST(TypeType, type_t, 0, 0)
+INST(TypeType, type_t, 0, HOISTABLE)
/*IRWitnessTableTypeBase*/
// An `IRWitnessTable` has type `WitnessTableType`.
- INST(WitnessTableType, witness_table_t, 1, 0)
+ INST(WitnessTableType, witness_table_t, 1, HOISTABLE)
// An integer type representing a witness table for targets where
// witness tables are represented as integer IDs. This type is used
// during the lower-generics pass while generating dynamic dispatch
// code and will eventually lower into an uint type.
- INST(WitnessTableIDType, witness_table_id_t, 1, 0)
+ INST(WitnessTableIDType, witness_table_id_t, 1, HOISTABLE)
INST_RANGE(WitnessTableTypeBase, WitnessTableType, WitnessTableIDType)
INST_RANGE(Type, VoidType, WitnessTableIDType)
@@ -240,14 +242,14 @@ INST_RANGE(Type, VoidType, WitnessTableIDType)
INST(Generic, generic, 0, PARENT)
INST_RANGE(GlobalValueWithParams, Func, Generic)
- INST(GlobalVar, global_var, 0, 0)
+ INST(GlobalVar, global_var, 0, GLOBAL)
INST_RANGE(GlobalValueWithCode, Func, GlobalVar)
-INST(GlobalParam, global_param, 0, 0)
-INST(GlobalConstant, globalConstant, 0, 0)
+INST(GlobalParam, global_param, 0, GLOBAL)
+INST(GlobalConstant, globalConstant, 0, GLOBAL)
-INST(StructKey, key, 0, 0)
-INST(GlobalGenericParam, global_generic_param, 0, 0)
+INST(StructKey, key, 0, GLOBAL)
+INST(GlobalGenericParam, global_generic_param, 0, GLOBAL)
INST(WitnessTable, witness_table, 0, 0)
INST(GlobalHashedStringLiterals, global_hashed_string_literals, 0, 0)
@@ -265,7 +267,7 @@ INST(Block, block, 0, PARENT)
INST(VoidLit, void_constant, 0, 0)
INST_RANGE(Constant, BoolLit, VoidLit)
-INST(CapabilitySet, capabilitySet, 0, 0)
+INST(CapabilitySet, capabilitySet, 0, HOISTABLE)
INST(undefined, undefined, 0, 0)
@@ -279,10 +281,9 @@ INST(MakeDifferentialPair, MakeDiffPair, 2, 0)
INST(DifferentialPairGetDifferential, GetDifferential, 1, 0)
INST(DifferentialPairGetPrimal, GetPrimal, 1, 0)
-INST(Specialize, specialize, 2, 0)
-INST(LookupWitness, lookupWitness, 2, 0)
+INST(Specialize, specialize, 2, HOISTABLE)
+INST(LookupWitness, lookupWitness, 2, HOISTABLE)
INST(GetSequentialID, GetSequentialID, 1, 0)
-INST(lookup_witness_table, lookup_witness_table, 2, 0)
INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0)
INST(AllocObj, allocObj, 0, 0)
@@ -317,7 +318,7 @@ INST(PackAnyValue, packAnyValue, 1, 0)
INST(UnpackAnyValue, unpackAnyValue, 1, 0)
INST(WitnessTableEntry, witness_table_entry, 2, 0)
-INST(InterfaceRequirementEntry, interface_req_entry, 2, 0)
+INST(InterfaceRequirementEntry, interface_req_entry, 2, GLOBAL)
INST(Param, param, 0, 0)
INST(StructField, field, 2, 0)
@@ -558,8 +559,6 @@ INST(BitNot, bitnot, 1, 0)
INST(Select, select, 3, 0)
-INST(Dot, dot, 2, 0)
-
INST(GetStringHash, getStringHash, 1, 0)
INST(WaveGetActiveMask, waveGetActiveMask, 0, 0)
@@ -880,40 +879,40 @@ INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
/* Layout */
- INST(VarLayout, varLayout, 1, 0)
+ INST(VarLayout, varLayout, 1, HOISTABLE)
/* TypeLayout */
- INST(TypeLayoutBase, typeLayout, 0, 0)
- INST(ParameterGroupTypeLayout, parameterGroupTypeLayout, 2, 0)
- INST(ArrayTypeLayout, arrayTypeLayout, 1, 0)
- INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, 0)
- INST(MatrixTypeLayout, matrixTypeLayout, 1, 0)
- INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, 0)
- INST(ExistentialTypeLayout, existentialTypeLayout, 0, 0)
- INST(StructTypeLayout, structTypeLayout, 0, 0)
+ INST(TypeLayoutBase, typeLayout, 0, HOISTABLE)
+ INST(ParameterGroupTypeLayout, parameterGroupTypeLayout, 2, HOISTABLE)
+ INST(ArrayTypeLayout, arrayTypeLayout, 1, HOISTABLE)
+ INST(StreamOutputTypeLayout, streamOutputTypeLayout, 1, HOISTABLE)
+ INST(MatrixTypeLayout, matrixTypeLayout, 1, HOISTABLE)
+ INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, HOISTABLE)
+ INST(ExistentialTypeLayout, existentialTypeLayout, 0, HOISTABLE)
+ INST(StructTypeLayout, structTypeLayout, 0, HOISTABLE)
INST_RANGE(TypeLayout, TypeLayoutBase, StructTypeLayout)
- INST(EntryPointLayout, EntryPointLayout, 1, 0)
+ INST(EntryPointLayout, EntryPointLayout, 1, HOISTABLE)
INST_RANGE(Layout, VarLayout, EntryPointLayout)
/* Attr */
- INST(PendingLayoutAttr, pendingLayout, 1, 0)
- INST(StageAttr, stage, 1, 0)
- INST(StructFieldLayoutAttr, fieldLayout, 2, 0)
- INST(CaseTypeLayoutAttr, caseLayout, 1, 0)
- INST(UNormAttr, unorm, 0, 0)
- INST(SNormAttr, snorm, 0, 0)
- INST(NoDiffAttr, no_diff, 0, 0)
+ INST(PendingLayoutAttr, pendingLayout, 1, HOISTABLE)
+ INST(StageAttr, stage, 1, HOISTABLE)
+ INST(StructFieldLayoutAttr, fieldLayout, 2, HOISTABLE)
+ INST(CaseTypeLayoutAttr, caseLayout, 1, HOISTABLE)
+ INST(UNormAttr, unorm, 0, HOISTABLE)
+ INST(SNormAttr, snorm, 0, HOISTABLE)
+ INST(NoDiffAttr, no_diff, 0, HOISTABLE)
/* SemanticAttr */
- INST(UserSemanticAttr, userSemantic, 2, 0)
- INST(SystemValueSemanticAttr, systemValueSemantic, 2, 0)
+ INST(UserSemanticAttr, userSemantic, 2, HOISTABLE)
+ INST(SystemValueSemanticAttr, systemValueSemantic, 2, HOISTABLE)
INST_RANGE(SemanticAttr, UserSemanticAttr, SystemValueSemanticAttr)
/* LayoutResourceInfoAttr */
- INST(TypeSizeAttr, size, 2, 0)
- INST(VarOffsetAttr, offset, 2, 0)
+ INST(TypeSizeAttr, size, 2, HOISTABLE)
+ INST(VarOffsetAttr, offset, 2, HOISTABLE)
INST_RANGE(LayoutResourceInfoAttr, TypeSizeAttr, VarOffsetAttr)
- INST(FuncThrowTypeAttr, FuncThrowType, 1, 0)
+ INST(FuncThrowTypeAttr, FuncThrowType, 1, HOISTABLE)
INST_RANGE(Attr, PendingLayoutAttr, FuncThrowTypeAttr)
/* Liveness */
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 7bc711f97..7a2e1f0e2 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2436,106 +2436,37 @@ struct IRLiveRangeEnd : IRLiveRangeMarker
IR_LEAF_ISA(LiveRangeEnd);
};
-// Description of an instruction to be used for global value numbering
-struct IRInstKey
-{
- IRInst* inst;
-
- HashCode getHashCode();
-};
-
-bool operator==(IRInstKey const& left, IRInstKey const& right);
-
-struct IRConstantKey
-{
- IRConstant* inst;
-
- bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); }
- HashCode getHashCode() const { return inst->getHashCode(); }
-};
-
-struct SharedIRBuilder
-{
-public:
- SharedIRBuilder()
- {}
-
- explicit SharedIRBuilder(IRModule* module)
- {
- init(module);
- }
-
- void init(IRModule* module)
- {
- m_module = module;
- m_session = module->getSession();
-
- m_globalValueNumberingMap.Clear();
- m_constantMap.Clear();
- }
-
- IRModule* getModule()
- {
- return m_module;
- }
-
- Session* getSession()
- {
- return m_session;
- }
-
- void insertBlockAlongEdge(IREdge const& edge);
-
- // Rebuilds `globalValueNumberingMap`. This is necessary if any existing
- // keys are modified (thus its hash code is changed).
- void deduplicateAndRebuildGlobalNumberingMap();
-
- // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement.
- void replaceGlobalInst(IRInst* oldInst, IRInst* newInst);
-
- typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap;
- typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap;
-
- GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; }
- ConstantMap& getConstantMap() { return m_constantMap; }
-
- bool isGloballyNumberedInst(IRInst* inst);
-
-private:
- // The module that will own all of the IR
- IRModule* m_module;
-
- // The parent compilation session
- Session* m_session;
-
- GlobalValueNumberingMap m_globalValueNumberingMap;
- ConstantMap m_constantMap;
-};
-
struct IRBuilderSourceLocRAII;
struct IRBuilder
{
private:
- /// Shared state for all IR builders working on the same module
- SharedIRBuilder* m_sharedBuilder = nullptr;
+ /// Shared state for all IR builders working on the same module
+ SharedIRBuilder* m_sharedBuilder = nullptr;
- /// Default location for inserting new instructions as they are emitted
+ IRModule* m_module = nullptr;
+
+ /// Default location for inserting new instructions as they are emitted
IRInsertLoc m_insertLoc;
- /// Information that controls how source locations are associatd with instructions that get emitted
+ /// Information that controls how source locations are associatd with instructions that get emitted
IRBuilderSourceLocRAII* m_sourceLocInfo = nullptr;
public:
IRBuilder()
{}
+ explicit IRBuilder(IRModule* module)
+ : m_module(module)
+ , m_sharedBuilder(module->getSharedBuilder())
+ {}
+
explicit IRBuilder(SharedIRBuilder* sharedBuilder)
- : m_sharedBuilder(sharedBuilder)
+ : IRBuilder(sharedBuilder->getModule())
{}
explicit IRBuilder(SharedIRBuilder& sharedBuilder)
- : m_sharedBuilder(&sharedBuilder)
+ : IRBuilder(sharedBuilder.getModule())
{}
void init(SharedIRBuilder* sharedBuilder)
@@ -2550,17 +2481,17 @@ public:
SharedIRBuilder* getSharedBuilder() const
{
- return m_sharedBuilder;
+ return m_module->getSharedBuilder();
}
Session* getSession() const
{
- return m_sharedBuilder->getSession();
+ return m_module->getSession();
}
IRModule* getModule() const
{
- return m_sharedBuilder->getModule();
+ return m_module;
}
IRInsertLoc const& getInsertLoc() const { return m_insertLoc; }
@@ -2597,6 +2528,18 @@ public:
IRConstant* _findOrEmitConstant(
IRConstant& keyInst);
+ /// Implements a special case of inst creation (intended only for calling from `_createInst`)
+ /// that returns an matching existing hoistable inst if it exists, otherwise it creates the inst and
+ /// add it to the global numbering map.
+ IRInst* _findOrEmitHoistableInst(
+ IRType* type,
+ IROp op,
+ Int fixedArgCount,
+ IRInst* const* fixedArgs,
+ Int varArgListCount,
+ Int const* listArgCounts,
+ IRInst* const* const* listArgs);
+
/// Create a new instruction with the given `type` and `op`, with an allocated
/// size of at least `minSizeInBytes`, and with its operand list initialized
/// from the provided lists of "fixed" and "variable" operands.
@@ -2615,7 +2558,8 @@ public:
/// size.
///
/// Note: This is an extremely low-level operation and clients of an `IRBuilder`
- /// should not be using it when other options are available.
+ /// should not be using it when other options are available. This is also where
+ /// all insts creation are bottlenecked through.
///
IRInst* _createInst(
size_t minSizeInBytes,
@@ -2654,6 +2598,12 @@ public:
void addInst(IRInst* inst);
+ // Replace the operand of a potentially hoistable inst.
+ // If the hoistable inst become duplicate of an existing inst,
+ // all uses of the original user will be replaced with the existing inst.
+ // The function returns the new user after any potential updates.
+ IRInst* replaceOperand(IRUse* use, IRInst* newValue);
+
IRInst* getBoolValue(bool value);
IRInst* getIntValue(IRType* type, IRIntegerValue value);
IRInst* getFloatValue(IRType* type, IRFloatingPointValue value);
@@ -2918,6 +2868,20 @@ public:
UInt argCount,
IRInst* const* args);
+ IRInst* createIntrinsicInst(
+ IRType* type,
+ IROp op,
+ IRInst* operand,
+ UInt operandCount,
+ IRInst* const* operands);
+
+ IRInst* createIntrinsicInst(
+ IRType* type,
+ IROp op,
+ UInt operandListCount,
+ UInt const* listOperandCounts,
+ IRInst* const* const* listOperands);
+
IRInst* emitIntrinsicInst(
IRType* type,
IROp op,
@@ -3001,6 +2965,10 @@ public:
UInt argCount,
IRInst* const* args);
+ IRInst* emitMakeMatrixFromScalar(
+ IRType* type,
+ IRInst* scalarValue);
+
IRInst* emitMakeArray(
IRType* type,
UInt argCount,
@@ -3066,31 +3034,6 @@ public:
IRInst* emitReinterpret(IRInst* type, IRInst* value);
- IRInst* findOrAddInst(
- IRType* type,
- IROp op,
- UInt operandListCount,
- UInt const* listOperandCounts,
- IRInst* const* const* listOperands);
-
- IRInst* findOrEmitHoistableInst(
- IRType* type,
- IROp op,
- UInt operandListCount,
- UInt const* listOperandCounts,
- IRInst* const* const* listOperands);
- IRInst* findOrEmitHoistableInst(
- IRType* type,
- IROp op,
- UInt operandCount,
- IRInst* const* operands);
- IRInst* findOrEmitHoistableInst(
- IRType* type,
- IROp op,
- IRInst* operand,
- UInt operandCount,
- IRInst* const* operands);
-
IRFunc* createFunc();
IRGlobalVar* createGlobalVar(
IRType* valueType);
@@ -3841,10 +3784,6 @@ public:
}
};
-void addHoistableInst(
- IRBuilder* builder,
- IRInst* inst);
-
// Helper to establish the source location that will be used
// by an IRBuilder.
struct IRBuilderSourceLocRAII
diff --git a/source/slang/slang-ir-legalize-mesh-outputs.cpp b/source/slang/slang-ir-legalize-mesh-outputs.cpp
index 7c6d256ab..db4d74ddb 100644
--- a/source/slang/slang-ir-legalize-mesh-outputs.cpp
+++ b/source/slang/slang-ir-legalize-mesh-outputs.cpp
@@ -25,7 +25,7 @@ void legalizeMeshOutputTypes(IRModule* module)
: as<IRPrimitivesType>(meshOutput) ? kIROp_PrimitivesDecoration
: (SLANG_UNREACHABLE("Missing case for IRMeshOutputType"), IROp(0));
// Ensure that all params are marked up as vertices/indices/primitives
- traverseUses<IRParam>(meshOutput, [&](IRParam* i)
+ traverseUsers<IRParam>(meshOutput, [&](IRParam* i)
{
builder.addMeshOutputDecoration(decorationOp, i, maxCount);
});
diff --git a/source/slang/slang-ir-legalize-types.cpp b/source/slang/slang-ir-legalize-types.cpp
index 38503155d..d916fa691 100644
--- a/source/slang/slang-ir-legalize-types.cpp
+++ b/source/slang/slang-ir-legalize-types.cpp
@@ -1861,14 +1861,27 @@ static LegalVal legalizeInst(
// While the operands are all "simple," they might not necessarily
// be equal to the operands we started with.
//
+ ShortList<IRInst*> newArgs;
+ newArgs.setCount(argCount);
+ bool recreate = false;
for (UInt aa = 0; aa < argCount; ++aa)
{
auto legalArg = legalArgs[aa];
- inst->setOperand(aa, legalArg.getSimple());
+ newArgs[aa] = legalArg.getSimple();
+ if (newArgs[aa] != inst->getOperand(aa))
+ recreate = true;
+ }
+ if (recreate)
+ {
+ IRBuilder builder(inst->getModule());
+ builder.setInsertBefore(inst);
+ auto newInst = builder.emitIntrinsicInst(legalType.getSimple(), inst->getOp(), argCount, newArgs.getArrayView().getBuffer());
+ inst->replaceUsesWith(newInst);
+ inst->removeFromParent();
+ context->replacedInstructions.add(inst);
+ return LegalVal::simple(newInst);
}
-
inst->setFullType(legalType.getSimple());
-
return LegalVal::simple(inst);
}
@@ -1888,6 +1901,10 @@ static LegalVal legalizeInst(
legalType,
legalArgs.getBuffer());
+ if (legalVal.flavor == LegalVal::Flavor::simple)
+ {
+ inst->replaceUsesWith(legalVal.getSimple());
+ }
// After we are done, we will eliminate the
// original instruction by removing it from
// the IR.
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index 80f974536..55048484f 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -229,11 +229,14 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
switch (originalValue->getOp())
{
case kIROp_StructType:
+ case kIROp_ClassType:
case kIROp_Func:
case kIROp_Generic:
case kIROp_GlobalVar:
case kIROp_GlobalParam:
+ case kIROp_GlobalConstant:
case kIROp_StructKey:
+ case kIROp_InterfaceRequirementEntry:
case kIROp_GlobalGenericParam:
case kIROp_WitnessTable:
case kIROp_InterfaceType:
@@ -277,26 +280,34 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
}
break;
+ case kIROp_VoidLit:
+ {
+ return builder->getVoidValue();
+ }
+ break;
+
default:
{
// In the default case, assume that we have some sort of "hoistable"
// instruction that requires us to create a clone of it.
UInt argCount = originalValue->getOperandCount();
- IRInst* clonedValue = builder->createIntrinsicInst(
- cloneType(this, originalValue->getFullType()),
- originalValue->getOp(),
- argCount, nullptr);
- registerClonedValue(this, clonedValue, originalValue);
+ ShortList<IRInst*> newArgs;
+ newArgs.setCount(argCount);
for (UInt aa = 0; aa < argCount; ++aa)
{
IRInst* originalArg = originalValue->getOperand(aa);
IRInst* clonedArg = cloneValue(this, originalArg);
- clonedValue->getOperands()[aa].init(clonedValue, clonedArg);
+ newArgs[aa] = clonedArg;
}
+ IRInst* clonedValue = builder->createIntrinsicInst(
+ cloneType(this, originalValue->getFullType()),
+ originalValue->getOp(),
+ argCount, newArgs.getArrayView().getBuffer());
+ registerClonedValue(this, clonedValue, originalValue);
+
cloneDecorationsAndChildren(this, clonedValue, originalValue);
-
- addHoistableInst(builder, clonedValue);
+ builder->addInst(clonedValue);
return clonedValue;
}
@@ -524,6 +535,8 @@ IRGlobalConstant* cloneGlobalConstantImpl(
IRGlobalConstant* originalVal,
IROriginalValuesForClone const& originalValues)
{
+ auto oldBuilder = context->builder;
+ context->builder = builder;
auto clonedType = cloneType(context, originalVal->getFullType());
IRGlobalConstant* clonedVal = nullptr;
if(auto originalInitVal = originalVal->getValue())
@@ -537,7 +550,7 @@ IRGlobalConstant* cloneGlobalConstantImpl(
}
cloneSimpleGlobalValueImpl(context, originalVal, originalValues, clonedVal);
-
+ context->builder = oldBuilder;
return clonedVal;
}
@@ -1174,21 +1187,24 @@ IRInst* cloneInst(
// instruction with the right number of operands, intialize
// it, and then add it to the sequence.
UInt argCount = originalInst->getOperandCount();
- IRInst* clonedInst = builder->createIntrinsicInst(
- cloneType(context, originalInst->getFullType()),
- originalInst->getOp(),
- argCount, nullptr);
- registerClonedValue(context, clonedInst, originalValues);
+ ShortList<IRInst*> newArgs;
+ newArgs.setCount(argCount);
auto oldBuilder = context->builder;
context->builder = builder;
for (UInt aa = 0; aa < argCount; ++aa)
{
IRInst* originalArg = originalInst->getOperand(aa);
IRInst* clonedArg = cloneValue(context, originalArg);
- clonedInst->getOperands()[aa].init(clonedInst, clonedArg);
+ newArgs[aa] = clonedArg;
}
- builder->addInst(clonedInst);
context->builder = oldBuilder;
+
+ IRInst* clonedInst = builder->createIntrinsicInst(
+ cloneType(context, originalInst->getFullType()),
+ originalInst->getOp(),
+ argCount, newArgs.getArrayView().getBuffer());
+ builder->addInst(clonedInst);
+ registerClonedValue(context, clonedInst, originalValues);
cloneDecorationsAndChildren(context, clonedInst, originalInst);
cloneExtraDecorations(context, clonedInst, originalValues);
return clonedInst;
diff --git a/source/slang/slang-ir-lower-generic-function.cpp b/source/slang/slang-ir-lower-generic-function.cpp
index 6f412d579..f2d7159d4 100644
--- a/source/slang/slang-ir-lower-generic-function.cpp
+++ b/source/slang/slang-ir-lower-generic-function.cpp
@@ -56,25 +56,51 @@ namespace Slang
lowerGenericFuncType(&builder, genericParent, cast<IRFuncType>(func->getFullType()));
SLANG_ASSERT(loweredGenericType);
loweredFunc->setFullType(loweredGenericType);
- List<IRInst*> clonedParams;
+ List<IRInst*> childrenToDemote;
+ List<IRInst*> clonedParams;
for (auto genericChild : genericParent->getFirstBlock()->getChildren())
{
- if (genericChild == func)
+ switch (genericChild->getOp())
+ {
+ case kIROp_Func:
continue;
- if (genericChild->getOp() == kIROp_Return)
+ case kIROp_Return:
continue;
+ }
// Process all generic parameters and local type definitions.
auto clonedChild = cloneInst(&cloneEnv, &builder, genericChild);
- if (clonedChild->getOp() == kIROp_Param)
+ switch (clonedChild->getOp())
{
- auto paramType = clonedChild->getFullType();
- auto loweredParamType = sharedContext->lowerType(&builder, paramType);
- if (loweredParamType != paramType)
+ case kIROp_Param:
{
- clonedChild->setFullType((IRType*)loweredParamType);
+ auto paramType = clonedChild->getFullType();
+ auto loweredParamType = sharedContext->lowerType(&builder, paramType);
+ if (loweredParamType != paramType)
+ {
+ clonedChild->setFullType((IRType*)loweredParamType);
+ }
+ clonedParams.add(clonedChild);
+ }
+ break;
+
+ case kIROp_LookupWitness:
+ case kIROp_Specialize:
+ {
+ childrenToDemote.add(clonedChild);
+ // Make sure all uses are from the function body.
+ for (auto use = genericChild->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser()->getParent() == genericChild->getParent())
+ {
+ // This specialize/lookup is used as operand to some other
+ // global inst in the generic. This is not supported now.
+ SLANG_UNIMPLEMENTED_X(
+ "Unsupported use of specialize/lookupWitness in generic body.");
+ }
+ }
+ continue;
}
- clonedParams.add(clonedChild);
}
}
cloneInstDecorationsAndChildren(&cloneEnv, &sharedContext->sharedBuilderStorage, func, loweredFunc);
@@ -85,6 +111,15 @@ namespace Slang
param->removeFromParent();
block->addParam(as<IRParam>(param));
}
+
+ // Demote specialize and lookupWitness insts and their dependents down to function body.
+ auto insertPoint = block->getFirstOrdinaryInst();
+ for (Index i = childrenToDemote.getCount() - 1; i >= 0; i--)
+ {
+ auto child = childrenToDemote[i];
+ child->insertBefore(insertPoint);
+ }
+
// Lower generic typed parameters into AnyValueType.
auto firstInst = loweredFunc->getFirstOrdinaryInst();
builder.setInsertBefore(firstInst);
@@ -292,7 +327,8 @@ namespace Slang
loweredFunc = lowerGenericFunction(funcToSpecialize);
if (loweredFunc != funcToSpecialize)
{
- specializeInst->setOperand(0, loweredFunc);
+ IRBuilder builder;
+ builder.replaceOperand(specializeInst->getOperands(), loweredFunc);
}
}
}
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 176142601..f3996fc01 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -38,8 +38,6 @@ struct RedundancyRemovalContext
case kIROp_GetElement:
case kIROp_GetElementPtr:
case kIROp_UpdateElement:
- case kIROp_LookupWitness:
- case kIROp_Specialize:
case kIROp_OptionalHasValue:
case kIROp_GetOptionalValue:
case kIROp_MakeOptionalValue:
diff --git a/source/slang/slang-ir-simplify-for-emit.cpp b/source/slang/slang-ir-simplify-for-emit.cpp
index 5e5f61a4a..67d95c59f 100644
--- a/source/slang/slang-ir-simplify-for-emit.cpp
+++ b/source/slang/slang-ir-simplify-for-emit.cpp
@@ -5,12 +5,16 @@
namespace Slang
{
+bool isCPUTarget(TargetRequest* targetReq);
+bool isCUDATarget(TargetRequest* targetReq);
+
struct SimplifyForEmitContext : public InstPassBase
{
- SimplifyForEmitContext(IRModule* inModule)
- : InstPassBase(inModule)
+ SimplifyForEmitContext(IRModule* inModule, TargetRequest* inTargetReq)
+ : InstPassBase(inModule), targetReq(inTargetReq)
{}
+ TargetRequest* targetReq;
List<IRInst*> followUpWorkList;
HashSet<IRInst*> followUpWorkListSet;
@@ -134,7 +138,7 @@ struct SimplifyForEmitContext : public InstPassBase
IRBuilder builder(sharedBuilderStorage);
builder.setInsertBefore(user);
auto newLoad = builder.emitLoad(load->getPtr());
- use->set(newLoad);
+ builder.replaceOperand(use, newLoad);
}
void processLoad(IRLoad* inst)
@@ -330,8 +334,115 @@ struct SimplifyForEmitContext : public InstPassBase
processInst(followUpWorkList[i]);
}
+ void unifyBinaryExprOperands(IRGlobalValueWithCode* func)
+ {
+ IRBuilder builder(func->getModule());
+
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_IRem:
+ case kIROp_FRem:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Leq:
+ case kIROp_Less:
+ case kIROp_Geq:
+ case kIROp_Greater:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ builder.setInsertBefore(inst);
+ SLANG_ASSERT(inst->getOperandCount() == 2);
+ if (as<IRVectorType>(inst->getDataType()))
+ {
+ for (UInt a = 0; a < 2; a++)
+ {
+ if (as<IRBasicType>(inst->getOperand(a)->getDataType()))
+ {
+ auto v = builder.emitMakeVectorFromScalar(
+ inst->getOperand(1 - a)->getDataType(), inst->getOperand(a));
+ inst->setOperand(a, v);
+ }
+ }
+ }
+ else if (as<IRMatrixType>(inst->getDataType()))
+ {
+ for (UInt a = 0; a < 2; a++)
+ {
+ if (as<IRBasicType>(inst->getOperand(a)->getDataType()))
+ {
+ auto v = builder.emitMakeMatrixFromScalar(
+ inst->getOperand(1 - a)->getDataType(), inst->getOperand(a));
+ inst->setOperand(a, v);
+ }
+ }
+ }
+
+ break;
+ }
+ }
+ }
+ }
+
+ // Turn single element vector values into scalars before using it to call an intrinsic func.
+ void lowerTrivialVector(IRGlobalValueWithCode* func)
+ {
+ IRBuilder builder(func->getModule());
+ List<IRInst*> instsToProcess;
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst = block->getFirstInst(); inst; inst = inst->getNextInst())
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ {
+ // If we are calling an intrinsic with any vector<T,1> argument, replace it with T.
+ auto callInst = as<IRCall>(inst);
+ if (getResolvedInstForDecorations(callInst->getCallee())->findDecoration<IRTargetIntrinsicDecoration>())
+ {
+ for (UInt a = 0; a < callInst->getArgCount(); a++)
+ {
+ auto arg = callInst->getArg(a);
+ if (auto argVectorType = as<IRVectorType>(arg->getDataType()))
+ {
+ if (cast<IRIntLit>(argVectorType->getElementCount())->getValue() == 1)
+ {
+ builder.setInsertBefore(callInst);
+ UInt idx = 0;
+ auto newArg = builder.emitSwizzle(argVectorType->getElementType(), arg, 1, &idx);
+ callInst->setOperand(a + 1, newArg);
+ }
+ }
+ }
+ }
+ }
+ break;
+ }
+ }
+ }
+ }
+
+
void processFunc(IRGlobalValueWithCode* func)
{
+ if (isCPUTarget(targetReq) || isCUDATarget(targetReq))
+ {
+ unifyBinaryExprOperands(func);
+ lowerTrivialVector(func);
+ }
eliminateCompositeConstruct(func);
deferAndDuplicateElementExtract(func);
deferAndDuplicateLoad(func);
@@ -345,9 +456,9 @@ struct SimplifyForEmitContext : public InstPassBase
}
};
-void simplifyForEmit(IRModule* module)
+void simplifyForEmit(IRModule* module, TargetRequest* targetRequest)
{
- SimplifyForEmitContext context(module);
+ SimplifyForEmitContext context(module, targetRequest);
context.processModule();
}
diff --git a/source/slang/slang-ir-simplify-for-emit.h b/source/slang/slang-ir-simplify-for-emit.h
index a6cf3bad8..e35c74841 100644
--- a/source/slang/slang-ir-simplify-for-emit.h
+++ b/source/slang/slang-ir-simplify-for-emit.h
@@ -4,6 +4,7 @@
namespace Slang
{
struct IRModule;
+ class TargetRequest;
- void simplifyForEmit(IRModule* inModule);
+ void simplifyForEmit(IRModule* inModule, TargetRequest* req);
}
diff --git a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
index 39edaeb16..cfc9d9c76 100644
--- a/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
+++ b/source/slang/slang-ir-specialize-dynamic-associatedtype-lookup.cpp
@@ -200,23 +200,22 @@ struct AssociatedTypeLookupSpecializationContext
if (!seqId)
return;
// Insert code to pack sequential ID into an uint2 at all use sites.
- IRUse* nextUse = nullptr;
- for (auto use = inst->firstUse; use; use = nextUse)
+ traverseUses(inst, [&](IRUse* use)
{
- nextUse = use->nextUse;
if (as<IRCOMWitnessDecoration>(use->getUser()))
- continue;
+ {
+ return;
+ }
IRBuilder builder(sharedContext->sharedBuilderStorage);
builder.setInsertBefore(use->getUser());
auto uint2Type = builder.getVectorType(
builder.getUIntType(), builder.getIntValue(builder.getIntType(), 2));
IRInst* uint2Args[] = {
seqId->getSequentialIDOperand(),
- builder.getIntValue(builder.getUIntType(), 0)};
+ builder.getIntValue(builder.getUIntType(), 0) };
auto uint2seqID = builder.emitMakeVector(uint2Type, 2, uint2Args);
- use->set(uint2seqID);
- use = nextUse;
- }
+ builder.replaceOperand(use, uint2seqID);
+ });
}
});
@@ -229,14 +228,12 @@ struct AssociatedTypeLookupSpecializationContext
builder.setInsertBefore(globalInst);
auto witnessTableIDType = builder.getWitnessTableIDType(
(IRType*)cast<IRWitnessTableType>(globalInst)->getConformanceType());
- IRUse* nextUse = nullptr;
- for (auto use = globalInst->firstUse; use; use = nextUse)
+ traverseUses(globalInst, [&](IRUse* use)
{
- nextUse = use->nextUse;
if (use->getUser()->getOp() == kIROp_WitnessTable)
- continue;
- use->set(witnessTableIDType);
- }
+ return;
+ builder.replaceOperand(use, witnessTableIDType);
+ });
sharedContext->sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
}
}
diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp
index e4ccf40d5..03eda0d99 100644
--- a/source/slang/slang-ir-specialize-resources.cpp
+++ b/source/slang/slang-ir-specialize-resources.cpp
@@ -256,16 +256,16 @@ struct ResourceOutputSpecializationPass
// the aid of this pass.
//
List<IRCall*> calls;
- for( auto use = oldFunc->firstUse; use; use = use->nextUse )
- {
- auto user = use->getUser();
- auto call = as<IRCall>(user);
- if(!call)
- continue;
- if(call->getCallee() != oldFunc)
- continue;
- calls.add(call);
- }
+ traverseUses(oldFunc, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ auto call = as<IRCall>(user);
+ if (!call)
+ return;
+ if (call->getCallee() != oldFunc)
+ return;
+ calls.add(call);
+ });
// Once we have identified the calls to `oldFunc`, we will set about replacing
// them with calls to `newFunc`.
@@ -833,16 +833,16 @@ struct ResourceOutputSpecializationPass
// `out`/`inout` parameters that doesn't have as many "gotcha" cases.
//
List<IRStore*> stores;
- for( auto use = param->firstUse; use; use = use->nextUse )
- {
- auto user = use->getUser();
- auto store = as<IRStore>(user);
- if(!store)
- continue;
- if(store->ptr.get() != param)
- continue;
- stores.add(store);
- }
+ traverseUses(param, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ auto store = as<IRStore>(user);
+ if (!store)
+ return;
+ if (store->ptr.get() != param)
+ return;
+ stores.add(store);
+ });
// Having identified the places where a value is stored to
// the output parameter, we iterate over those values to
@@ -1194,16 +1194,16 @@ bool specializeResourceUsage(
// Inline unspecializable resource output functions and then continue trying.
for (auto func : unspecializableFuncs)
{
- for (auto use = func->firstUse; use; use = use->nextUse)
+ traverseUses(func, [&](IRUse* use)
{
auto user = use->getUser();
auto call = as<IRCall>(user);
if (!call)
- continue;
+ return;
if (call->getCallee() != func)
- continue;
+ return;
inlineCall(call);
- }
+ });
}
simplifyIR(irModule);
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index cf7acd46c..0044e5745 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -897,7 +897,8 @@ struct SpecializationContext
// specialization opportunities (generic specialization,
// existential specialization, simplifications, etc.)
//
- iterChanged |= maybeSpecializeInst(inst);
+ if (inst->hasUses() || inst->mightHaveSideEffects())
+ iterChanged |= maybeSpecializeInst(inst);
// Finally, we need to make our logic recurse through
// the whole IR module, so we want to add the children
@@ -1041,7 +1042,6 @@ struct SpecializationContext
// The old callee should be in the form of `specialize(.operator[], IInterfaceType)`,
// we should update it to be `specialize(.operator[], elementType)`, so the return type
// of the load call is `elementType`.
- auto oldCallee = inst->getCallee();
// A subscript operation on mutable buffers returns a ptr type instead of a value type.
// We need to make sure the pointer-ness is preserved correctly.
@@ -1057,9 +1057,6 @@ struct SpecializationContext
inst->replaceUsesWith(newWrapExistential);
workList.Remove(inst);
inst->removeAndDeallocate();
- SLANG_ASSERT(!oldCallee->hasUses());
- workList.Remove(oldCallee);
- oldCallee->removeAndDeallocate();
addUsersToWorkList(newWrapExistential);
workList.Remove(wrapExistential);
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 3f250e31e..b195af2cc 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -923,6 +923,15 @@ IRBlock* IREdge::getSuccessor() const
return cast<IRBlock>(getUse()->get());
}
+void SharedIRBuilder::init(IRModule* module)
+{
+ m_module = module;
+ m_session = module->getSession();
+
+ m_globalValueNumberingMap.Clear();
+ m_constantMap.Clear();
+}
+
void SharedIRBuilder::insertBlockAlongEdge(
IREdge const& edge)
{
diff --git a/source/slang/slang-ir-type-set.cpp b/source/slang/slang-ir-type-set.cpp
index 0cfe69e42..7ac617bda 100644
--- a/source/slang/slang-ir-type-set.cpp
+++ b/source/slang/slang-ir-type-set.cpp
@@ -7,313 +7,4 @@
namespace Slang
{
-IRTypeSet::IRTypeSet(Session* session)
-{
- m_module = IRModule::create(session);
-
- m_sharedBuilder.init(m_module);
- m_builder.init(m_sharedBuilder);
-
- m_builder.setInsertInto(m_module->getModuleInst());
-}
-
-IRTypeSet::~IRTypeSet()
-{
- _clearTypes();
-}
-
-void IRTypeSet::clear()
-{
- _clearTypes();
-
- m_cloneMap.Clear();
-
- m_module = IRModule::create(m_sharedBuilder.getSession());
-
- m_sharedBuilder.init(m_module);
- m_builder.init(m_sharedBuilder);
-
- m_builder.setInsertInto(m_module->getModuleInst());
-}
-
-void IRTypeSet::_clearTypes()
-{
- List<IRType*> types;
- getTypes(types);
-
- for (auto type : types)
- {
- // We need to destroy references to instructions in other modules
- if (type->getModule() == m_module)
- {
- // We want to remove arguments because an argument *could* be an instruction in another module,
- // and we don't want to those modules insts to have uses, in this module which is being destroyed
- type->removeArguments();
- }
- }
-}
-
-IRInst* IRTypeSet::cloneInst(IRInst* inst)
-{
- if (inst == nullptr)
- {
- return nullptr;
- }
-
- // See if it's already cloned
- if (IRInst*const* newInstPtr = m_cloneMap.TryGetValue(inst))
- {
- return *newInstPtr;
- }
-
- IRModule* module = inst->getModule();
- // All inst's must belong to a module
- SLANG_ASSERT(module);
-
- // If it's in this module then we don't need to clone
- if (module == m_module)
- {
- return inst;
- }
-
- if (isNominalOp(inst->getOp()))
- {
- // We can clone without any definition, and add the linkage
-
- // TODO(JS)
- // This is arguably problematic - I'm adding an instruction from another module to the map, to be it's self.
- // I did have code which created a copy of the nominal instruction and name hint, but because nominality means
- // 'same address' other code would generate a different name for that instruction (say as compared to being a member in
- // the original instruction)
- //
- // Because I use findOrAddInst which doesn't hoist instructions, the hoisting doesn't rely on parenting, that would
- // break.
-
- // If nominal, we just use the original inst
- m_cloneMap.Add(inst, inst);
- return inst;
- }
-
- // It would be nice if I could use ir-clone.cpp to do this -> but it doesn't clone
- // operands. We wouldn't want to clone decorations, and it can't clone IRConstant(!) so
- // it's no use
-
- IRInst* clone = nullptr;
- switch (inst->getOp())
- {
- case kIROp_IntLit:
- {
- auto intLit = static_cast<IRConstant*>(inst);
- IRType* clonedType = cloneType(intLit->getDataType());
- clone = m_builder.getIntValue(clonedType, intLit->value.intVal);
- break;
- }
- case kIROp_StringLit:
- {
- auto stringLit = static_cast<IRStringLit*>(inst);
- clone = m_builder.getStringValue(stringLit->getStringSlice());
- break;
- }
- case kIROp_VectorType:
- {
- auto vecType = static_cast<IRVectorType*>(inst);
- const Index elementCount = Index(getIntVal(vecType->getElementCount()));
-
- if (elementCount <= 1)
- {
- clone = cloneType(vecType->getElementType());
- }
- break;
- }
- case kIROp_MatrixType:
- {
- auto matType = static_cast<IRMatrixType*>(inst);
- const Index columnCount = Index(getIntVal(matType->getColumnCount()));
- const Index rowCount = Index(getIntVal(matType->getRowCount()));
-
- if (columnCount <= 1 && rowCount <= 1)
- {
- clone = cloneType(matType->getElementType());
- }
- break;
- }
- default: break;
- }
-
- if (!clone)
- {
- if (IRBasicType::isaImpl(inst->getOp()))
- {
- clone = m_builder.getType(inst->getOp());
- }
- else
- {
- IRType* irType = dynamicCast<IRType>(inst);
- if (irType)
- {
- auto clonedType = cloneType(inst->getFullType());
- Index operandCount = Index(inst->getOperandCount());
-
- List<IRInst*> cloneOperands;
- cloneOperands.setCount(operandCount);
-
- for (Index i = 0; i < operandCount; ++i)
- {
- cloneOperands[i] = cloneInst(inst->getOperand(i));
- }
-
- //clone = m_irBuilder.findOrEmitHoistableInst(cloneType, inst->op, operandCount, cloneOperands.getBuffer());
-
- UInt operandCounts[1] = { UInt(operandCount) };
- IRInst*const* listOperands[1] = { cloneOperands.getBuffer() };
-
- clone = m_builder.findOrAddInst(clonedType, inst->getOp(), 1, operandCounts, listOperands);
- }
- else
- {
- // This cloning style only works on insts that are not unique
- auto clonedType = cloneType(inst->getFullType());
-
- Index operandCount = Index(inst->getOperandCount());
- clone = m_builder.emitIntrinsicInst(clonedType, inst->getOp(), operandCount, nullptr);
- for (Index i = 0; i < operandCount; ++i)
- {
- auto cloneOperand = cloneInst(inst->getOperand(i));
- clone->getOperands()[i].init(clone, cloneOperand);
- }
- }
- }
- }
-
- m_cloneMap.Add(inst, clone);
- return clone;
-}
-
-IRType* IRTypeSet::add(IRType* irType)
-{
- if (irType->getModule() == m_module)
- {
- return irType;
- }
- // We need to clone the type
- return cloneType(irType);
-}
-
-void IRTypeSet::getTypes(List<IRType*>& outTypes) const
-{
- outTypes.clear();
- for (auto inst : m_module->getModuleInst()->getChildren())
- {
- if (IRType* type = as<IRType>(inst))
- {
- outTypes.add(type);
- }
- }
-}
-
-void IRTypeSet::getTypes(Kind kind, List<IRType*>& outTypes) const
-{
- outTypes.clear();
-
- for (auto inst : m_module->getModuleInst()->getChildren())
- {
- IRType* type = nullptr;
-
- switch (kind)
- {
- case Kind::Scalar:
- {
- type = as<IRBasicType>(inst);
- break;
- }
- case Kind::Vector:
- {
- type = as<IRVectorType>(inst);
- break;
- }
- case Kind::Matrix:
- {
- type = as<IRMatrixType>(inst);
- break;
- }
- default: break;
- }
-
- if (type)
- {
- outTypes.add(type);
- }
- }
-}
-
-IRType* IRTypeSet::addVectorType(IRType* inElementType, int colsCount)
-{
- IRType* elementType = cloneType(inElementType);
- if (colsCount == 1)
- {
- return elementType;
- }
- return m_builder.getVectorType(elementType, m_builder.getIntValue(m_builder.getIntType(), colsCount));
-}
-
-void IRTypeSet::addVectorForMatrixTypes()
-{
- // Make a copy so we can alter m_types dictionary
- List<IRType*> types;
- getTypes(Kind::Matrix, types);
- for (IRType* type : types)
- {
- SLANG_ASSERT(as<IRMatrixType>(type));
- IRMatrixType* matType = static_cast<IRMatrixType*>(type);
- m_builder.getVectorType(matType->getElementType(), matType->getColumnCount());
- }
-}
-
-static bool _hasNominalOperand(IRInst* inst)
-{
- const Index operandCount = Index(inst->getOperandCount());
- auto operands = inst->getOperands();
-
- for (Index i = 0; i < operandCount; ++i)
- {
- IRInst* operand = operands[i].get();
- if (isNominalOp(operand->getOp()))
- {
- return true;
- }
- }
-
- return false;
-}
-
-void IRTypeSet::_addAllBuiltinTypesRec(IRInst* inst)
-{
- for (IRInst* child = inst->getFirstDecorationOrChild(); child; child = child->getNextInst())
- {
- IRType* type = nullptr;
-
- if (auto vectorType = as<IRVectorType>(child))
- {
- type = vectorType;
- }
- else if (auto matrixType = as<IRMatrixType>(child))
- {
- type = matrixType;
- }
- if (type && !_hasNominalOperand(type))
- {
- add(type);
- }
- else
- {
- _addAllBuiltinTypesRec(child);
- }
- }
-}
-
-void IRTypeSet::addAllBuiltinTypes(IRModule* module)
-{
- _addAllBuiltinTypesRec(module->getModuleInst());
-}
-
}
diff --git a/source/slang/slang-ir-type-set.h b/source/slang/slang-ir-type-set.h
index 958d71cf1..f60088fcd 100644
--- a/source/slang/slang-ir-type-set.h
+++ b/source/slang/slang-ir-type-set.h
@@ -9,85 +9,4 @@
namespace Slang
{
-/*
-NOTE! This type set is only designed to work for emitting code to determine unique types. It is envisaged in the
-future that it will not be needed because types will be made unique within a module, and thus the pointer to a type
-will uniquely identify the type.
-
-The other reason this type exists, is to allow an IRModule for emit to be immutable. That is not currently possible
-within emit code because it may be necessary in order to emit to be able to create other types that needed (for example
-vector types required for a matrix type implementation).
-
-This is used so as to try and use slangs type system to uniquely identify types and specializations on intrinsic.
-That we want to have a pointer to a type be unique, and slang supports this through the m_sharedIRBuilder. BUT for this to
-work all work on the module must use the same sharedIRBuilder, and that appears to not be the case in terms
-of other passes.
-Even if it was the case when we may want to add types as part of emitting, we can't use the previously used
-shared builder, so again we end up with pointers to the same things not being the same thing.
-
-To work around this we clone types we want to use as keys into the 'unique module'.
-This is not necessary for all types though - as we assume nominal types *must* have unique pointers (that is the
-definition of nominal).
-
-This could be handled in other ways (for example not testing equality on pointer equality). Anyway for now this
-works, but probably needs to be handled in a better way. The better way may involve having guarantees about equality
-enabled in other code generation and making de-duping possible in emit code.
-
-Note that one pro for this approach is that it does not alter the source module. That as it stands it's not necessary
-for the source module to be immutable, because it is created for emitting and then discarded.
-
-NOTE! That Vector<X, 1> or Matrix<X, 1, 1> will be turned into the type X.
-
- */
-class IRTypeSet
-{
-public:
- enum class Kind
- {
- Scalar,
- Vector,
- Matrix,
- CountOf,
- };
-
- IRType* add(IRType* type);
- IRType* addVectorType(IRType* elementType, int colsCount);
-
- void addAllBuiltinTypes(IRModule* module);
-
- void addVectorForMatrixTypes();
-
- void getTypes(List<IRType*>& outTypes) const;
- void getTypes(Kind kind, List<IRType*>& outTypes) const;
-
- IRType* getType(IRType* type) { return cloneType(type); }
-
- IRType* cloneType(IRType* type) { return (IRType*)cloneInst((IRInst*)type); }
- IRInst* cloneInst(IRInst* inst);
-
- /// Returns true if the type belongs and is created on the module owned by the set
- bool isOwned(IRType* type) { return type->getModule() == m_module; }
-
- IRBuilder& getBuilder() { return m_builder; }
- IRModule* getModule() const { return m_module; }
-
- void clear();
-
- IRTypeSet(Session* session);
- ~IRTypeSet();
-
-protected:
- void _addAllBuiltinTypesRec(IRInst* inst);
- void _clearTypes();
-
- // Maps insts from source modules into m_module.
- // NOTE! That nominal types are not cloned, as they are identified by pointer. They are just
- Dictionary<IRInst*, IRInst*> m_cloneMap;
-
- // Can find all types by traversing the types in the m_module
- SharedIRBuilder m_sharedBuilder;
- IRBuilder m_builder;
- RefPtr<IRModule> m_module;
-};
-
} // namespace Slang
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 1ea426715..253686aa5 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -549,6 +549,8 @@ struct GenericChildrenMigrationContextImpl
}
if (as<IRConstant>(inst))
return false;
+ if (getIROpInfo(inst->getOp()).isHoistable())
+ return false;
return true;
});
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index efd38f7b7..2f1ac2d1a 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -49,7 +49,7 @@ struct DeduplicateContext
return *newValue;
for (UInt i = 0; i < value->getOperandCount(); i++)
{
- value->setOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate));
+ value->unsafeSetOperand(i, deduplicate(value->getOperand(i), shouldDeduplicate));
}
value->setFullType((IRType*)deduplicate(value->getFullType(), shouldDeduplicate));
if (auto newValue = deduplicateMap.TryGetValue(key))
diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp
index 03db96ac5..d5c0aa432 100644
--- a/source/slang/slang-ir-validate.cpp
+++ b/source/slang/slang-ir-validate.cpp
@@ -186,6 +186,28 @@ namespace Slang
if (pp == operandParent)
return;
}
+
+ // We allow out-of-order def-use in global scope.
+ bool allInGlobalScope = inst->getParent() && inst->getParent()->getOp() == kIROp_Module;
+ if (allInGlobalScope)
+ {
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ auto op = inst->getOperand(i);
+ if (!op)
+ continue;
+ if (!op->getParent())
+ continue;
+ if (op->getParent()->getOp() != kIROp_Module)
+ {
+ allInGlobalScope = false;
+ break;
+ }
+ }
+ }
+ if (allInGlobalScope)
+ return;
+
//
// We failed to find `operandParent` while walking the ancestors of `inst`,
// so something had gone wrong.
diff --git a/source/slang/slang-ir-wrap-structured-buffers.cpp b/source/slang/slang-ir-wrap-structured-buffers.cpp
index 2ad09aa90..53671fa7f 100644
--- a/source/slang/slang-ir-wrap-structured-buffers.cpp
+++ b/source/slang/slang-ir-wrap-structured-buffers.cpp
@@ -134,7 +134,7 @@ struct WrapStructuredBuffersContext
// scanning through its IR uses, since values of that
// type are using it as a (type) operand.
//
- for( auto typeUse = newStructuredBufferType->firstUse; typeUse; typeUse = typeUse->nextUse )
+ traverseUses(newStructuredBufferType, [&](IRUse* typeUse)
{
// There might be uses of `newStructuredBufferType` where
// it isn't being used as the type of a value, so we
@@ -142,7 +142,7 @@ struct WrapStructuredBuffersContext
//
auto valueOfStructuredBufferType = typeUse->getUser();
if(valueOfStructuredBufferType->getFullType() != newStructuredBufferType)
- continue;
+ return;
// Now we have some `valueOfStructuredBufferType`. In our running
// example, this might be `gBuffer`, which is an `IRGlobalParam`.
@@ -155,7 +155,7 @@ struct WrapStructuredBuffersContext
// because these could be calls to intrinsic functions like
// `RWStructuredBuffer.Load`
//
- for( auto valueUse = valueOfStructuredBufferType->firstUse; valueUse; valueUse = valueUse->nextUse )
+ traverseUses(valueOfStructuredBufferType, [&](IRUse* valueUse)
{
// we are only interested in instructions that are calls,
// with at least one argument, where the first argument
@@ -165,11 +165,11 @@ struct WrapStructuredBuffersContext
//
auto call = as<IRCall>(valueUse->getUser());
if(!call)
- continue;
+ return;
if(call->getArgCount() == 0)
- continue;
+ return;
if(call->getArg(0) != valueOfStructuredBufferType)
- continue;
+ return;
// At this point we have a candidate `call` instruction,
// but we need to determine whether it is a call to
@@ -196,7 +196,7 @@ struct WrapStructuredBuffersContext
//
auto callee = call->getCallee();
if(!as<IRSpecialize>(callee))
- continue;
+ return;
// At this point it seems likely we have one of the calls
// we want to rewrite, but there are still intrinsics
@@ -285,8 +285,8 @@ struct WrapStructuredBuffersContext
newVal->setOperand(0, call);
}
}
- }
- }
+ });
+ });
}
/// Get the struture field "key" to use for generated wrappers
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 1b16bfe1f..6cf0f09a5 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -148,7 +148,6 @@ namespace Slang
void IRUse::init(IRInst* u, IRInst* v)
{
clear();
-
user = u;
usedValue = v;
if(v)
@@ -170,6 +169,9 @@ namespace Slang
void IRUse::set(IRInst* uv)
{
+ // Normally we should never be modifying the operand of an hoistable inst.
+ // They can be modified by `replaceUsesWith`, or to be replaced by a new inst.
+ SLANG_ASSERT(!getIROpInfo(user->getOp()).isHoistable() || uv == usedValue);
init(user, uv);
}
@@ -1196,11 +1198,57 @@ namespace Slang
return as<IRGlobalValueWithCode>(pp);
}
+ void addHoistableInst(
+ IRBuilder* builder,
+ IRInst* inst);
+
// Add an instruction into the current scope
void IRBuilder::addInst(
IRInst* inst)
{
- inst->insertAt(m_insertLoc);
+ if (getIROpInfo(inst->getOp()).isGlobal())
+ {
+ addHoistableInst(this, inst);
+ return;
+ }
+
+ if (!inst->parent)
+ inst->insertAt(m_insertLoc);
+ }
+
+ IRInst* IRBuilder::replaceOperand(IRUse* use, IRInst* newValue)
+ {
+ auto user = use->getUser();
+ if (user->getModule())
+ {
+ user->getModule()->getSharedBuilder()->getInstReplacementMap().TryGetValue(newValue, newValue);
+ }
+
+ if (!getIROpInfo(user->getOp()).isHoistable())
+ {
+ use->set(newValue);
+ return user;
+ }
+
+ // If user is hoistable, we need to remove it from the global number map first,
+ // perform the update, then try to reinsert it back to the global number map.
+ // If we find an equivalent entry already exists in the global number map,
+ // we return the existing entry.
+ auto builder = user->getModule()->getSharedBuilder();
+ builder->_removeGlobalNumberingEntry(user);
+ use->init(user, newValue);
+
+ IRInst* existingVal = nullptr;
+ if (builder->getGlobalValueNumberingMap().TryGetValue(IRInstKey{ user }, existingVal))
+ {
+ user->replaceUsesWith(existingVal);
+ return existingVal;
+ }
+ else
+ {
+ builder->_addGlobalNumberingEntry(user);
+ return user;
+ }
}
// Given two parent instructions, pick the better one to use as as
@@ -1645,6 +1693,13 @@ namespace Slang
Int const* listArgCounts,
IRInst* const* const* listArgs)
{
+ m_sharedBuilder->getInstReplacementMap().TryGetValue((IRInst*)(type), *(IRInst**)&type);
+
+ if (getIROpInfo(op).flags & kIROpFlag_Hoistable)
+ {
+ return _findOrEmitHoistableInst(type, op, fixedArgCount, fixedArgs, varArgListCount, listArgCounts, listArgs);
+ }
+
Int varArgCount = 0;
for (Int ii = 0; ii < varArgListCount; ++ii)
{
@@ -1671,7 +1726,9 @@ namespace Slang
{
if (fixedArgs)
{
- operand->init(inst, fixedArgs[aa]);
+ auto arg = fixedArgs[aa];
+ m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg);
+ operand->init(inst, arg);
}
else
{
@@ -1687,7 +1744,9 @@ namespace Slang
{
if (listArgs[ii])
{
- operand->init(inst, listArgs[ii][jj]);
+ auto arg = listArgs[ii][jj];
+ m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg);
+ operand->init(inst, arg);
}
else
{
@@ -2309,21 +2368,23 @@ namespace Slang
args.add(getIntValue(capabilityAtomType, Int(atom)));
}
- return findOrEmitHoistableInst(
+ return createIntrinsicInst(
capabilitySetType, kIROp_CapabilitySet, args.getCount(), args.getBuffer());
}
- IRInst* IRBuilder::findOrEmitHoistableInst(
- IRType* type,
- IROp op,
- UInt operandListCount,
- UInt const* listOperandCounts,
- IRInst* const* const* listOperands)
- {
- UInt operandCount = 0;
- for (UInt ii = 0; ii < operandListCount; ++ii)
+ IRInst* IRBuilder::_findOrEmitHoistableInst(
+ IRType* type,
+ IROp op,
+ Int fixedArgCount,
+ IRInst* const* fixedArgs,
+ Int varArgListCount,
+ Int const* listArgCounts,
+ IRInst* const* const* listArgs)
+ {
+ UInt operandCount = fixedArgCount;
+ for (Int ii = 0; ii < varArgListCount; ++ii)
{
- operandCount += listOperandCounts[ii];
+ operandCount += listArgCounts[ii];
}
auto& memoryArena = getModule()->getMemoryArena();
@@ -2350,102 +2411,21 @@ namespace Slang
// Don't link up as we may free (if we already have this key)
{
IRUse* operand = inst->getOperands();
- for (UInt ii = 0; ii < operandListCount; ++ii)
+ for (Int ii = 0; ii < fixedArgCount; ++ii)
{
- UInt listOperandCount = listOperandCounts[ii];
- for (UInt jj = 0; jj < listOperandCount; ++jj)
- {
- operand->usedValue = listOperands[ii][jj];
- operand++;
- }
- }
- }
-
- // Find or add the key/inst
- {
- IRInstKey key = { inst };
-
- // Ideally we would add if not found, else return if was found instead of testing & then adding.
- IRInst** found = getSharedBuilder()->getGlobalValueNumberingMap().TryGetValueOrAdd(key, inst);
- SLANG_ASSERT(endCursor == memoryArena.getCursor());
- // If it's found, just return, and throw away the instruction
- if (found)
- {
- memoryArena.rewindToCursor(cursor);
- return *found;
- }
- }
-
- // Make the lookup 'inst' instruction into 'proper' instruction. Equivalent to
- // IRInst* inst = createInstImpl<IRInst>(builder, op, type, 0, nullptr, operandListCount, listOperandCounts, listOperands);
- {
- if (type)
- {
- inst->typeUse.usedValue = nullptr;
- inst->typeUse.init(inst, type);
- }
-
- _maybeSetSourceLoc(inst);
-
- IRUse*const operands = inst->getOperands();
- for (UInt i = 0; i < operandCount; ++i)
- {
- IRUse& operand = operands[i];
- auto value = operand.usedValue;
-
- operand.usedValue = nullptr;
- operand.init(inst, value);
+ auto arg = fixedArgs[ii];
+ m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg);
+ operand->usedValue = arg;
+ operand++;
}
- }
-
- addHoistableInst(this, inst);
-
- return inst;
- }
-
- IRInst* IRBuilder::findOrAddInst(
- IRType* type,
- IROp op,
- UInt operandListCount,
- UInt const* listOperandCounts,
- IRInst* const* const* listOperands)
- {
- UInt operandCount = 0;
- for (UInt ii = 0; ii < operandListCount; ++ii)
- {
- operandCount += listOperandCounts[ii];
- }
-
- auto& memoryArena = getModule()->getMemoryArena();
- void* cursor = memoryArena.getCursor();
-
- // We are going to create a 'dummy' instruction on the memoryArena
- // which can be used as a key for lookup, so see if we
- // already have an equivalent instruction available to use.
- size_t keySize = sizeof(IRInst) + operandCount * sizeof(IRUse);
- IRInst* inst = (IRInst*)memoryArena.allocateAndZero(keySize);
-
- void* endCursor = memoryArena.getCursor();
- // Mark as 'unused' cos it is unused on release builds.
- SLANG_UNUSED(endCursor);
-
- new(inst) IRInst();
-#if SLANG_ENABLE_IR_BREAK_ALLOC
- inst->_debugUID = _debugGetAndIncreaseInstCounter();
-#endif
- inst->m_op = op;
- inst->typeUse.usedValue = type;
- inst->operandCount = (uint32_t)operandCount;
-
- // Don't link up as we may free (if we already have this key)
- {
- IRUse* operand = inst->getOperands();
- for (UInt ii = 0; ii < operandListCount; ++ii)
+ for (Int ii = 0; ii < varArgListCount; ++ii)
{
- UInt listOperandCount = listOperandCounts[ii];
+ UInt listOperandCount = listArgCounts[ii];
for (UInt jj = 0; jj < listOperandCount; ++jj)
{
- operand->usedValue = listOperands[ii][jj];
+ auto arg = listArgs[ii][jj];
+ m_sharedBuilder->getInstReplacementMap().TryGetValue(arg, arg);
+ operand->usedValue = arg;
operand++;
}
}
@@ -2488,50 +2468,17 @@ namespace Slang
}
}
- addInst(inst);
- return inst;
- }
-
-
- IRInst* IRBuilder::findOrEmitHoistableInst(
- IRType* type,
- IROp op,
- UInt operandCount,
- IRInst* const* operands)
- {
- return findOrEmitHoistableInst(
- type,
- op,
- 1,
- &operandCount,
- &operands);
- }
-
- IRInst* IRBuilder::findOrEmitHoistableInst(
- IRType* type,
- IROp op,
- IRInst* operand,
- UInt operandCount,
- IRInst* const* operands)
- {
- UInt counts[] = { 1, operandCount };
- IRInst* const* lists[] = { &operand, operands };
+ addHoistableInst(this, inst);
- return findOrEmitHoistableInst(
- type,
- op,
- 2,
- counts,
- lists);
+ return inst;
}
-
IRType* IRBuilder::getType(
IROp op,
UInt operandCount,
IRInst* const* operands)
{
- return (IRType*) findOrEmitHoistableInst(
+ return (IRType*)createIntrinsicInst(
nullptr,
op,
operandCount,
@@ -2831,7 +2778,7 @@ namespace Slang
IRType* const* paramTypes,
IRType* resultType)
{
- return (IRFuncType*) findOrEmitHoistableInst(
+ return (IRFuncType*)createIntrinsicInst(
nullptr,
kIROp_FuncType,
resultType,
@@ -2844,13 +2791,13 @@ namespace Slang
{
UInt counts[3] = {1, paramCount, 1};
IRInst** lists[3] = {(IRInst**)&resultType, (IRInst**)paramTypes, (IRInst**)&attribute};
- return (IRFuncType*)findOrEmitHoistableInst(nullptr, kIROp_FuncType, 3, counts, lists);
+ return (IRFuncType*)createIntrinsicInst(nullptr, kIROp_FuncType, 3, counts, lists);
}
IRWitnessTableType* IRBuilder::getWitnessTableType(
IRType* baseType)
{
- return (IRWitnessTableType*)findOrEmitHoistableInst(
+ return (IRWitnessTableType*)createIntrinsicInst(
nullptr,
kIROp_WitnessTableType,
1,
@@ -2860,7 +2807,7 @@ namespace Slang
IRWitnessTableIDType* IRBuilder::getWitnessTableIDType(
IRType* baseType)
{
- return (IRWitnessTableIDType*)findOrEmitHoistableInst(
+ return (IRWitnessTableIDType*)createIntrinsicInst(
nullptr,
kIROp_WitnessTableIDType,
1,
@@ -2914,7 +2861,7 @@ namespace Slang
UInt caseCount,
IRType* const* caseTypes)
{
- return (IRType*) findOrEmitHoistableInst(
+ return (IRType*)createIntrinsicInst(
getTypeKind(),
kIROp_TaggedUnionType,
caseCount,
@@ -2947,7 +2894,7 @@ namespace Slang
}
}
- return (IRType*) findOrEmitHoistableInst(
+ return (IRType*)createIntrinsicInst(
getTypeKind(),
kIROp_BindExistentialsType,
baseType,
@@ -3197,7 +3144,7 @@ namespace Slang
if (as<IRWitnessTable>(innerReturnVal))
{
- return findOrEmitHoistableInst(
+ return createIntrinsicInst(
type,
kIROp_Specialize,
genericVal,
@@ -3214,7 +3161,8 @@ namespace Slang
argCount,
args);
- addInst(inst);
+ if (!inst->parent)
+ addInst(inst);
return inst;
}
@@ -3233,7 +3181,7 @@ namespace Slang
IRInst* args[] = {witnessTableVal, interfaceMethodVal};
- return findOrEmitHoistableInst(
+ return createIntrinsicInst(
type,
kIROp_LookupWitness,
2,
@@ -3331,6 +3279,17 @@ namespace Slang
args);
}
+ IRInst* IRBuilder::createIntrinsicInst(
+ IRType* type, IROp op, IRInst* operand, UInt operandCount, IRInst* const* operands)
+ {
+ return createInstWithTrailingArgs<IRInst>(this, op, type, operand, operandCount, operands);
+ }
+
+ IRInst* IRBuilder::createIntrinsicInst(IRType* type, IROp op, UInt operandListCount, UInt const* listOperandCounts, IRInst* const* const* listOperands)
+ {
+ return createInstImpl<IRInst>(this, op, type, 0, nullptr, (Int)operandListCount, (Int const* )listOperandCounts, listOperands);
+ }
+
IRInst* IRBuilder::emitIntrinsicInst(
IRType* type,
@@ -3343,7 +3302,8 @@ namespace Slang
op,
argCount,
args);
- addInst(inst);
+ if (!inst->parent)
+ addInst(inst);
return inst;
}
@@ -3772,6 +3732,13 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_MakeMatrix, argCount, args);
}
+ IRInst* IRBuilder::emitMakeMatrixFromScalar(
+ IRType* type,
+ IRInst* scalarValue)
+ {
+ return emitIntrinsicInst(type, kIROp_MakeMatrixFromScalar, 1, &scalarValue);
+ }
+
IRInst* IRBuilder::emitMakeArray(
IRType* type,
UInt argCount,
@@ -3938,7 +3905,7 @@ namespace Slang
value->insertAtEnd(parent);
}
}
-
+
IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target)
{
return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration);
@@ -5056,7 +5023,7 @@ namespace Slang
this,
kIROp_GlobalConstant,
type);
- addInst(inst);
+ addGlobalValue(this, inst);
return inst;
}
@@ -5069,7 +5036,7 @@ namespace Slang
kIROp_GlobalConstant,
type,
val);
- addInst(inst);
+ addGlobalValue(this, inst);
return inst;
}
@@ -5349,7 +5316,7 @@ namespace Slang
IRInst* operands[] = { kindInst, sizeInst };
- return cast<IRTypeSizeAttr>(findOrEmitHoistableInst(
+ return cast<IRTypeSizeAttr>(createIntrinsicInst(
getVoidType(),
kIROp_TypeSizeAttr,
SLANG_COUNT_OF(operands),
@@ -5376,7 +5343,7 @@ namespace Slang
operands[operandCount++] = spaceInst;
}
- return cast<IRVarOffsetAttr>(findOrEmitHoistableInst(
+ return cast<IRVarOffsetAttr>(createIntrinsicInst(
getVoidType(),
kIROp_VarOffsetAttr,
operandCount,
@@ -5388,7 +5355,7 @@ namespace Slang
{
IRInst* operands[] = { pendingLayout };
- return cast<IRPendingLayoutAttr>(findOrEmitHoistableInst(
+ return cast<IRPendingLayoutAttr>(createIntrinsicInst(
getVoidType(),
kIROp_PendingLayoutAttr,
SLANG_COUNT_OF(operands),
@@ -5401,7 +5368,7 @@ namespace Slang
{
IRInst* operands[] = { key, layout };
- return cast<IRStructFieldLayoutAttr>(findOrEmitHoistableInst(
+ return cast<IRStructFieldLayoutAttr>(createIntrinsicInst(
getVoidType(),
kIROp_StructFieldLayoutAttr,
SLANG_COUNT_OF(operands),
@@ -5413,7 +5380,7 @@ namespace Slang
{
IRInst* operands[] = { layout };
- return cast<IRCaseTypeLayoutAttr>(findOrEmitHoistableInst(
+ return cast<IRCaseTypeLayoutAttr>(createIntrinsicInst(
getVoidType(),
kIROp_CaseTypeLayoutAttr,
SLANG_COUNT_OF(operands),
@@ -5430,7 +5397,7 @@ namespace Slang
IRInst* operands[] = { nameInst, indexInst };
- return cast<IRSemanticAttr>(findOrEmitHoistableInst(
+ return cast<IRSemanticAttr>(createIntrinsicInst(
getVoidType(),
op,
SLANG_COUNT_OF(operands),
@@ -5441,7 +5408,7 @@ namespace Slang
{
auto stageInst = getIntValue(getIntType(), IRIntegerValue(stage));
IRInst* operands[] = { stageInst };
- return cast<IRStageAttr>(findOrEmitHoistableInst(
+ return cast<IRStageAttr>(createIntrinsicInst(
getVoidType(),
kIROp_StageAttr,
SLANG_COUNT_OF(operands),
@@ -5450,7 +5417,7 @@ namespace Slang
IRAttr* IRBuilder::getAttr(IROp op, UInt operandCount, IRInst* const* operands)
{
- return cast<IRAttr>(findOrEmitHoistableInst(
+ return cast<IRAttr>(createIntrinsicInst(
getVoidType(),
op,
operandCount,
@@ -5461,7 +5428,7 @@ namespace Slang
IRTypeLayout* IRBuilder::getTypeLayout(IROp op, List<IRInst*> const& operands)
{
- return cast<IRTypeLayout>(findOrEmitHoistableInst(
+ return cast<IRTypeLayout>(createIntrinsicInst(
getVoidType(),
op,
operands.getCount(),
@@ -5470,7 +5437,7 @@ namespace Slang
IRVarLayout* IRBuilder::getVarLayout(List<IRInst*> const& operands)
{
- return cast<IRVarLayout>(findOrEmitHoistableInst(
+ return cast<IRVarLayout>(createIntrinsicInst(
getVoidType(),
kIROp_VarLayout,
operands.getCount(),
@@ -5483,7 +5450,7 @@ namespace Slang
{
IRInst* operands[] = { paramsLayout, resultLayout };
- return cast<IREntryPointLayout>(findOrEmitHoistableInst(
+ return cast<IREntryPointLayout>(createIntrinsicInst(
getVoidType(),
kIROp_EntryPointLayout,
SLANG_COUNT_OF(operands),
@@ -6528,70 +6495,146 @@ namespace Slang
void validateIRInstOperands(IRInst*);
- void IRInst::replaceUsesWith(IRInst* other)
+ static void _replaceInstUsesWith(IRInst* thisInst, IRInst* other)
{
- // Safety check: don't try to replace something with itself.
- if(other == this)
- return;
+ SharedIRBuilder* sharedBuilder = nullptr;
- // We will walk through the list of uses for the current
- // instruction, and make them point to the other inst.
- IRUse* ff = firstUse;
+ struct WorkItem
+ {
+ IRInst* thisInst;
+ IRInst* otherInst;
+ };
- // No uses? Nothing to do.
- if(!ff)
- return;
+ // A work list of hoistable users for which we need
+ // to deduplicate/update their entry in the global numbering map.
+ List<WorkItem> workList;
+ HashSet<IRInst*> workListSet;
- ff->debugValidate();
+ auto addToWorkList = [&](IRInst* src, IRInst* target)
+ {
+ if (workListSet.Add(src))
+ {
+ WorkItem item;
+ item.thisInst = src;
+ item.otherInst = target;
+ workList.add(item);
+ }
+ };
- IRUse* uu = ff;
- for(;;)
+ addToWorkList(thisInst, other);
+
+ for (Index i = 0; i < workList.getCount(); i++)
{
- // The uses had better all be uses of this
- // instruction, or invariants are broken.
- SLANG_ASSERT(uu->get() == this);
+ auto workItem = workList[i];
+ thisInst = workItem.thisInst;
+ other = workItem.otherInst;
- // Swap this use over to use the other value.
- uu->usedValue = other;
+ // Safety check: don't try to replace something with itself.
+ if (other == thisInst)
+ continue;
- // Try to move to the next use, but bail
- // out if we are at the last one.
- IRUse* nn = uu->nextUse;
- if( !nn )
- break;
+ if (getIROpInfo(thisInst->getOp()).isHoistable())
+ {
+ if (!sharedBuilder)
+ {
+ SLANG_ASSERT(thisInst->getModule());
+ sharedBuilder = thisInst->getModule()->getSharedBuilder();
+ }
+ sharedBuilder->getInstReplacementMap()[thisInst] = other;
+ }
- uu = nn;
- }
+ // We will walk through the list of uses for the current
+ // instruction, and make them point to the other inst.
+ IRUse* ff = thisInst->firstUse;
- // We are at the last use (and there must
- // be at least one, because we handled
- // the case of an empty list earlier).
- SLANG_ASSERT(uu);
+ // No uses? Nothing to do.
+ if (!ff)
+ continue;
- // Our job at this point is to splice
- // our list of uses onto the other
- // value's uses.
- //
- // If the value already had uses, then
- // we need to patch our new list onto
- // the front.
- if( auto nn = other->firstUse )
- {
- uu->nextUse = nn;
- nn->prevLink = &uu->nextUse;
- }
+ //ff->debugValidate();
+
+ IRUse* uu = ff;
+ for (;;)
+ {
+ // The uses had better all be uses of this
+ // instruction, or invariants are broken.
+ SLANG_ASSERT(uu->get() == thisInst);
+
+ auto user = uu->getUser();
+ bool userIsHoistable = getIROpInfo(user->getOp()).isHoistable();
+ if (userIsHoistable)
+ {
+ if (!sharedBuilder)
+ {
+ SLANG_ASSERT(user->getModule());
+ sharedBuilder = user->getModule()->getSharedBuilder();
+ }
+ sharedBuilder->_removeGlobalNumberingEntry(user);
+ }
+
+ // Swap this use over to use the other value.
+ uu->usedValue = other;
+
+ if (userIsHoistable)
+ {
+ // Is the updated inst already exists in the global numbering map?
+ // If so, we need to continue work on replacing the updated inst with the existing value.
+ IRInst* existingVal = nullptr;
+ if (sharedBuilder->getGlobalValueNumberingMap().TryGetValue(IRInstKey{ user }, existingVal))
+ {
+ addToWorkList(user, existingVal);
+ }
+ else
+ {
+ sharedBuilder->_addGlobalNumberingEntry(user);
+ }
+ }
+
+ // Try to move to the next use, but bail
+ // out if we are at the last one.
+ IRUse* nn = uu->nextUse;
+ if (!nn)
+ break;
+
+ uu = nn;
+ }
- // No matter what, our list of
- // uses will become the start
- // of the list of uses for
- // `other`
- other->firstUse = ff;
- ff->prevLink = &other->firstUse;
+ // We are at the last use (and there must
+ // be at least one, because we handled
+ // the case of an empty list earlier).
+ SLANG_ASSERT(uu);
- // And `this` will have no uses any more.
- this->firstUse = nullptr;
+ // Our job at this point is to splice
+ // our list of uses onto the other
+ // value's uses.
+ //
+ // If the value already had uses, then
+ // we need to patch our new list onto
+ // the front.
+ if (auto nn = other->firstUse)
+ {
+ uu->nextUse = nn;
+ nn->prevLink = &uu->nextUse;
+ }
+
+ // No matter what, our list of
+ // uses will become the start
+ // of the list of uses for
+ // `other`
+ other->firstUse = ff;
+ ff->prevLink = &other->firstUse;
+
+ // And `this` will have no uses any more.
+ thisInst->firstUse = nullptr;
+
+ ff->debugValidate();
+ }
- ff->debugValidate();
+ }
+
+ void IRInst::replaceUsesWith(IRInst* other)
+ {
+ _replaceInstUsesWith(this, other);
}
// Insert this instruction into the same basic block
@@ -6750,9 +6793,21 @@ namespace Slang
// and then destroy it (it had better have no uses!)
void IRInst::removeAndDeallocate()
{
- removeFromParent();
+ if (auto module = getModule())
+ {
+ if (getIROpInfo(getOp()).isHoistable())
+ {
+ module->getSharedBuilder()->removeHoistableInstFromGlobalNumberingMap(this);
+ }
+ else if (auto constInst = as<IRConstant>(this))
+ {
+ module->getSharedBuilder()->getConstantMap().Remove(IRConstantKey{ constInst });
+ }
+ module->getSharedBuilder()->getInstReplacementMap().Remove(this);
+ }
removeArguments();
removeAndDeallocateAllDecorationsAndChildren();
+ removeFromParent();
// Run destructor to be sure...
this->~IRInst();
@@ -6919,7 +6974,6 @@ namespace Slang
case kIROp_Not:
case kIROp_BitNot:
case kIROp_Select:
- case kIROp_Dot:
case kIROp_MakeExistential:
case kIROp_ExtractExistentialType:
case kIROp_ExtractExistentialValue:
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 41b140972..9b8aa5cb7 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -37,12 +37,14 @@ enum : IROpFlags
kIROpFlags_None = 0,
kIROpFlag_Parent = 1 << 0, ///< This op is a parent op
kIROpFlag_UseOther = 1 << 1, ///< If set this op can use 'other bits' to store information
+ kIROpFlag_Hoistable = 1 << 2, ///< If set this op is a hoistable inst that needs to be deduplicated.
+ kIROpFlag_Global = 1 << 3, ///< If set this op should always be hoisted but should never be deduplicated.
};
/* Bit usage of IROp is a follows
MainOp | Other
-Bit range: 0-7 | Remaining bits
+Bit range: 0-10 | Remaining bits
For doing range checks (for example for doing isa tests), the value is masked by kIROpMeta_OpMask, such that the Other bits don't interfere.
The other bits can be used for storage for anything that needs to identify as a different 'op' or 'type'. It is currently
@@ -92,6 +94,9 @@ struct IROpInfo
// Flags to control how we emit additional info
IROpFlags flags;
+
+ bool isHoistable() const { return (flags & kIROpFlag_Hoistable) != 0; }
+ bool isGlobal() const { return (flags & kIROpFlag_Global) != 0; }
};
// Look up the info for an op
@@ -206,6 +211,43 @@ struct IRInstList : IRInstListBase
};
template<typename T>
+struct IRModifiableInstList
+{
+ IRInst* parent;
+ List<IRInst*> workList;
+
+ IRModifiableInstList() {}
+
+ IRModifiableInstList(T* parent, T* first, T* last);
+
+ T* getFirst() { return workList.getCount() ? (T*)workList.getFirst() : nullptr; }
+ T* getLast() { return workList.getCount() ? (T*)workList.getLast() : nullptr; }
+
+ struct Iterator
+ {
+ IRModifiableInstList<T>* list;
+ Index position = 0;
+
+ Iterator() {}
+ Iterator(IRModifiableInstList<T>* inList, Index inPos) : list(inList), position(inPos) {}
+
+ T* operator*()
+ {
+ return (T*)(list->workList[position]);
+ }
+ void operator++();
+
+ bool operator!=(Iterator const& i)
+ {
+ return i.list != list || i.position != position;
+ }
+ };
+
+ Iterator begin() { return Iterator(this, 0); }
+ Iterator end() { return Iterator(this, workList.getCount()); }
+};
+
+template<typename T>
struct IRFilteredInstList : IRInstListBase
{
IRFilteredInstList() {}
@@ -591,6 +633,14 @@ struct IRInst
getLastChild());
}
+ IRModifiableInstList<IRInst> getModifiableChildren()
+ {
+ return IRModifiableInstList<IRInst>(
+ this,
+ getFirstChild(),
+ getLastChild());
+ }
+
/// A doubly-linked list containing any decorations and then any children of this instruction.
///
/// We store both the decorations and children of an instruction
@@ -607,7 +657,13 @@ struct IRInst
IRInst* getFirstDecorationOrChild() { return m_decorationsAndChildren.first; }
IRInst* getLastDecorationOrChild() { return m_decorationsAndChildren.last; }
IRInstListBase getDecorationsAndChildren() { return m_decorationsAndChildren; }
-
+ IRModifiableInstList<IRInst> getModifiableDecorationsAndChildren()
+ {
+ return IRModifiableInstList<IRInst>(
+ this,
+ m_decorationsAndChildren.first,
+ m_decorationsAndChildren.last);
+ }
void removeAndDeallocateAllDecorationsAndChildren();
#ifdef SLANG_ENABLE_IR_BREAK_ALLOC
@@ -647,6 +703,12 @@ struct IRInst
getOperands()[index].set(value);
}
+ void unsafeSetOperand(UInt index, IRInst* value)
+ {
+ SLANG_ASSERT(getOperands()[index].user != nullptr);
+ getOperands()[index].init(this, value);
+ }
+
//
@@ -773,6 +835,39 @@ typename IRInstList<T>::Iterator IRInstList<T>::end()
}
template<typename T>
+IRModifiableInstList<T>::IRModifiableInstList(T* inParent, T* first, T* last)
+{
+ parent = inParent;
+ for (auto item = first; item; item = item->next)
+ {
+ workList.add(item);
+ if (item == last)
+ break;
+ }
+}
+
+template<typename T>
+void IRModifiableInstList<T>::Iterator::operator++()
+{
+ position++;
+ while (position < list->workList.getCount())
+ {
+ auto inst = list->workList[position];
+ if (!as<T>(inst))
+ {
+ // Skip insts that are not of type T.
+ }
+ else if (list->parent != inst->parent)
+ {
+ // Skip insts that are no longer in its original parent.
+ }
+ else
+ break;
+ position++;
+ }
+}
+
+template<typename T>
IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst)
{
first = fst;
@@ -1796,6 +1891,104 @@ struct IRModuleInst : IRInst
IR_LEAF_ISA(Module)
};
+struct IRModule;
+
+// Description of an instruction to be used for global value numbering
+struct IRInstKey
+{
+ IRInst* inst;
+
+ HashCode getHashCode();
+};
+
+bool operator==(IRInstKey const& left, IRInstKey const& right);
+
+struct IRConstantKey
+{
+ IRConstant* inst;
+
+ bool operator==(const IRConstantKey& rhs) const { return inst->equal(rhs.inst); }
+ HashCode getHashCode() const { return inst->getHashCode(); }
+};
+
+struct SharedIRBuilder
+{
+public:
+ SharedIRBuilder()
+ {}
+
+ explicit SharedIRBuilder(IRModule* module)
+ {
+ init(module);
+ }
+
+ void init(IRModule* module);
+
+ IRModule* getModule()
+ {
+ return m_module;
+ }
+
+ Session* getSession()
+ {
+ return m_session;
+ }
+
+ void insertBlockAlongEdge(IREdge const& edge);
+
+ // Rebuilds `globalValueNumberingMap`. This is necessary if any existing
+ // keys are modified (thus its hash code is changed).
+ void deduplicateAndRebuildGlobalNumberingMap();
+
+ // Replaces all uses of oldInst with newInst, and ensures the global numbering map is valid after the replacement.
+ void replaceGlobalInst(IRInst* oldInst, IRInst* newInst);
+
+ void removeHoistableInstFromGlobalNumberingMap(IRInst* inst);
+
+ void tryHoistInst(IRInst* inst);
+
+ typedef Dictionary<IRInstKey, IRInst*> GlobalValueNumberingMap;
+ typedef Dictionary<IRConstantKey, IRConstant*> ConstantMap;
+
+ GlobalValueNumberingMap& getGlobalValueNumberingMap() { return m_globalValueNumberingMap; }
+ Dictionary<IRInst*, IRInst*>& getInstReplacementMap() { return m_instReplacementMap; }
+
+ void _addGlobalNumberingEntry(IRInst* inst)
+ {
+ m_globalValueNumberingMap.Add(IRInstKey{ inst }, inst);
+ m_instReplacementMap.Remove(inst);
+ tryHoistInst(inst);
+ }
+ void _removeGlobalNumberingEntry(IRInst* inst)
+ {
+ IRInst* value = nullptr;
+ if (m_globalValueNumberingMap.TryGetValue(IRInstKey{ inst }, value))
+ {
+ if (value == inst)
+ {
+ m_globalValueNumberingMap.Remove(IRInstKey{ inst });
+ }
+ }
+ }
+
+ ConstantMap& getConstantMap() { return m_constantMap; }
+
+private:
+ // The module that will own all of the IR
+ IRModule* m_module;
+
+ // The parent compilation session
+ Session* m_session;
+
+ GlobalValueNumberingMap m_globalValueNumberingMap;
+
+ // Duplicate insts that are still alive and needs to be replaced in m_globalValueNumberMap
+ // when used as an operand to create another inst.
+ Dictionary<IRInst*, IRInst*> m_instReplacementMap;
+
+ ConstantMap m_constantMap;
+};
+
struct IRModule : RefObject
{
public:
@@ -1810,6 +2003,8 @@ public:
SLANG_FORCE_INLINE IRModuleInst* getModuleInst() const { return m_moduleInst; }
SLANG_FORCE_INLINE MemoryArena& getMemoryArena() { return m_memoryArena; }
+ SharedIRBuilder* getSharedBuilder() const { return &m_sharedBuilder; }
+
IRInstListBase getGlobalInsts() const { return getModuleInst()->getChildren(); }
/// Create an empty instruction with the `op` opcode and space for
@@ -1853,6 +2048,7 @@ private:
IRModule(Session* session)
: m_session(session)
, m_memoryArena(kMemoryArenaBlockSize)
+ , m_sharedBuilder(this)
{
}
@@ -1870,6 +2066,9 @@ private:
/// The memory arena from which all IR instructions (and any associated state) in this module are allocated.
MemoryArena m_memoryArena;
+
+ /// Shared contexts for constructing and maintaining the IR.
+ mutable SharedIRBuilder m_sharedBuilder;
};
struct IRSpecializationDictionaryItem : public IRInst
@@ -1943,13 +2142,17 @@ uint32_t& _debugGetIRAllocCounter();
// TODO: Ellie, comment and move somewhere more appropriate?
template<typename I = IRInst, typename F>
-static void traverseUses(IRInst* inst, F f)
+static void traverseUsers(IRInst* inst, F f)
{
- auto n = inst->firstUse;
- IRUse* u;
- while((u = n) != nullptr)
+ List<IRUse*> uses;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
{
- n = u->nextUse;
+ uses.add(use);
+ }
+ for (auto u : uses)
+ {
+ if (u->usedValue != inst)
+ continue;
if(auto s = as<I>(u->getUser()))
{
f(s);
@@ -1957,6 +2160,22 @@ static void traverseUses(IRInst* inst, F f)
}
}
+template<typename F>
+static void traverseUses(IRInst* inst, F f)
+{
+ List<IRUse*> uses;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ uses.add(use);
+ }
+ for (auto u : uses)
+ {
+ if (u->usedValue != inst)
+ continue;
+ f(u);
+ }
+}
+
namespace detail
{
// A helper to get the singular pointer argument of something callable