From 393185196ed65a9eeaf9502edbf3dcce87337d81 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 22 Aug 2022 09:43:05 -0700 Subject: Support compile-time constant int val in the form of polynomials. (#2372) Co-authored-by: Yong He --- source/slang/slang-ast-val.cpp | 454 +++++++++++++++++++++++++++++++++ source/slang/slang-ast-val.h | 95 +++++++ source/slang/slang-check-expr.cpp | 47 +++- source/slang/slang-check-type.cpp | 9 +- source/slang/slang-language-server.cpp | 9 +- source/slang/slang-lower-to-ir.cpp | 21 ++ source/slang/slang-type-layout.cpp | 12 + source/slang/slang.natvis | 18 ++ 8 files changed, 660 insertions(+), 5 deletions(-) (limited to 'source/slang') diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 0f3bd2b3a..16c649f3a 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -576,4 +576,458 @@ Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut return this; } +// PolynomialIntVal + +bool PolynomialIntVal::_equalsValOverride(Val* val) +{ + if (auto genericParamVal = as(val)) + { + return constantTerm == 0 && terms.getCount() == 1 && + terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 && + terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) && + terms[0]->paramFactors[0]->power == 1; + } + else if (auto otherPolynomial = as(val)) + { + if (constantTerm != otherPolynomial->constantTerm) + return false; + if (terms.getCount() != otherPolynomial->terms.getCount()) + return false; + for (Index i = 0; i < terms.getCount(); i++) + { + auto& thisTerm = *(terms[i]); + auto& thatTerm = *(otherPolynomial->terms[i]); + if (thisTerm.constFactor != thatTerm.constFactor) + return false; + if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount()) + return false; + for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++) + { + if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power) + return false; + if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param)) + return false; + } + } + return true; + } + return false; +} + +void PolynomialIntVal::_toTextOverride(StringBuilder& out) +{ + for (Index i = 0; i < terms.getCount(); i++) + { + auto& term = *(terms[i]); + if (i > 0) + { + if (term.constFactor > 0) + out << "+"; + else + out << "-"; + } + bool isFirstFactor = true; + if (term.constFactor != 1 || term.paramFactors.getCount() == 0) + { + out << abs(term.constFactor); + isFirstFactor = false; + } + for (Index j = 0; j < term.paramFactors.getCount(); j++) + { + auto factor = term.paramFactors[j]; + if (isFirstFactor) + { + isFirstFactor = false; + } + else + { + out << "*"; + } + factor->param->toText(out); + if (factor->power != 1) + { + out << "^^" << factor->power; + } + } + } + if (constantTerm > 0) + { + if (terms.getCount() > 0) + { + out << "+"; + } + out << constantTerm; + } + else if (constantTerm < 0) + { + out << constantTerm; + } +} + +HashCode PolynomialIntVal::_getHashCodeOverride() +{ + HashCode result = (HashCode)constantTerm; + for (auto& term : terms) + { + if (!term) continue; + result = combineHash(result, (HashCode)term->constFactor); + for (auto& factor : term->paramFactors) + { + result = combineHash(result, factor->param->getHashCode()); + result = combineHash(result, (HashCode)factor->power); + } + } + return result; +} + +Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + IntegerLiteralValue evaluatedConstantTerm = constantTerm; + List evaluatedTerms; + for (auto& term : terms) + { + IntegerLiteralValue evaluatedTermConstFactor; + List evaluatedTermParamFactors; + evaluatedTermConstFactor = term->constFactor; + for (auto& factor : term->paramFactors) + { + auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); + if (auto genericVal = as(substResult)) + { + auto newFactor = astBuilder->create(); + newFactor->param = genericVal; + newFactor->power = factor->power; + evaluatedTermParamFactors.add(newFactor); + } + else if (auto constantVal = as(substResult)) + { + evaluatedTermConstFactor *= constantVal->value; + } + } + if (evaluatedTermParamFactors.getCount() == 0) + { + evaluatedConstantTerm += evaluatedTermConstFactor; + } + else + { + auto newTerm = astBuilder->create(); + newTerm->paramFactors = _Move(evaluatedTermParamFactors); + newTerm->constFactor = evaluatedTermConstFactor; + evaluatedTerms.add(newTerm); + } + } + + *ioDiff += diff; + + if (evaluatedTerms.getCount() == 0) + return astBuilder->create(evaluatedConstantTerm); + if (diff != 0) + { + auto newPolynomial = astBuilder->create(); + newPolynomial->constantTerm = evaluatedConstantTerm; + newPolynomial->terms = _Move(evaluatedTerms); + newPolynomial->canonicalize(astBuilder); + return newPolynomial; + } + return nullptr; +} + + +// compute val += opreand*multiplier; +bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) +{ + if (auto genVal = as(operand)) + { + auto term = astBuilder->create(); + term->constFactor = multiplier; + auto factor = astBuilder->create(); + factor->power = 1; + factor->param = genVal; + term->paramFactors.add(factor); + val->terms.add(term); + return true; + } + else if (auto c = as(operand)) + { + val->constantTerm += c->value * multiplier; + return true; + } + else if (auto poly = as(operand)) + { + val->constantTerm += poly->constantTerm * multiplier; + for (auto term : poly->terms) + { + auto newTerm = astBuilder->create(); + newTerm->constFactor = multiplier * term->constFactor; + newTerm->paramFactors = term->paramFactors; + val->terms.add(newTerm); + } + return true; + } + return false; +} + +PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) +{ + auto result = astBuilder->create(); + if (!addToPolynomialTerm(astBuilder, result, base, -1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + auto result = astBuilder->create(); + if (!addToPolynomialTerm(astBuilder, result, op0, 1)) + return nullptr; + if (!addToPolynomialTerm(astBuilder, result, op1, -1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + auto result = astBuilder->create(); + if (!addToPolynomialTerm(astBuilder, result, op0, 1)) + return nullptr; + if (!addToPolynomialTerm(astBuilder, result, op1, 1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + if (auto poly0 = as(op0)) + { + if (auto poly1 = as(op1)) + { + auto result = astBuilder->create(); + // add poly0.constant * poly1.constant + result->constantTerm = poly0->constantTerm * poly1->constantTerm; + // add poly0.constant * poly1.terms + if (poly0->constantTerm != 0) + { + for (auto term : poly1->terms) + { + auto newTerm = astBuilder->create(); + newTerm->constFactor = poly0->constantTerm * term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + result->terms.add(newTerm); + } + } + // add poly1.constant * poly0.terms + if (poly1->constantTerm != 0) + { + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create(); + newTerm->constFactor = poly1->constantTerm * term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + result->terms.add(newTerm); + } + } + // add poly1.terms * poly0.terms + for (auto term0 : poly0->terms) + { + for (auto term1 : poly1->terms) + { + auto newTerm = astBuilder->create(); + newTerm->constFactor = term0->constFactor * term1->constFactor; + newTerm->paramFactors.addRange(term0->paramFactors); + newTerm->paramFactors.addRange(term1->paramFactors); + result->terms.add(newTerm); + } + } + result->canonicalize(astBuilder); + return result; + } + else if (auto val1 = as(op1)) + { + auto result = astBuilder->create(); + result->constantTerm = 0; + auto factor1 = astBuilder->create(); + factor1->power = 1; + factor1->param = val1; + if (poly0->constantTerm != 0) + { + auto term0 = astBuilder->create(); + term0->constFactor = poly0->constantTerm; + term0->paramFactors.add(factor1); + result->terms.add(term0); + } + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create(); + newTerm->constFactor = term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + newTerm->paramFactors.add(factor1); + result->terms.add(newTerm); + } + result->canonicalize(astBuilder); + return result; + } + else if (auto cVal1 = as(op1)) + { + auto result = astBuilder->create(); + result->constantTerm = poly0->constantTerm * cVal1->value; + auto factor1 = astBuilder->create(); + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create(); + newTerm->constFactor = term->constFactor * cVal1->value; + newTerm->paramFactors.addRange(term->paramFactors); + newTerm->paramFactors.add(factor1); + result->terms.add(newTerm); + } + result->canonicalize(astBuilder); + return result; + } + else + return nullptr; + } + else if (auto val0 = as(op0)) + { + if (auto poly1 = as(op1)) + { + return mul(astBuilder, op1, op0); + } + else if (auto val1 = as(op1)) + { + auto result = astBuilder->create(); + auto term = astBuilder->create(); + term->constFactor = 1; + auto factor0 = astBuilder->create(); + factor0->power = 1; + factor0->param = val0; + term->paramFactors.add(factor0); + auto factor1 = astBuilder->create(); + factor1->power = 1; + factor1->param = val1; + term->paramFactors.add(factor1); + result->terms.add(term); + result->canonicalize(astBuilder); + return result; + } + else if (auto cVal1 = as(op1)) + { + auto result = astBuilder->create(); + auto term = astBuilder->create(); + term->constFactor = cVal1->value; + auto factor0 = astBuilder->create(); + factor0->power = 1; + factor0->param = val0; + term->paramFactors.add(factor0); + result->terms.add(term); + result->canonicalize(astBuilder); + return result; + } + } + else if (as(op0)) + { + return mul(astBuilder, op1, op0); + } + return nullptr; +} + +void PolynomialIntVal::canonicalize(ASTBuilder* builder) +{ + List newTerms; + IntegerLiteralValue newConstantTerm = constantTerm; + auto addTerm = [&](PolynomialIntValTerm* newTerm) + { + for (auto term : newTerms) + { + if (term->canCombineWith(*newTerm)) + { + term->constFactor += newTerm->constFactor; + return; + } + } + newTerms.add(newTerm); + }; + for (auto term : terms) + { + if (term->constFactor == 0) + continue; + List newFactors; + List factorIsDifferent; + for (Index i = 0; i < term->paramFactors.getCount(); i++) + { + auto factor = term->paramFactors[i]; + bool factorFound = false; + for (Index j = 0; j < newFactors.getCount(); j++) + { + auto& newFactor = newFactors[j]; + if (factor->param->equalsVal(newFactor->param)) + { + if (!factorIsDifferent[j]) + { + factorIsDifferent[j] = true; + auto clonedFactor = builder->create(); + clonedFactor->param = newFactor->param; + clonedFactor->power = newFactor->power; + newFactor = clonedFactor; + } + newFactor->power += factor->power; + factorFound = true; + break; + } + } + if (!factorFound) + { + newFactors.add(factor); + factorIsDifferent.add(false); + } + } + List newFactors2; + for (auto factor : newFactors) + { + if (factor->power != 0) + newFactors2.add(factor); + } + if (newFactors2.getCount() == 0) + { + newConstantTerm += term->constFactor; + continue; + } + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + bool isDifferent = false; + if (newFactors2.getCount() != term->paramFactors.getCount()) + isDifferent = true; + if (!isDifferent) + { + for (Index i = 0; i < term->paramFactors.getCount(); i++) + if (term->paramFactors[i] != newFactors2[i]) + { + isDifferent = true; + break; + } + } + if (!isDifferent) + { + addTerm(term); + } + else + { + auto newTerm = builder->create(); + newTerm->constFactor = term->constFactor; + newTerm->paramFactors = _Move(newFactors2); + addTerm(newTerm); + } + } + List newTerms2; + for (auto term : newTerms) + { + if (term->constFactor == 0) + continue; + newTerms2.add(term); + } + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + terms = _Move(newTerms2); +} + } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 5fd3e54f5..6eaaa8eb1 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -52,6 +52,101 @@ protected: {} }; +// polynomial expression "2*a*b^3 + 1" will be represented as: +// { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] } +class PolynomialIntValFactor : public NodeBase +{ + SLANG_AST_CLASS(PolynomialIntValFactor) +public: + GenericParamIntVal* param; + IntegerLiteralValue power; + // for sorting only. + bool operator<(const PolynomialIntValFactor& other) const + { + if (param->declRef.decl < other.param->declRef.decl) + return true; + else if (param->declRef.decl == other.param->declRef.decl) + return power < other.power; + return false; + } + // for sorting only. + bool operator==(const PolynomialIntValFactor& other) const + { + if (param->declRef.decl == other.param->declRef.decl && power == other.power) + return true; + return false; + } + bool equals(const PolynomialIntValFactor& other) const + { + return power == other.power && param->equalsVal(other.param); + } +}; +class PolynomialIntValTerm : public NodeBase +{ + SLANG_AST_CLASS(PolynomialIntValTerm) +public: + IntegerLiteralValue constFactor; + List paramFactors; + bool canCombineWith(const PolynomialIntValTerm& other) const + { + if (paramFactors.getCount() != other.paramFactors.getCount()) + return false; + for (Index i = 0; i < paramFactors.getCount(); i++) + { + if (!paramFactors[i]->equals(*other.paramFactors[i])) + return false; + } + return true; + } + bool operator<(const PolynomialIntValTerm& other) const + { + if (constFactor < other.constFactor) + return true; + else if (constFactor == other.constFactor) + { + for (Index i = 0; i < paramFactors.getCount(); i++) + { + if (i >= other.paramFactors.getCount()) + return false; + if (*(paramFactors[i]) < *(other.paramFactors[i])) + return true; + if (*(paramFactors[i]) == *(other.paramFactors[i])) + { + } + else + { + return false; + } + } + } + return false; + } +}; + +class PolynomialIntVal : public IntVal +{ + SLANG_AST_CLASS(PolynomialIntVal) +public: + + List terms; + IntegerLiteralValue constantTerm = 0; + + bool isConstant() { return terms.getCount() == 0; } + void canonicalize(ASTBuilder* builder); + + // Overrides should be public so base classes can access + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + static PolynomialIntVal* neg(ASTBuilder* astBuilder, IntVal* base); + static PolynomialIntVal* add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + static PolynomialIntVal* sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + static PolynomialIntVal* mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1); + +}; + /// An unknown integer value indicating an erroneous sub-expression class ErrorIntVal : public IntVal { diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index c96db6f3b..7257790af 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -928,7 +928,7 @@ namespace Slang if (!allConst) { - // TODO(tfoley): We probably want to support a very limited number of operations + // We support a very limited number of operations // on "constants" that aren't actually known, to be able to handle a generic // that takes an integer `N` but then constructs a vector of size `N+1`. // @@ -937,7 +937,46 @@ namespace Slang // need inference to be smart enough to know that `2 + N` and `N + 2` are the // same value, as are `N + M + 1 + 1` and `M + 2 + N`. // - // For now we can just bail in this case. + // This is done by constructing a 'PolynomialIntVal' and rely on its + // `canonicalize` operation. + if (implicitCast) + { + // We cannot support casting in this case. + return nullptr; + } + + auto opName = funcDeclRef.getName(); + + // handle binary operators + if (opName == getName("-")) + { + if (argCount == 1) + { + return PolynomialIntVal::neg(m_astBuilder, argVals[0]); + } + else if (argCount == 2) + { + return PolynomialIntVal::sub(m_astBuilder, argVals[0], argVals[1]); + } + } + else if (opName == getName("+")) + { + if (argCount == 1) + { + return argVals[0]; + } + else if (argCount == 2) + { + return PolynomialIntVal::add(m_astBuilder, argVals[0], argVals[1]); + } + } + else if (opName == getName("*")) + { + if (argCount == 2) + { + return PolynomialIntVal::mul(m_astBuilder, argVals[0], argVals[1]); + } + } return nullptr; } @@ -1110,7 +1149,9 @@ namespace Slang if (auto genericValParamRef = declRef.as()) { // TODO(tfoley): handle the case of non-`int` value parameters... - return m_astBuilder->create(genericValParamRef); + Val* valResult = m_astBuilder->create(genericValParamRef); + valResult = valResult->substitute(m_astBuilder, expr.getSubsts()); + return as(valResult); } // We may also need to check for references to variables that diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index b8789dd76..aa2c69126 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -367,8 +367,15 @@ namespace Slang { return leftVar->declRef.equals(rightVar->declRef); } + else if (auto rightPoly = as(right)) + { + return right->equalsVal(leftVar); + } + } + if (auto leftVar = as(left)) + { + return leftVar->equalsVal(right); } - return false; } diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index da72ea65e..d3be5c4dc 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -243,7 +243,9 @@ String getDeclSignatureString(DeclRef declRef, WorkspaceVersion* version) DiagnosticSink sink; SharedSemanticsContext semanticContext(version->linkage, getModule(varDecl), &sink); SemanticsVisitor semanticsVisitor(&semanticContext); - if (auto intVal = semanticsVisitor.tryFoldIntegerConstantExpression(varDecl->initExpr, nullptr)) + if (auto intVal = semanticsVisitor.tryFoldIntegerConstantExpression( + declRef.substitute(version->linkage->getASTBuilder(), varDecl->initExpr), + nullptr)) { if (auto constantInt = as(intVal)) { @@ -257,6 +259,11 @@ String getDeclSignatureString(DeclRef declRef, WorkspaceVersion* version) sb << constantInt->value; } } + else + { + sb << " = "; + intVal->toText(sb); + } } } } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1285afca8..3d4f0df5d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1287,6 +1287,27 @@ struct ValLoweringVisitor : ValVisitorastBuilder, val->declRef))); } + LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val) + { + auto irBuilder = getBuilder(); + auto constTerm = irBuilder->getIntValue(irBuilder->getIntType(), val->constantTerm); + auto resultVal = constTerm; + for (auto term : val->terms) + { + auto termVal = irBuilder->getIntValue(irBuilder->getIntType(), term->constFactor); + for (auto factor : term->paramFactors) + { + auto factorVal = lowerVal(context, factor->param).val; + for (IntegerLiteralValue i = 0; i < factor->power; i++) + { + termVal = irBuilder->emitMul(factorVal->getDataType(), termVal, factorVal); + } + } + resultVal = irBuilder->emitAdd(termVal->getDataType(), resultVal, termVal); + } + return LoweredValInfo::simple(resultVal); + } + LoweredValInfo visitDeclaredSubtypeWitness(DeclaredSubtypeWitness* val) { return emitDeclRef(context, val->declRef, diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 22af19271..73ccd4726 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1479,6 +1479,18 @@ static LayoutSize GetElementCount(IntVal* val) // return 0; } + else if (auto polyIntVal = as(val)) + { + // TODO: We want to treat the case where the number of + // elements in an array depends on a generic parameter + // much like the case where the number of elements is + // unbounded, *but* we can't just blindly do that because + // an API might disallow unbounded arrays in various + // cases where a generic bound might work (because + // any concrete specialization will have a finite bound...) + // + return 0; + } SLANG_UNEXPECTED("unhandled integer literal kind"); UNREACHABLE_RETURN(LayoutSize(0)); } diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index d56fa885d..4d5bad2d3 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -447,10 +447,28 @@ (Slang::ThisType*)&astNodeType (Slang::AndType*)&astNodeType (Slang::ModifiedType*)&astNodeType + (Slang::Type*)this,nd + + {astNodeType} + + (Slang::GenericSubstitution*)&astNodeType + (Slang::ThisTypeSubstitution*)&astNodeType + + + + {astNodeType} + + + substitutions + outer + this + + + {nameAndLoc.name}: {astNodeType} -- cgit v1.2.3