summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ast-val.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-08-22 09:43:05 -0700
committerGitHub <noreply@github.com>2022-08-22 09:43:05 -0700
commit393185196ed65a9eeaf9502edbf3dcce87337d81 (patch)
tree91c9fa14ddb21d15e6cedf83f7aa6b649e99db86 /source/slang/slang-ast-val.cpp
parent15055d20c143cb398bd3e269541eebf24777390a (diff)
Support compile-time constant int val in the form of polynomials. (#2372)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
-rw-r--r--source/slang/slang-ast-val.cpp454
1 files changed, 454 insertions, 0 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 0f3bd2b3a..16c649f3a 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -576,4 +576,458 @@ Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
return this;
}
+// PolynomialIntVal
+
+bool PolynomialIntVal::_equalsValOverride(Val* val)
+{
+ if (auto genericParamVal = as<GenericParamIntVal>(val))
+ {
+ return constantTerm == 0 && terms.getCount() == 1 &&
+ terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 &&
+ terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) &&
+ terms[0]->paramFactors[0]->power == 1;
+ }
+ else if (auto otherPolynomial = as<PolynomialIntVal>(val))
+ {
+ if (constantTerm != otherPolynomial->constantTerm)
+ return false;
+ if (terms.getCount() != otherPolynomial->terms.getCount())
+ return false;
+ for (Index i = 0; i < terms.getCount(); i++)
+ {
+ auto& thisTerm = *(terms[i]);
+ auto& thatTerm = *(otherPolynomial->terms[i]);
+ if (thisTerm.constFactor != thatTerm.constFactor)
+ return false;
+ if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount())
+ return false;
+ for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++)
+ {
+ if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power)
+ return false;
+ if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param))
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+void PolynomialIntVal::_toTextOverride(StringBuilder& out)
+{
+ for (Index i = 0; i < terms.getCount(); i++)
+ {
+ auto& term = *(terms[i]);
+ if (i > 0)
+ {
+ if (term.constFactor > 0)
+ out << "+";
+ else
+ out << "-";
+ }
+ bool isFirstFactor = true;
+ if (term.constFactor != 1 || term.paramFactors.getCount() == 0)
+ {
+ out << abs(term.constFactor);
+ isFirstFactor = false;
+ }
+ for (Index j = 0; j < term.paramFactors.getCount(); j++)
+ {
+ auto factor = term.paramFactors[j];
+ if (isFirstFactor)
+ {
+ isFirstFactor = false;
+ }
+ else
+ {
+ out << "*";
+ }
+ factor->param->toText(out);
+ if (factor->power != 1)
+ {
+ out << "^^" << factor->power;
+ }
+ }
+ }
+ if (constantTerm > 0)
+ {
+ if (terms.getCount() > 0)
+ {
+ out << "+";
+ }
+ out << constantTerm;
+ }
+ else if (constantTerm < 0)
+ {
+ out << constantTerm;
+ }
+}
+
+HashCode PolynomialIntVal::_getHashCodeOverride()
+{
+ HashCode result = (HashCode)constantTerm;
+ for (auto& term : terms)
+ {
+ if (!term) continue;
+ result = combineHash(result, (HashCode)term->constFactor);
+ for (auto& factor : term->paramFactors)
+ {
+ result = combineHash(result, factor->param->getHashCode());
+ result = combineHash(result, (HashCode)factor->power);
+ }
+ }
+ return result;
+}
+
+Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ IntegerLiteralValue evaluatedConstantTerm = constantTerm;
+ List<PolynomialIntValTerm*> evaluatedTerms;
+ for (auto& term : terms)
+ {
+ IntegerLiteralValue evaluatedTermConstFactor;
+ List<PolynomialIntValFactor*> evaluatedTermParamFactors;
+ evaluatedTermConstFactor = term->constFactor;
+ for (auto& factor : term->paramFactors)
+ {
+ auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff);
+ if (auto genericVal = as<GenericParamIntVal>(substResult))
+ {
+ auto newFactor = astBuilder->create<PolynomialIntValFactor>();
+ newFactor->param = genericVal;
+ newFactor->power = factor->power;
+ evaluatedTermParamFactors.add(newFactor);
+ }
+ else if (auto constantVal = as<ConstantIntVal>(substResult))
+ {
+ evaluatedTermConstFactor *= constantVal->value;
+ }
+ }
+ if (evaluatedTermParamFactors.getCount() == 0)
+ {
+ evaluatedConstantTerm += evaluatedTermConstFactor;
+ }
+ else
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->paramFactors = _Move(evaluatedTermParamFactors);
+ newTerm->constFactor = evaluatedTermConstFactor;
+ evaluatedTerms.add(newTerm);
+ }
+ }
+
+ *ioDiff += diff;
+
+ if (evaluatedTerms.getCount() == 0)
+ return astBuilder->create<ConstantIntVal>(evaluatedConstantTerm);
+ if (diff != 0)
+ {
+ auto newPolynomial = astBuilder->create<PolynomialIntVal>();
+ newPolynomial->constantTerm = evaluatedConstantTerm;
+ newPolynomial->terms = _Move(evaluatedTerms);
+ newPolynomial->canonicalize(astBuilder);
+ return newPolynomial;
+ }
+ return nullptr;
+}
+
+
+// compute val += opreand*multiplier;
+bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier)
+{
+ if (auto genVal = as<GenericParamIntVal>(operand))
+ {
+ auto term = astBuilder->create<PolynomialIntValTerm>();
+ term->constFactor = multiplier;
+ auto factor = astBuilder->create<PolynomialIntValFactor>();
+ factor->power = 1;
+ factor->param = genVal;
+ term->paramFactors.add(factor);
+ val->terms.add(term);
+ return true;
+ }
+ else if (auto c = as<ConstantIntVal>(operand))
+ {
+ val->constantTerm += c->value * multiplier;
+ return true;
+ }
+ else if (auto poly = as<PolynomialIntVal>(operand))
+ {
+ val->constantTerm += poly->constantTerm * multiplier;
+ for (auto term : poly->terms)
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = multiplier * term->constFactor;
+ newTerm->paramFactors = term->paramFactors;
+ val->terms.add(newTerm);
+ }
+ return true;
+ }
+ return false;
+}
+
+PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base)
+{
+ auto result = astBuilder->create<PolynomialIntVal>();
+ if (!addToPolynomialTerm(astBuilder, result, base, -1))
+ return nullptr;
+ result->canonicalize(astBuilder);
+ return result;
+}
+
+PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+{
+ auto result = astBuilder->create<PolynomialIntVal>();
+ if (!addToPolynomialTerm(astBuilder, result, op0, 1))
+ return nullptr;
+ if (!addToPolynomialTerm(astBuilder, result, op1, -1))
+ return nullptr;
+ result->canonicalize(astBuilder);
+ return result;
+}
+
+PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+{
+ auto result = astBuilder->create<PolynomialIntVal>();
+ if (!addToPolynomialTerm(astBuilder, result, op0, 1))
+ return nullptr;
+ if (!addToPolynomialTerm(astBuilder, result, op1, 1))
+ return nullptr;
+ result->canonicalize(astBuilder);
+ return result;
+}
+
+PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1)
+{
+ if (auto poly0 = as<PolynomialIntVal>(op0))
+ {
+ if (auto poly1 = as<PolynomialIntVal>(op1))
+ {
+ auto result = astBuilder->create<PolynomialIntVal>();
+ // add poly0.constant * poly1.constant
+ result->constantTerm = poly0->constantTerm * poly1->constantTerm;
+ // add poly0.constant * poly1.terms
+ if (poly0->constantTerm != 0)
+ {
+ for (auto term : poly1->terms)
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = poly0->constantTerm * term->constFactor;
+ newTerm->paramFactors.addRange(term->paramFactors);
+ result->terms.add(newTerm);
+ }
+ }
+ // add poly1.constant * poly0.terms
+ if (poly1->constantTerm != 0)
+ {
+ for (auto term : poly0->terms)
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = poly1->constantTerm * term->constFactor;
+ newTerm->paramFactors.addRange(term->paramFactors);
+ result->terms.add(newTerm);
+ }
+ }
+ // add poly1.terms * poly0.terms
+ for (auto term0 : poly0->terms)
+ {
+ for (auto term1 : poly1->terms)
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = term0->constFactor * term1->constFactor;
+ newTerm->paramFactors.addRange(term0->paramFactors);
+ newTerm->paramFactors.addRange(term1->paramFactors);
+ result->terms.add(newTerm);
+ }
+ }
+ result->canonicalize(astBuilder);
+ return result;
+ }
+ else if (auto val1 = as<GenericParamIntVal>(op1))
+ {
+ auto result = astBuilder->create<PolynomialIntVal>();
+ result->constantTerm = 0;
+ auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ factor1->power = 1;
+ factor1->param = val1;
+ if (poly0->constantTerm != 0)
+ {
+ auto term0 = astBuilder->create<PolynomialIntValTerm>();
+ term0->constFactor = poly0->constantTerm;
+ term0->paramFactors.add(factor1);
+ result->terms.add(term0);
+ }
+ for (auto term : poly0->terms)
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = term->constFactor;
+ newTerm->paramFactors.addRange(term->paramFactors);
+ newTerm->paramFactors.add(factor1);
+ result->terms.add(newTerm);
+ }
+ result->canonicalize(astBuilder);
+ return result;
+ }
+ else if (auto cVal1 = as<ConstantIntVal>(op1))
+ {
+ auto result = astBuilder->create<PolynomialIntVal>();
+ result->constantTerm = poly0->constantTerm * cVal1->value;
+ auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ for (auto term : poly0->terms)
+ {
+ auto newTerm = astBuilder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = term->constFactor * cVal1->value;
+ newTerm->paramFactors.addRange(term->paramFactors);
+ newTerm->paramFactors.add(factor1);
+ result->terms.add(newTerm);
+ }
+ result->canonicalize(astBuilder);
+ return result;
+ }
+ else
+ return nullptr;
+ }
+ else if (auto val0 = as<GenericParamIntVal>(op0))
+ {
+ if (auto poly1 = as<PolynomialIntVal>(op1))
+ {
+ return mul(astBuilder, op1, op0);
+ }
+ else if (auto val1 = as<GenericParamIntVal>(op1))
+ {
+ auto result = astBuilder->create<PolynomialIntVal>();
+ auto term = astBuilder->create<PolynomialIntValTerm>();
+ term->constFactor = 1;
+ auto factor0 = astBuilder->create<PolynomialIntValFactor>();
+ factor0->power = 1;
+ factor0->param = val0;
+ term->paramFactors.add(factor0);
+ auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ factor1->power = 1;
+ factor1->param = val1;
+ term->paramFactors.add(factor1);
+ result->terms.add(term);
+ result->canonicalize(astBuilder);
+ return result;
+ }
+ else if (auto cVal1 = as<ConstantIntVal>(op1))
+ {
+ auto result = astBuilder->create<PolynomialIntVal>();
+ auto term = astBuilder->create<PolynomialIntValTerm>();
+ term->constFactor = cVal1->value;
+ auto factor0 = astBuilder->create<PolynomialIntValFactor>();
+ factor0->power = 1;
+ factor0->param = val0;
+ term->paramFactors.add(factor0);
+ result->terms.add(term);
+ result->canonicalize(astBuilder);
+ return result;
+ }
+ }
+ else if (as<ConstantIntVal>(op0))
+ {
+ return mul(astBuilder, op1, op0);
+ }
+ return nullptr;
+}
+
+void PolynomialIntVal::canonicalize(ASTBuilder* builder)
+{
+ List<PolynomialIntValTerm*> newTerms;
+ IntegerLiteralValue newConstantTerm = constantTerm;
+ auto addTerm = [&](PolynomialIntValTerm* newTerm)
+ {
+ for (auto term : newTerms)
+ {
+ if (term->canCombineWith(*newTerm))
+ {
+ term->constFactor += newTerm->constFactor;
+ return;
+ }
+ }
+ newTerms.add(newTerm);
+ };
+ for (auto term : terms)
+ {
+ if (term->constFactor == 0)
+ continue;
+ List<PolynomialIntValFactor*> newFactors;
+ List<bool> factorIsDifferent;
+ for (Index i = 0; i < term->paramFactors.getCount(); i++)
+ {
+ auto factor = term->paramFactors[i];
+ bool factorFound = false;
+ for (Index j = 0; j < newFactors.getCount(); j++)
+ {
+ auto& newFactor = newFactors[j];
+ if (factor->param->equalsVal(newFactor->param))
+ {
+ if (!factorIsDifferent[j])
+ {
+ factorIsDifferent[j] = true;
+ auto clonedFactor = builder->create<PolynomialIntValFactor>();
+ clonedFactor->param = newFactor->param;
+ clonedFactor->power = newFactor->power;
+ newFactor = clonedFactor;
+ }
+ newFactor->power += factor->power;
+ factorFound = true;
+ break;
+ }
+ }
+ if (!factorFound)
+ {
+ newFactors.add(factor);
+ factorIsDifferent.add(false);
+ }
+ }
+ List<PolynomialIntValFactor*> newFactors2;
+ for (auto factor : newFactors)
+ {
+ if (factor->power != 0)
+ newFactors2.add(factor);
+ }
+ if (newFactors2.getCount() == 0)
+ {
+ newConstantTerm += term->constFactor;
+ continue;
+ }
+ newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; });
+ bool isDifferent = false;
+ if (newFactors2.getCount() != term->paramFactors.getCount())
+ isDifferent = true;
+ if (!isDifferent)
+ {
+ for (Index i = 0; i < term->paramFactors.getCount(); i++)
+ if (term->paramFactors[i] != newFactors2[i])
+ {
+ isDifferent = true;
+ break;
+ }
+ }
+ if (!isDifferent)
+ {
+ addTerm(term);
+ }
+ else
+ {
+ auto newTerm = builder->create<PolynomialIntValTerm>();
+ newTerm->constFactor = term->constFactor;
+ newTerm->paramFactors = _Move(newFactors2);
+ addTerm(newTerm);
+ }
+ }
+ List<PolynomialIntValTerm*> newTerms2;
+ for (auto term : newTerms)
+ {
+ if (term->constFactor == 0)
+ continue;
+ newTerms2.add(term);
+ }
+ newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
+ terms = _Move(newTerms2);
+}
+
} // namespace Slang