diff options
| author | Yong He <yonghe@outlook.com> | 2022-09-01 10:01:13 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-09-01 10:01:13 -0700 |
| commit | 4a94473eb34376dd8474f8ca3f2834b5c1daac14 (patch) | |
| tree | 218714e897a2821c2b09727590f364519afe3915 | |
| parent | 3c0177134d126956336865623ea3d6861be59cfa (diff) | |
Deduplicate consts and IRSpecialize in IR, propagate type info for `IntVal`. (#2388)
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-deduplicate.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 10 | ||||
| -rw-r--r-- | tests/bugs/generic-type-duplication.slang | 33 | ||||
| -rw-r--r-- | tests/bugs/generic-type-duplication.slang.expected.txt | 4 |
12 files changed, 103 insertions, 41 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index dd5aff238..64d32f4b4 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -722,10 +722,10 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut *ioDiff += diff; if (evaluatedTerms.getCount() == 0) - return astBuilder->create<ConstantIntVal>(evaluatedConstantTerm); + return astBuilder->create<ConstantIntVal>(type, evaluatedConstantTerm); if (diff != 0) { - auto newPolynomial = astBuilder->create<PolynomialIntVal>(); + auto newPolynomial = astBuilder->create<PolynomialIntVal>(type); newPolynomial->constantTerm = evaluatedConstantTerm; newPolynomial->terms = _Move(evaluatedTerms); return newPolynomial->canonicalize(astBuilder); @@ -770,7 +770,7 @@ bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(base->type); if (!addToPolynomialTerm(astBuilder, result, base, -1)) return nullptr; result->canonicalize(astBuilder); @@ -779,7 +779,7 @@ PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(op0->type); if (!addToPolynomialTerm(astBuilder, result, op0, 1)) return nullptr; if (!addToPolynomialTerm(astBuilder, result, op1, -1)) @@ -790,7 +790,7 @@ PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, Int PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(op0->type); if (!addToPolynomialTerm(astBuilder, result, op0, 1)) return nullptr; if (!addToPolynomialTerm(astBuilder, result, op1, 1)) @@ -805,7 +805,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int { if (auto poly1 = as<PolynomialIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(poly0->type); // add poly0.constant * poly1.constant result->constantTerm = poly0->constantTerm * poly1->constantTerm; // add poly0.constant * poly1.terms @@ -847,7 +847,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto cVal1 = as<ConstantIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(poly0->type); result->constantTerm = poly0->constantTerm * cVal1->value; auto factor1 = astBuilder->create<PolynomialIntValFactor>(); for (auto term : poly0->terms) @@ -863,7 +863,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto val1 = as<IntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(poly0->type); result->constantTerm = 0; auto factor1 = astBuilder->create<PolynomialIntValFactor>(); factor1->power = 1; @@ -901,7 +901,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto cVal1 = as<ConstantIntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(val0->type); auto term = astBuilder->create<PolynomialIntValTerm>(); term->constFactor = cVal1->value; auto factor0 = astBuilder->create<PolynomialIntValFactor>(); @@ -914,7 +914,7 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int } else if (auto val1 = as<IntVal>(op1)) { - auto result = astBuilder->create<PolynomialIntVal>(); + auto result = astBuilder->create<PolynomialIntVal>(val0->type); auto term = astBuilder->create<PolynomialIntValTerm>(); term->constFactor = 1; auto factor0 = astBuilder->create<PolynomialIntValFactor>(); @@ -1035,7 +1035,7 @@ IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) return terms[0]->paramFactors[0]->param; } if (terms.getCount() == 0) - return builder->create<ConstantIntVal>(constantTerm); + return builder->create<ConstantIntVal>(type, constantTerm); return this; } @@ -1127,7 +1127,7 @@ static bool nameIs(Name* name, const char* val) return false; } -Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) +Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) { // Are all args const now? List<ConstantIntVal*> constArgs; @@ -1207,7 +1207,7 @@ Val* FuncCallIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDe { SLANG_UNREACHABLE("constant folding of FuncCallIntVal"); } - return astBuilder->create<ConstantIntVal>(resultValue); + return astBuilder->create<ConstantIntVal>(resultType, resultValue); } return nullptr; } @@ -1228,12 +1228,12 @@ Val* FuncCallIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitutio if (diff) { // TODO: report diagnostics back. - auto newVal = tryFoldImpl(astBuilder, newFuncDeclRef, newArgs, nullptr); + auto newVal = tryFoldImpl(astBuilder, type, newFuncDeclRef, newArgs, nullptr); if (newVal) return newVal; else { - auto result = astBuilder->create<FuncCallIntVal>(); + auto result = astBuilder->create<FuncCallIntVal>(type); result->args = _Move(newArgs); result->funcDeclRef = newFuncDeclRef; result->funcType = funcType; diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 259be75c3..69797b3a5 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -12,6 +12,13 @@ namespace Slang { class IntVal : public Val { SLANG_ABSTRACT_AST_CLASS(IntVal) + + Type* type; + + IntVal(Type* inType) + : type(inType) + {} + }; // Trivial case of a value that is just a constant integer @@ -27,8 +34,8 @@ class ConstantIntVal : public IntVal HashCode _getHashCodeOverride(); protected: - ConstantIntVal(IntegerLiteralValue inValue) - : value(inValue) + ConstantIntVal(Type* inType, IntegerLiteralValue inValue) + : IntVal(inType), value(inValue) {} }; @@ -47,8 +54,8 @@ class GenericParamIntVal : public IntVal Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); protected: - GenericParamIntVal(DeclRef<VarDeclBase> inDeclRef) - : declRef(inDeclRef) + GenericParamIntVal(Type* inType, DeclRef<VarDeclBase> inDeclRef) + : IntVal(inType), declRef(inDeclRef) {} }; @@ -66,7 +73,9 @@ class FuncCallIntVal : public IntVal Type* funcType; List<IntVal*> args; - static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink); + FuncCallIntVal(Type* inType) : IntVal(inType) {} + + static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink); }; class WitnessLookupIntVal : public IntVal @@ -80,7 +89,8 @@ class WitnessLookupIntVal : public IntVal SubtypeWitness* witness; Decl* key; - Type* type; + + WitnessLookupIntVal(Type* inType) : IntVal(inType) {} static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key); @@ -140,6 +150,7 @@ public: { return power == other.power && param->equalsVal(other.param); } + }; class PolynomialIntValTerm : public NodeBase { @@ -207,6 +218,7 @@ public: static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + PolynomialIntVal(Type* inType) : IntVal(inType) {} }; @@ -215,6 +227,8 @@ class ErrorIntVal : public IntVal { SLANG_AST_CLASS(ErrorIntVal) + ErrorIntVal(Type* inType) : IntVal(inType) {} + // TODO: We should probably eventually just have an `ErrorVal` here // and have all `Val`s that represent ordinary values hold their // `Type` so that we can have an `ErrorVal` of any type. diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index a3273f942..4efacd703 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -349,7 +349,7 @@ namespace Slang // We have a new type for the conversion, based on what // we learned. toType = m_astBuilder->getArrayType(toElementType, - m_astBuilder->create<ConstantIntVal>(elementCount)); + m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount)); } } else if(auto toMatrixType = as<MatrixExpressionType>(toType)) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 9e3fbaa8d..9f19023ee 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -536,7 +536,9 @@ namespace Slang } else if( auto genericValueParamDecl = as<GenericValueParamDecl>(mm) ) { - genericSubst->args.add(astBuilder->create<GenericParamIntVal>(DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst))); + genericSubst->args.add(astBuilder->create<GenericParamIntVal>( + genericValueParamDecl->getType(), + DeclRef<GenericValueParamDecl>(genericValueParamDecl, outerSubst))); } } @@ -4235,6 +4237,7 @@ namespace Slang else if (auto valueParam = as<GenericValueParamDecl>(dd)) { auto val = m_astBuilder->create<GenericParamIntVal>( + valueParam->getType(), makeDeclRef(valueParam)); subst->args.add(val); } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index e33d26c0c..c7bfdd3a6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -866,7 +866,7 @@ namespace Slang IntVal* SemanticsVisitor::getIntVal(IntegerLiteralExpr* expr) { // TODO(tfoley): don't keep allocating here! - return m_astBuilder->create<ConstantIntVal>(expr->value); + return m_astBuilder->create<ConstantIntVal>(expr->type.type, expr->value); } IntVal* SemanticsVisitor::tryConstantFoldExpr( @@ -982,7 +982,7 @@ namespace Slang || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") || opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) { - auto result = m_astBuilder->create<FuncCallIntVal>(); + auto result = m_astBuilder->create<FuncCallIntVal>(invokeExpr.getExpr()->type.type); result->args.addRange(argVals, argCount); result->funcDeclRef = funcDeclRef; result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute( @@ -1091,7 +1091,7 @@ namespace Slang } } - IntVal* result = m_astBuilder->create<ConstantIntVal>(resultValue); + IntVal* result = m_astBuilder->create<ConstantIntVal>(invokeExpr.getExpr()->type.type, resultValue); return result; } @@ -1176,7 +1176,7 @@ namespace Slang { // If it's a boolean, we allow promotion to int. const IntegerLiteralValue value = IntegerLiteralValue(boolLitExpr.getExpr()->value); - return m_astBuilder->create<ConstantIntVal>(value); + return m_astBuilder->create<ConstantIntVal>(m_astBuilder->getBoolType(), value); } // it is possible that we are referring to a generic value param @@ -1186,8 +1186,9 @@ namespace Slang if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { - // TODO(tfoley): handle the case of non-`int` value parameters... - Val* valResult = m_astBuilder->create<GenericParamIntVal>(genericValParamRef); + Val* valResult = m_astBuilder->create<GenericParamIntVal>( + declRef.substitute(m_astBuilder, genericValParamRef.getDecl()->getType()), + genericValParamRef); valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); return as<IntVal>(valResult); } @@ -2145,7 +2146,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->create<ConstantIntVal>(elementCount))); + m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there @@ -2266,7 +2267,7 @@ namespace Slang // here if the input type had a sugared name... swizExpr->type = QualType(createVectorType( baseElementType, - m_astBuilder->create<ConstantIntVal>(elementCount))); + m_astBuilder->create<ConstantIntVal>(m_astBuilder->getIntType(), elementCount))); } // A swizzle can be used as an l-value as long as there diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 184ff3350..2e577b6d5 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -1120,7 +1120,7 @@ namespace Slang if(!intVal) { sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); - intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(0); + intVal = getLinkage()->getASTBuilder()->create<ConstantIntVal>(m_astBuilder->getIntType(), 0); } ModuleSpecializationInfo::GenericArgInfo expandedArg; diff --git a/source/slang/slang-ir-deduplicate.cpp b/source/slang/slang-ir-deduplicate.cpp index 8aef7736c..51a677627 100644 --- a/source/slang/slang-ir-deduplicate.cpp +++ b/source/slang/slang-ir-deduplicate.cpp @@ -53,17 +53,22 @@ namespace Slang context.builder = this; m_constantMap.Clear(); m_globalValueNumberingMap.Clear(); + List<IRInst*> instToRemove; for (auto inst : m_module->getGlobalInsts()) { if (auto constVal = as<IRConstant>(inst)) { - context.addConstantValue(constVal); + auto newConst = context.addConstantValue(constVal); + if (newConst != constVal) + { + constVal->replaceUsesWith(newConst); + instToRemove.add(constVal); + } } } - List<IRInst*> instToRemove; for (auto inst : m_module->getGlobalInsts()) { - if (as<IRType>(inst)) + if (as<IRType>(inst) || as<IRSpecialize>(inst)) { auto newInst = context.addTypeValue(inst); if (newInst != inst) diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 1dd18ea47..b9a8f68ab 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -731,6 +731,7 @@ struct SpecializationContext // This prevents us from generating duplicated specializations // when this pass is invoked iteratively. readSpecializationDictionaries(); + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); // The unspecialized IR we receive as input will have // `IRBindGlobalGenericParam` instructions that associate diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 20bf6060d..19d3ce59a 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2146,6 +2146,9 @@ namespace Slang keyInst.value.intVal = static_cast<uint16_t>(inValue); break; case kIROp_BoolType: + keyInst.m_op = kIROp_BoolLit; + keyInst.value.intVal = ((inValue != 0) ? 1 : 0); + break; case kIROp_UIntType: keyInst.value.intVal = static_cast<uint32_t>(inValue); break; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 0921d36fb..d9080169a 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1321,11 +1321,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val) { auto irBuilder = getBuilder(); - auto constTerm = irBuilder->getIntValue(irBuilder->getIntType(), val->constantTerm); + auto type = lowerType(context, val->type); + auto constTerm = irBuilder->getIntValue(type, val->constantTerm); auto resultVal = constTerm; for (auto term : val->terms) { - auto termVal = irBuilder->getIntValue(irBuilder->getIntType(), term->constFactor); + auto termVal = irBuilder->getIntValue(type, term->constFactor); for (auto factor : term->paramFactors) { auto factorVal = lowerVal(context, factor->param).val; @@ -1701,10 +1702,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower LoweredValInfo visitConstantIntVal(ConstantIntVal* val) { - // TODO: it is a bit messy here that the `ConstantIntVal` representation - // has no notion of a *type* associated with the value... - - auto type = getIntType(context); + auto type = lowerType(context, val->type); return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); } diff --git a/tests/bugs/generic-type-duplication.slang b/tests/bugs/generic-type-duplication.slang new file mode 100644 index 000000000..4117a7f81 --- /dev/null +++ b/tests/bugs/generic-type-duplication.slang @@ -0,0 +1,33 @@ +// Test that the same generic type specialization does not get emitted as different types in target code. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj +//TEST(compute,vulkan):COMPARE_COMPUTE_EX:-vk -slang -compute -shaderobj + +struct MyGeneric<let addOne: bool> +{ + int value; + + [mutating] + void load(RWStructuredBuffer<MyGeneric<addOne>> buffer) + { + var m = buffer.Load(0); + if (addOne) + value = m.value + 1; + else + value = m.value; + } +}; + +//TEST_INPUT:set myBuffer = ubuffer(data=[1],stride=4) +RWStructuredBuffer<MyGeneric<true>> myBuffer; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + MyGeneric<true> obj; + obj.load(myBuffer); + outputBuffer[dispatchThreadID.x] = obj.value; +} diff --git a/tests/bugs/generic-type-duplication.slang.expected.txt b/tests/bugs/generic-type-duplication.slang.expected.txt new file mode 100644 index 000000000..487b11653 --- /dev/null +++ b/tests/bugs/generic-type-duplication.slang.expected.txt @@ -0,0 +1,4 @@ +2 +2 +2 +2 |
