diff options
| author | Yong He <yonghe@outlook.com> | 2022-08-22 09:43:05 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-22 09:43:05 -0700 |
| commit | 393185196ed65a9eeaf9502edbf3dcce87337d81 (patch) | |
| tree | 91c9fa14ddb21d15e6cedf83f7aa6b649e99db86 /source/slang/slang-ast-val.cpp | |
| parent | 15055d20c143cb398bd3e269541eebf24777390a (diff) | |
Support compile-time constant int val in the form of polynomials. (#2372)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 454 |
1 files changed, 454 insertions, 0 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 0f3bd2b3a..16c649f3a 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -576,4 +576,458 @@ Val* SNormModifierVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut return this; } +// PolynomialIntVal + +bool PolynomialIntVal::_equalsValOverride(Val* val) +{ + if (auto genericParamVal = as<GenericParamIntVal>(val)) + { + return constantTerm == 0 && terms.getCount() == 1 && + terms[0]->paramFactors.getCount() == 1 && terms[0]->constFactor == 1 && + terms[0]->paramFactors[0]->param->equalsVal(genericParamVal) && + terms[0]->paramFactors[0]->power == 1; + } + else if (auto otherPolynomial = as<PolynomialIntVal>(val)) + { + if (constantTerm != otherPolynomial->constantTerm) + return false; + if (terms.getCount() != otherPolynomial->terms.getCount()) + return false; + for (Index i = 0; i < terms.getCount(); i++) + { + auto& thisTerm = *(terms[i]); + auto& thatTerm = *(otherPolynomial->terms[i]); + if (thisTerm.constFactor != thatTerm.constFactor) + return false; + if (thisTerm.paramFactors.getCount() != thatTerm.paramFactors.getCount()) + return false; + for (Index j = 0; j < thisTerm.paramFactors.getCount(); j++) + { + if (thisTerm.paramFactors[j]->power != thatTerm.paramFactors[j]->power) + return false; + if (!thisTerm.paramFactors[j]->param->equalsVal(thatTerm.paramFactors[j]->param)) + return false; + } + } + return true; + } + return false; +} + +void PolynomialIntVal::_toTextOverride(StringBuilder& out) +{ + for (Index i = 0; i < terms.getCount(); i++) + { + auto& term = *(terms[i]); + if (i > 0) + { + if (term.constFactor > 0) + out << "+"; + else + out << "-"; + } + bool isFirstFactor = true; + if (term.constFactor != 1 || term.paramFactors.getCount() == 0) + { + out << abs(term.constFactor); + isFirstFactor = false; + } + for (Index j = 0; j < term.paramFactors.getCount(); j++) + { + auto factor = term.paramFactors[j]; + if (isFirstFactor) + { + isFirstFactor = false; + } + else + { + out << "*"; + } + factor->param->toText(out); + if (factor->power != 1) + { + out << "^^" << factor->power; + } + } + } + if (constantTerm > 0) + { + if (terms.getCount() > 0) + { + out << "+"; + } + out << constantTerm; + } + else if (constantTerm < 0) + { + out << constantTerm; + } +} + +HashCode PolynomialIntVal::_getHashCodeOverride() +{ + HashCode result = (HashCode)constantTerm; + for (auto& term : terms) + { + if (!term) continue; + result = combineHash(result, (HashCode)term->constFactor); + for (auto& factor : term->paramFactors) + { + result = combineHash(result, factor->param->getHashCode()); + result = combineHash(result, (HashCode)factor->power); + } + } + return result; +} + +Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + IntegerLiteralValue evaluatedConstantTerm = constantTerm; + List<PolynomialIntValTerm*> evaluatedTerms; + for (auto& term : terms) + { + IntegerLiteralValue evaluatedTermConstFactor; + List<PolynomialIntValFactor*> evaluatedTermParamFactors; + evaluatedTermConstFactor = term->constFactor; + for (auto& factor : term->paramFactors) + { + auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); + if (auto genericVal = as<GenericParamIntVal>(substResult)) + { + auto newFactor = astBuilder->create<PolynomialIntValFactor>(); + newFactor->param = genericVal; + newFactor->power = factor->power; + evaluatedTermParamFactors.add(newFactor); + } + else if (auto constantVal = as<ConstantIntVal>(substResult)) + { + evaluatedTermConstFactor *= constantVal->value; + } + } + if (evaluatedTermParamFactors.getCount() == 0) + { + evaluatedConstantTerm += evaluatedTermConstFactor; + } + else + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->paramFactors = _Move(evaluatedTermParamFactors); + newTerm->constFactor = evaluatedTermConstFactor; + evaluatedTerms.add(newTerm); + } + } + + *ioDiff += diff; + + if (evaluatedTerms.getCount() == 0) + return astBuilder->create<ConstantIntVal>(evaluatedConstantTerm); + if (diff != 0) + { + auto newPolynomial = astBuilder->create<PolynomialIntVal>(); + newPolynomial->constantTerm = evaluatedConstantTerm; + newPolynomial->terms = _Move(evaluatedTerms); + newPolynomial->canonicalize(astBuilder); + return newPolynomial; + } + return nullptr; +} + + +// compute val += opreand*multiplier; +bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) +{ + if (auto genVal = as<GenericParamIntVal>(operand)) + { + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = multiplier; + auto factor = astBuilder->create<PolynomialIntValFactor>(); + factor->power = 1; + factor->param = genVal; + term->paramFactors.add(factor); + val->terms.add(term); + return true; + } + else if (auto c = as<ConstantIntVal>(operand)) + { + val->constantTerm += c->value * multiplier; + return true; + } + else if (auto poly = as<PolynomialIntVal>(operand)) + { + val->constantTerm += poly->constantTerm * multiplier; + for (auto term : poly->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = multiplier * term->constFactor; + newTerm->paramFactors = term->paramFactors; + val->terms.add(newTerm); + } + return true; + } + return false; +} + +PolynomialIntVal* PolynomialIntVal::neg(ASTBuilder* astBuilder, IntVal* base) +{ + auto result = astBuilder->create<PolynomialIntVal>(); + if (!addToPolynomialTerm(astBuilder, result, base, -1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::sub(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + auto result = astBuilder->create<PolynomialIntVal>(); + if (!addToPolynomialTerm(astBuilder, result, op0, 1)) + return nullptr; + if (!addToPolynomialTerm(astBuilder, result, op1, -1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::add(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + auto result = astBuilder->create<PolynomialIntVal>(); + if (!addToPolynomialTerm(astBuilder, result, op0, 1)) + return nullptr; + if (!addToPolynomialTerm(astBuilder, result, op1, 1)) + return nullptr; + result->canonicalize(astBuilder); + return result; +} + +PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, IntVal* op1) +{ + if (auto poly0 = as<PolynomialIntVal>(op0)) + { + if (auto poly1 = as<PolynomialIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + // add poly0.constant * poly1.constant + result->constantTerm = poly0->constantTerm * poly1->constantTerm; + // add poly0.constant * poly1.terms + if (poly0->constantTerm != 0) + { + for (auto term : poly1->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = poly0->constantTerm * term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + result->terms.add(newTerm); + } + } + // add poly1.constant * poly0.terms + if (poly1->constantTerm != 0) + { + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = poly1->constantTerm * term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + result->terms.add(newTerm); + } + } + // add poly1.terms * poly0.terms + for (auto term0 : poly0->terms) + { + for (auto term1 : poly1->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term0->constFactor * term1->constFactor; + newTerm->paramFactors.addRange(term0->paramFactors); + newTerm->paramFactors.addRange(term1->paramFactors); + result->terms.add(newTerm); + } + } + result->canonicalize(astBuilder); + return result; + } + else if (auto val1 = as<GenericParamIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + result->constantTerm = 0; + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + if (poly0->constantTerm != 0) + { + auto term0 = astBuilder->create<PolynomialIntValTerm>(); + term0->constFactor = poly0->constantTerm; + term0->paramFactors.add(factor1); + result->terms.add(term0); + } + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term->constFactor; + newTerm->paramFactors.addRange(term->paramFactors); + newTerm->paramFactors.add(factor1); + result->terms.add(newTerm); + } + result->canonicalize(astBuilder); + return result; + } + else if (auto cVal1 = as<ConstantIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + result->constantTerm = poly0->constantTerm * cVal1->value; + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + for (auto term : poly0->terms) + { + auto newTerm = astBuilder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term->constFactor * cVal1->value; + newTerm->paramFactors.addRange(term->paramFactors); + newTerm->paramFactors.add(factor1); + result->terms.add(newTerm); + } + result->canonicalize(astBuilder); + return result; + } + else + return nullptr; + } + else if (auto val0 = as<GenericParamIntVal>(op0)) + { + if (auto poly1 = as<PolynomialIntVal>(op1)) + { + return mul(astBuilder, op1, op0); + } + else if (auto val1 = as<GenericParamIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = 1; + auto factor0 = astBuilder->create<PolynomialIntValFactor>(); + factor0->power = 1; + factor0->param = val0; + term->paramFactors.add(factor0); + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + term->paramFactors.add(factor1); + result->terms.add(term); + result->canonicalize(astBuilder); + return result; + } + else if (auto cVal1 = as<ConstantIntVal>(op1)) + { + auto result = astBuilder->create<PolynomialIntVal>(); + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = cVal1->value; + auto factor0 = astBuilder->create<PolynomialIntValFactor>(); + factor0->power = 1; + factor0->param = val0; + term->paramFactors.add(factor0); + result->terms.add(term); + result->canonicalize(astBuilder); + return result; + } + } + else if (as<ConstantIntVal>(op0)) + { + return mul(astBuilder, op1, op0); + } + return nullptr; +} + +void PolynomialIntVal::canonicalize(ASTBuilder* builder) +{ + List<PolynomialIntValTerm*> newTerms; + IntegerLiteralValue newConstantTerm = constantTerm; + auto addTerm = [&](PolynomialIntValTerm* newTerm) + { + for (auto term : newTerms) + { + if (term->canCombineWith(*newTerm)) + { + term->constFactor += newTerm->constFactor; + return; + } + } + newTerms.add(newTerm); + }; + for (auto term : terms) + { + if (term->constFactor == 0) + continue; + List<PolynomialIntValFactor*> newFactors; + List<bool> factorIsDifferent; + for (Index i = 0; i < term->paramFactors.getCount(); i++) + { + auto factor = term->paramFactors[i]; + bool factorFound = false; + for (Index j = 0; j < newFactors.getCount(); j++) + { + auto& newFactor = newFactors[j]; + if (factor->param->equalsVal(newFactor->param)) + { + if (!factorIsDifferent[j]) + { + factorIsDifferent[j] = true; + auto clonedFactor = builder->create<PolynomialIntValFactor>(); + clonedFactor->param = newFactor->param; + clonedFactor->power = newFactor->power; + newFactor = clonedFactor; + } + newFactor->power += factor->power; + factorFound = true; + break; + } + } + if (!factorFound) + { + newFactors.add(factor); + factorIsDifferent.add(false); + } + } + List<PolynomialIntValFactor*> newFactors2; + for (auto factor : newFactors) + { + if (factor->power != 0) + newFactors2.add(factor); + } + if (newFactors2.getCount() == 0) + { + newConstantTerm += term->constFactor; + continue; + } + newFactors2.sort([](PolynomialIntValFactor* t1, PolynomialIntValFactor* t2) {return *t1 < *t2; }); + bool isDifferent = false; + if (newFactors2.getCount() != term->paramFactors.getCount()) + isDifferent = true; + if (!isDifferent) + { + for (Index i = 0; i < term->paramFactors.getCount(); i++) + if (term->paramFactors[i] != newFactors2[i]) + { + isDifferent = true; + break; + } + } + if (!isDifferent) + { + addTerm(term); + } + else + { + auto newTerm = builder->create<PolynomialIntValTerm>(); + newTerm->constFactor = term->constFactor; + newTerm->paramFactors = _Move(newFactors2); + addTerm(newTerm); + } + } + List<PolynomialIntValTerm*> newTerms2; + for (auto term : newTerms) + { + if (term->constFactor == 0) + continue; + newTerms2.add(term); + } + newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); + terms = _Move(newTerms2); +} + } // namespace Slang |
