diff options
| author | Yong He <yonghe@outlook.com> | 2022-08-22 09:43:05 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-22 09:43:05 -0700 |
| commit | 393185196ed65a9eeaf9502edbf3dcce87337d81 (patch) | |
| tree | 91c9fa14ddb21d15e6cedf83f7aa6b649e99db86 | |
| parent | 15055d20c143cb398bd3e269541eebf24777390a (diff) | |
Support compile-time constant int val in the form of polynomials. (#2372)
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/compiler-core/slang-diagnostic-sink.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 454 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 95 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-check-type.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-language-server.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-type-layout.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang.natvis | 18 | ||||
| -rw-r--r-- | tests/language-feature/generics/generic-value-constant-folding.slang | 25 | ||||
| -rw-r--r-- | tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt | 5 |
11 files changed, 697 insertions, 9 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.cpp b/source/compiler-core/slang-diagnostic-sink.cpp index 34a3c4968..85c7792c0 100644 --- a/source/compiler-core/slang-diagnostic-sink.cpp +++ b/source/compiler-core/slang-diagnostic-sink.cpp @@ -405,10 +405,13 @@ static void formatDiagnostic( HumaneSourceLoc humaneLoc; const auto sourceLoc = diagnostic.loc; { - sourceView = sourceManager->findSourceViewRecursively(sourceLoc); - if (sourceView) + if (sourceManager) { - humaneLoc = sourceView->getHumaneLoc(sourceLoc); + sourceView = sourceManager->findSourceViewRecursively(sourceLoc); + if (sourceView) + { + humaneLoc = sourceView->getHumaneLoc(sourceLoc); + } } formatDiagnostic(humaneLoc, diagnostic, sink->getFlags(), sb); @@ -418,7 +421,7 @@ static void formatDiagnostic( while (currentView && currentView->getInitiatingSourceLoc().isValid() && currentView->getSourceFile()->getPathInfo().type == PathInfo::Type::TokenPaste) { - SourceView* initiatingView = sourceManager->findSourceView(currentView->getInitiatingSourceLoc()); + SourceView* initiatingView = sourceManager ? sourceManager->findSourceView(currentView->getInitiatingSourceLoc()) : nullptr; if (initiatingView == nullptr) { break; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 0f3bd2b3a..16c649f3a 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -576,4 +576,458 @@ Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut return this; } +// PolynomialIntVal + +bool PolynomialIntVal::_equalsValOverride(Val* val) +{ + if (auto genericParamVal = as<GenericParamIntVal>(val)) + { + return constantTerm == 0 && terms.getCount() == 1 && + terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 && + terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) && + terms[0]->paramFactors[0]->power == 1; + } + else if (auto otherPolynomial = as<PolynomialIntVal>(val)) + { + if (constantTerm != otherPolynomial->constantTerm) + return false; + if (terms.getCount() != otherPolynomial->terms.getCount()) + return false; + for (Index i = 0; i < terms.getCount(); i++) + { + auto& thisTerm = *(terms[i]); + auto& thatTerm = *(otherPolynomial->terms[i]); + if (thisTerm.constFactor != thatTerm.constFactor) + return false; + if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount()) + return false; + for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++) + { + if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power) + return false; + if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param)) + return false; + } + } + return true; + } + return false; +} + +void PolynomialIntVal::_toTextOverride(StringBuilder& out) +{ + for (Index i = 0; i < terms.getCount(); i++) + { + auto& term = *(terms[i]); + if (i > 0) + { + if (term.constFactor > 0) + out << "+"; + else + out << "-"; + } + bool isFirstFactor = true; + if (term.constFactor != 1 || term.paramFactors.getCount() == 0) + { + out << abs(term.constFactor); + isFirstFactor = false; + } + for (Index j = 0; j < term.paramFactors.getCount(); j++) + { + auto factor = term.paramFactors[j]; + if (isFirstFactor) + { + isFirstFactor = false; + } + else + { + out << "*"; + } + factor->param->toText(out); + if (factor->power != 1) + { + out << "^^" << factor->power; + } + } + } + if (constantTerm > 0) + { + if (terms.getCount() > 0) + { + out << "+"; + } + out << constantTerm; + } + else if (constantTerm < 0) + { + out << constantTerm; + } +} + +HashCode PolynomialIntVal::_getHashCodeOverride() +{ + HashCode result = (HashCode)constantTerm; + for (auto& term : terms) + { + if (!term) continue; + result = combineHash(result, (HashCode)term->constFactor); + for (auto& factor : term->paramFactors) + { + result = combineHash(result, factor->param->getHashCode()); + result = combineHash(result, (HashCode)factor->power); + } + } + return result; +} + +Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + IntegerLiteralValue evaluatedConstantTerm = constantTerm; + List<PolynomialIntValTerm*> evaluatedTerms; + for (auto& term : terms) + { + IntegerLiteralValue evaluatedTermConstFactor; + List<PolynomialIntValFactor*> evaluatedTermParamFactors; + evaluatedTermConstFactor = term->constFactor; + for (auto& factor : term->paramFactors) + { + auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); + if (auto genericVal = as<GenericParamIntVal>(substResult)) + { + auto newFactor = astBuilder->create<PolynomialIntValFactor>(); + newFactor->param = genericVal; + newFactor->power = factor->power; + evaluatedTermParamFactors.add(newFactor); + } + else if (auto constantVal = as<ConstantIntVal>(substResult)) + { + evaluatedTermConstFactor *= constantVal->value; + } + } + if (evaluatedTermParamFactors.getCount() == 0) + { + evaluatedConstantTerm += evaluatedTermConstFactor; + } + else + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->paramFactors = _Move(evaluatedTermParamFactors); + newTerm->constFactor = evaluatedTermConstFactor; + evaluatedTerms.add(newTerm); + } + } + + *ioDiff += diff; + + if (evaluatedTerms.getCount() == 0) + return astBuilder->create<ConstantIntVal>(evaluatedConstantTerm); + if (diff != 0) + { + auto newPolynomial = astBuilder->create<PolynomialIntVal>(); + newPolynomial->constantTerm = evaluatedConstantTerm; + newPolynomial->terms = _Move(evaluatedTerms); + newPolynomial->canonicalize(astBuilder); + return newPolynomial; + } + return nullptr; +} + + +// compute val += opreand*multiplier; +bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) +{ + if (auto genVal = as<GenericParamIntVal>(operand)) + { + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = multiplier; + auto factor = astBuilder->create<PolynomialIntValFactor>(); + factor->power = 1; + factor->param = genVal; + term->paramFactors.add(factor); + val->terms.add(term); + return true; + } + else if (auto c = as<ConstantIntVal>(operand)) + { + val->constantTerm += c->value * multiplier; + return true; + } + else if (auto poly = as<PolynomialIntVal>(operand)) + { + val->constantTerm += poly->constantTerm * multiplier; + for (auto term : poly->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = multiplier * term->constFactor; + newTerm->paramFactors = term->paramFactors; + val->terms.add(newTerm); + } + return true; + } + return false; +} + +PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) +{ + auto result = astBuilder->create<PolynomialIntVal>(); + if (!addToPolynomialTerm(astBuilder, result, base, -1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + auto result = astBuilder->create<PolynomialIntVal>(); + if (!addToPolynomialTerm(astBuilder, result, op0, 1)) + return nullptr; + if (!addToPolynomialTerm(astBuilder, result, op1, -1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + auto result = astBuilder->create<PolynomialIntVal>(); + if (!addToPolynomialTerm(astBuilder, result, op0, 1)) + return nullptr; + if (!addToPolynomialTerm(astBuilder, result, op1, 1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + if (auto poly0 = as<PolynomialIntVal>(op0)) + { + if (auto poly1 = as<PolynomialIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + // add poly0.constant * poly1.constant + result->constantTerm = poly0->constantTerm * poly1->constantTerm; + // add poly0.constant * poly1.terms + if (poly0->constantTerm != 0) + { + for (auto term : poly1->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = poly0->constantTerm * term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + result->terms.add(newTerm); + } + } + // add poly1.constant * poly0.terms + if (poly1->constantTerm != 0) + { + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = poly1->constantTerm * term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + result->terms.add(newTerm); + } + } + // add poly1.terms * poly0.terms + for (auto term0 : poly0->terms) + { + for (auto term1 : poly1->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term0->constFactor * term1->constFactor; + newTerm->paramFactors.addRange(term0->paramFactors); + newTerm->paramFactors.addRange(term1->paramFactors); + result->terms.add(newTerm); + } + } + result->canonicalize(astBuilder); + return result; + } + else if (auto val1 = as<GenericParamIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + result->constantTerm = 0; + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + if (poly0->constantTerm != 0) + { + auto term0 = astBuilder->create<PolynomialIntValTerm>(); + term0->constFactor = poly0->constantTerm; + term0->paramFactors.add(factor1); + result->terms.add(term0); + } + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + newTerm->paramFactors.add(factor1); + result->terms.add(newTerm); + } + result->canonicalize(astBuilder); + return result; + } + else if (auto cVal1 = as<ConstantIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + result->constantTerm = poly0->constantTerm * cVal1->value; + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term->constFactor * cVal1->value; + newTerm->paramFactors.addRange(term->paramFactors); + newTerm->paramFactors.add(factor1); + result->terms.add(newTerm); + } + result->canonicalize(astBuilder); + return result; + } + else + return nullptr; + } + else if (auto val0 = as<GenericParamIntVal>(op0)) + { + if (auto poly1 = as<PolynomialIntVal>(op1)) + { + return mul(astBuilder, op1, op0); + } + else if (auto val1 = as<GenericParamIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = 1; + auto factor0 = astBuilder->create<PolynomialIntValFactor>(); + factor0->power = 1; + factor0->param = val0; + term->paramFactors.add(factor0); + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + term->paramFactors.add(factor1); + result->terms.add(term); + result->canonicalize(astBuilder); + return result; + } + else if (auto cVal1 = as<ConstantIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = cVal1->value; + auto factor0 = astBuilder->create<PolynomialIntValFactor>(); + factor0->power = 1; + factor0->param = val0; + term->paramFactors.add(factor0); + result->terms.add(term); + result->canonicalize(astBuilder); + return result; + } + } + else if (as<ConstantIntVal>(op0)) + { + return mul(astBuilder, op1, op0); + } + return nullptr; +} + +void PolynomialIntVal::canonicalize(ASTBuilder* builder) +{ + List<PolynomialIntValTerm*> newTerms; + IntegerLiteralValue newConstantTerm = constantTerm; + auto addTerm = [&](PolynomialIntValTerm* newTerm) + { + for (auto term : newTerms) + { + if (term->canCombineWith(*newTerm)) + { + term->constFactor += newTerm->constFactor; + return; + } + } + newTerms.add(newTerm); + }; + for (auto term : terms) + { + if (term->constFactor == 0) + continue; + List<PolynomialIntValFactor*> newFactors; + List<bool> factorIsDifferent; + for (Index i = 0; i < term->paramFactors.getCount(); i++) + { + auto factor = term->paramFactors[i]; + bool factorFound = false; + for (Index j = 0; j < newFactors.getCount(); j++) + { + auto& newFactor = newFactors[j]; + if (factor->param->equalsVal(newFactor->param)) + { + if (!factorIsDifferent[j]) + { + factorIsDifferent[j] = true; + auto clonedFactor = builder->create<PolynomialIntValFactor>(); + clonedFactor->param = newFactor->param; + clonedFactor->power = newFactor->power; + newFactor = clonedFactor; + } + newFactor->power += factor->power; + factorFound = true; + break; + } + } + if (!factorFound) + { + newFactors.add(factor); + factorIsDifferent.add(false); + } + } + List<PolynomialIntValFactor*> newFactors2; + for (auto factor : newFactors) + { + if (factor->power != 0) + newFactors2.add(factor); + } + if (newFactors2.getCount() == 0) + { + newConstantTerm += term->constFactor; + continue; + } + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + bool isDifferent = false; + if (newFactors2.getCount() != term->paramFactors.getCount()) + isDifferent = true; + if (!isDifferent) + { + for (Index i = 0; i < term->paramFactors.getCount(); i++) + if (term->paramFactors[i] != newFactors2[i]) + { + isDifferent = true; + break; + } + } + if (!isDifferent) + { + addTerm(term); + } + else + { + auto newTerm = builder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term->constFactor; + newTerm->paramFactors = _Move(newFactors2); + addTerm(newTerm); + } + } + List<PolynomialIntValTerm*> newTerms2; + for (auto term : newTerms) + { + if (term->constFactor == 0) + continue; + newTerms2.add(term); + } + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + terms = _Move(newTerms2); +} + } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 5fd3e54f5..6eaaa8eb1 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -52,6 +52,101 @@ protected: {} }; +// polynomial expression "2*a*b^3 + 1" will be represented as: +// { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] } +class PolynomialIntValFactor : public NodeBase +{ + SLANG_AST_CLASS(PolynomialIntValFactor) +public: + GenericParamIntVal* param; + IntegerLiteralValue power; + // for sorting only. + bool operator<(const PolynomialIntValFactor& other) const + { + if (param->declRef.decl < other.param->declRef.decl) + return true; + else if (param->declRef.decl == other.param->declRef.decl) + return power < other.power; + return false; + } + // for sorting only. + bool operator==(const PolynomialIntValFactor& other) const + { + if (param->declRef.decl == other.param->declRef.decl && power == other.power) + return true; + return false; + } + bool equals(const PolynomialIntValFactor& other) const + { + return power == other.power && param->equalsVal(other.param); + } +}; +class PolynomialIntValTerm : public NodeBase +{ + SLANG_AST_CLASS(PolynomialIntValTerm) +public: + IntegerLiteralValue constFactor; + List<PolynomialIntValFactor*> paramFactors; + bool canCombineWith(const PolynomialIntValTerm& other) const + { + if (paramFactors.getCount() != other.paramFactors.getCount()) + return false; + for (Index i = 0; i < paramFactors.getCount(); i++) + { + if (!paramFactors[i]->equals(*other.paramFactors[i])) + return false; + } + return true; + } + bool operator<(const PolynomialIntValTerm& other) const + { + if (constFactor < other.constFactor) + return true; + else if (constFactor == other.constFactor) + { + for (Index i = 0; i < paramFactors.getCount(); i++) + { + if (i >= other.paramFactors.getCount()) + return false; + if (*(paramFactors[i]) < *(other.paramFactors[i])) + return true; + if (*(paramFactors[i]) == *(other.paramFactors[i])) + { + } + else + { + return false; + } + } + } + return false; + } +}; + +class PolynomialIntVal : public IntVal +{ + SLANG_AST_CLASS(PolynomialIntVal) +public: + + List<PolynomialIntValTerm*> terms; + IntegerLiteralValue constantTerm = 0; + + bool isConstant() { return terms.getCount() == 0; } + void canonicalize(ASTBuilder* builder); + + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + static PolynomialIntVal* neg(ASTBuilder* astBuilder, IntVal* base); + static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + +}; + /// An unknown integer value indicating an erroneous sub-expression class ErrorIntVal : public IntVal { diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index c96db6f3b..7257790af 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -928,7 +928,7 @@ namespace Slang if (!allConst) { - // TODO(tfoley): We probably want to support a very limited number of operations + // We support a very limited number of operations // on "constants" that aren't actually known, to be able to handle a generic // that takes an integer `N` but then constructs a vector of size `N+1`. // @@ -937,7 +937,46 @@ namespace Slang // need inference to be smart enough to know that `2 + N` and `N + 2` are the // same value, as are `N + M + 1 + 1` and `M + 2 + N`. // - // For now we can just bail in this case. + // This is done by constructing a 'PolynomialIntVal' and rely on its + // `canonicalize` operation. + if (implicitCast) + { + // We cannot support casting in this case. + return nullptr; + } + + auto opName = funcDeclRef.getName(); + + // handle binary operators + if (opName == getName("-")) + { + if (argCount == 1) + { + return PolynomialIntVal::neg(m_astBuilder, argVals[0]); + } + else if (argCount == 2) + { + return PolynomialIntVal::sub(m_astBuilder, argVals[0], argVals[1]); + } + } + else if (opName == getName("+")) + { + if (argCount == 1) + { + return argVals[0]; + } + else if (argCount == 2) + { + return PolynomialIntVal::add(m_astBuilder, argVals[0], argVals[1]); + } + } + else if (opName == getName("*")) + { + if (argCount == 2) + { + return PolynomialIntVal::mul(m_astBuilder, argVals[0], argVals[1]); + } + } return nullptr; } @@ -1110,7 +1149,9 @@ namespace Slang if (auto genericValParamRef = declRef.as<GenericValueParamDecl>()) { // TODO(tfoley): handle the case of non-`int` value parameters... - return m_astBuilder->create<GenericParamIntVal>(genericValParamRef); + Val* valResult = m_astBuilder->create<GenericParamIntVal>(genericValParamRef); + valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); + return as<IntVal>(valResult); } // We may also need to check for references to variables that diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index b8789dd76..aa2c69126 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -367,8 +367,15 @@ namespace Slang { return leftVar->declRef.equals(rightVar->declRef); } + else if (auto rightPoly = as<PolynomialIntVal>(right)) + { + return right->equalsVal(leftVar); + } + } + if (auto leftVar = as<PolynomialIntVal>(left)) + { + return leftVar->equalsVal(right); } - return false; } diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index da72ea65e..d3be5c4dc 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -243,7 +243,9 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) DiagnosticSink sink; SharedSemanticsContext semanticContext(version->linkage, getModule(varDecl), &sink); SemanticsVisitor semanticsVisitor(&semanticContext); - if (auto intVal = semanticsVisitor.tryFoldIntegerConstantExpression(varDecl->initExpr, nullptr)) + if (auto intVal = semanticsVisitor.tryFoldIntegerConstantExpression( + declRef.substitute(version->linkage->getASTBuilder(), varDecl->initExpr), + nullptr)) { if (auto constantInt = as<ConstantIntVal>(intVal)) { @@ -257,6 +259,11 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) sb << constantInt->value; } } + else + { + sb << " = "; + intVal->toText(sb); + } } } } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1285afca8..3d4f0df5d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1287,6 +1287,27 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower lowerType(context, getType(context->astBuilder, val->declRef))); } + LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val) + { + auto irBuilder = getBuilder(); + auto constTerm = irBuilder->getIntValue(irBuilder->getIntType(), val->constantTerm); + auto resultVal = constTerm; + for (auto term : val->terms) + { + auto termVal = irBuilder->getIntValue(irBuilder->getIntType(), term->constFactor); + for (auto factor : term->paramFactors) + { + auto factorVal = lowerVal(context, factor->param).val; + for (IntegerLiteralValue i = 0; i < factor->power; i++) + { + termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal); + } + } + resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal); + } + return LoweredValInfo::simple(resultVal); + } + LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) { return emitDeclRef(context, val->declRef, diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 22af19271..73ccd4726 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1479,6 +1479,18 @@ static LayoutSize GetElementCount(IntVal* val) // return 0; } + else if (auto polyIntVal = as<PolynomialIntVal>(val)) + { + // TODO: We want to treat the case where the number of + // elements in an array depends on a generic parameter + // much like the case where the number of elements is + // unbounded, *but* we can't just blindly do that because + // an API might disallow unbounded arrays in various + // cases where a generic bound might work (because + // any concrete specialization will have a finite bound...) + // + return 0; + } SLANG_UNEXPECTED("unhandled integer literal kind"); UNREACHABLE_RETURN(LayoutSize(0)); } diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index d56fa885d..4d5bad2d3 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -447,10 +447,28 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&astNodeType</ExpandedItem> + <Item Name="[Type]">(Slang::Type*)this,nd</Item> </Expand> </Type> + <Type Name="Slang::Substitutions"> + <DisplayString>{astNodeType}</DisplayString> + <Expand> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">(Slang::GenericSubstitution*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeSubstitution">(Slang::ThisTypeSubstitution*)&astNodeType</ExpandedItem> + </Expand> + </Type> + <Type Name="Slang::SubstitutionSet"> + <DisplayString>{astNodeType}</DisplayString> + <Expand> + <LinkedListItems> + <HeadPointer>substitutions</HeadPointer> + <NextPointer>outer</NextPointer> + <ValueNode>this</ValueNode> + </LinkedListItems> + </Expand> + </Type> <Type Name="Slang::AggTypeDecl"> <DisplayString>{nameAndLoc.name}: {astNodeType}</DisplayString> <Expand> diff --git a/tests/language-feature/generics/generic-value-constant-folding.slang b/tests/language-feature/generics/generic-value-constant-folding.slang new file mode 100644 index 000000000..1d6781889 --- /dev/null +++ b/tests/language-feature/generics/generic-value-constant-folding.slang @@ -0,0 +1,25 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj -output-using-type + +struct PlusOne<let v : int> +{ + static const int value = v + 1; +} + +struct GetConst<let v : int, let u : int> +{ + static const int value = (u+v)*(u+v) + PlusOne<u-1>.value; + int arr[value]; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + int tid = dispatchThreadID.x; + int inVal = tid; + int arr[GetConst<5,2>.value + 1]; + arr[0] = GetConst<5,3>.value + 1; + outputBuffer[tid] = arr[0]; +} diff --git a/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt b/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt new file mode 100644 index 000000000..2ba17a828 --- /dev/null +++ b/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt @@ -0,0 +1,5 @@ +type: int32_t +68 +68 +68 +68 |
