diff options
| author | Yong He <yonghe@outlook.com> | 2022-08-24 10:56:53 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-24 10:56:53 -0700 |
| commit | d245c72f2a92a74ccda83f41758c1948ae5132d3 (patch) | |
| tree | f036e1f2afb7febe2de9b09990bcde6c04f3bad1 /source/slang/slang-ast-val.cpp | |
| parent | 0b808453407f8feef8574cae99afd90771712185 (diff) | |
Compiler time evaluation of all int and bool operators. (#2376)
* Compiler time evaluation of all int and bool operators.
* Fix linux compile error.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 316 |
1 files changed, 265 insertions, 51 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 16c649f3a..4ed69e282 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -4,7 +4,7 @@ #include <typeinfo> #include "slang-generated-ast-macro.h" - +#include "slang-diagnostics.h" #include "slang-syntax.h" namespace Slang { @@ -693,17 +693,18 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut for (auto& factor : term->paramFactors) { auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); - if (auto genericVal = as<GenericParamIntVal>(substResult)) + + if (auto constantVal = as<ConstantIntVal>(substResult)) + { + evaluatedTermConstFactor *= constantVal->value; + } + else if (auto intResult = as<IntVal>(substResult)) { auto newFactor = astBuilder->create<PolynomialIntValFactor>(); - newFactor->param = genericVal; + newFactor->param = intResult; newFactor->power = factor->power; evaluatedTermParamFactors.add(newFactor); } - else if (auto constantVal = as<ConstantIntVal>(substResult)) - { - evaluatedTermConstFactor *= constantVal->value; - } } if (evaluatedTermParamFactors.getCount() == 0) { @@ -727,28 +728,16 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut auto newPolynomial = astBuilder->create<PolynomialIntVal>(); newPolynomial->constantTerm = evaluatedConstantTerm; newPolynomial->terms = _Move(evaluatedTerms); - newPolynomial->canonicalize(astBuilder); - return newPolynomial; + return newPolynomial->canonicalize(astBuilder); } - return nullptr; + return this; } // compute val += opreand*multiplier; bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) { - if (auto genVal = as<GenericParamIntVal>(operand)) - { - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = multiplier; - auto factor = astBuilder->create<PolynomialIntValFactor>(); - factor->power = 1; - factor->param = genVal; - term->paramFactors.add(factor); - val->terms.add(term); - return true; - } - else if (auto c = as<ConstantIntVal>(operand)) + if (auto c = as<ConstantIntVal>(operand)) { val->constantTerm += c->value * multiplier; return true; @@ -765,6 +754,17 @@ bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* } return true; } + else if (auto genVal = as<IntVal>(operand)) + { + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = multiplier; + auto factor = astBuilder->create<PolynomialIntValFactor>(); + factor->power = 1; + factor->param = genVal; + term->paramFactors.add(factor); + val->terms.add(term); + return true; + } return false; } @@ -845,24 +845,15 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int result->canonicalize(astBuilder); return result; } - else if (auto val1 = as<GenericParamIntVal>(op1)) + else if (auto cVal1 = as<ConstantIntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); - result->constantTerm = 0; + result->constantTerm = poly0->constantTerm * cVal1->value; auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - if (poly0->constantTerm != 0) - { - auto term0 = astBuilder->create<PolynomialIntValTerm>(); - term0->constFactor = poly0->constantTerm; - term0->paramFactors.add(factor1); - result->terms.add(term0); - } for (auto term : poly0->terms) { auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor; + newTerm->constFactor = term->constFactor * cVal1->value; newTerm->paramFactors.addRange(term->paramFactors); newTerm->paramFactors.add(factor1); result->terms.add(newTerm); @@ -870,15 +861,24 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int result->canonicalize(astBuilder); return result; } - else if (auto cVal1 = as<ConstantIntVal>(op1)) + else if (auto val1 = as<IntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); - result->constantTerm = poly0->constantTerm * cVal1->value; + result->constantTerm = 0; auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + if (poly0->constantTerm != 0) + { + auto term0 = astBuilder->create<PolynomialIntValTerm>(); + term0->constFactor = poly0->constantTerm; + term0->paramFactors.add(factor1); + result->terms.add(term0); + } for (auto term : poly0->terms) { auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor * cVal1->value; + newTerm->constFactor = term->constFactor; newTerm->paramFactors.addRange(term->paramFactors); newTerm->paramFactors.add(factor1); result->terms.add(newTerm); @@ -889,51 +889,51 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int else return nullptr; } - else if (auto val0 = as<GenericParamIntVal>(op0)) + else if (as<ConstantIntVal>(op0)) + { + return mul(astBuilder, op1, op0); + } + else if (auto val0 = as<IntVal>(op0)) { if (auto poly1 = as<PolynomialIntVal>(op1)) { return mul(astBuilder, op1, op0); } - else if (auto val1 = as<GenericParamIntVal>(op1)) + else if (auto cVal1 = as<ConstantIntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = 1; + term->constFactor = cVal1->value; auto factor0 = astBuilder->create<PolynomialIntValFactor>(); factor0->power = 1; factor0->param = val0; term->paramFactors.add(factor0); - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - term->paramFactors.add(factor1); result->terms.add(term); result->canonicalize(astBuilder); return result; } - else if (auto cVal1 = as<ConstantIntVal>(op1)) + else if (auto val1 = as<IntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = cVal1->value; + term->constFactor = 1; auto factor0 = astBuilder->create<PolynomialIntValFactor>(); factor0->power = 1; factor0->param = val0; term->paramFactors.add(factor0); + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + term->paramFactors.add(factor1); result->terms.add(term); result->canonicalize(astBuilder); return result; } } - else if (as<ConstantIntVal>(op0)) - { - return mul(astBuilder, op1, op0); - } return nullptr; } -void PolynomialIntVal::canonicalize(ASTBuilder* builder) +IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) { List<PolynomialIntValTerm*> newTerms; IntegerLiteralValue newConstantTerm = constantTerm; @@ -1028,6 +1028,220 @@ void PolynomialIntVal::canonicalize(ASTBuilder* builder) } newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); terms = _Move(newTerms2); + constantTerm = newConstantTerm; + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 && + terms[0]->paramFactors[0]->power == 1) + { + return terms[0]->paramFactors[0]->param; + } + if (terms.getCount() == 0) + return builder->create<ConstantIntVal>(constantTerm); + return this; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SomeIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool SomeIntVal::_equalsValOverride(Val* val) +{ + if (auto someIntVal = as<SomeIntVal>(val)) + { + if (!funcDeclRef.equals(someIntVal->funcDeclRef)) + return false; + if (args.getCount() != someIntVal->args.getCount()) + return false; + for (Index i = 0; i < args.getCount(); i++) + { + if (!args[i]->equalsVal(someIntVal->args[i])) + return false; + } + return true; + } + return false; +} + +void SomeIntVal::_toTextOverride(StringBuilder& out) +{ + auto argToText = [&](int index) + { + if (as<PolynomialIntVal>(args[index]) || as<SomeIntVal>(args[index])) + { + out << "("; + args[index]->toText(out); + out << ")"; + } + else + { + args[index]->toText(out); + } + }; + Name* name = funcDeclRef.getName(); + if (args.getCount() == 2) + { + argToText(0); + out << (name ? name->text : ""); + argToText(1);; + } + else if (args.getCount() == 1) + { + out << (name ? name->text : ""); + argToText(0); + } + else if (name && name->text == "?:") + { + argToText(0); + out << "?"; + argToText(1); + out << ":"; + argToText(2); + } + else + { + if (name) + { + out << name->text; + } + out << "("; + for (Index i = 0; i < args.getCount(); i++) + { + if (i > 0) out << ", "; + args[i]->toText(out); + } + out << ")"; + } +} + +HashCode SomeIntVal::_getHashCodeOverride() +{ + HashCode result = funcDeclRef.getHashCode(); + for (auto arg : args) + { + result = combineHash(result, arg->getHashCode()); + } + return result; +} + +static bool nameIs(Name* name, const char* val) +{ + if (name && name->text.getUnownedSlice() == val) + return true; + return false; +} + +Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) +{ + // Are all args const now? + List<ConstantIntVal*> constArgs; + bool allConst = true; + for (auto arg : newArgs) + { + if (auto c = as<ConstantIntVal>(arg)) + { + constArgs.add(c); + } + else + { + allConst = false; + break; + } + } + if (allConst) + { + // Evaluate the function. + auto opName = newFuncDecl.getName(); + IntegerLiteralValue resultValue = 0; + if (nameIs(opName, "==")) + { + resultValue = constArgs[0]->value / constArgs[1]->value; + } +#define BINARY_OPERATOR_CASE(op) \ + else if (nameIs(opName, #op)) \ + { \ + resultValue = constArgs[0]->value op constArgs[1]->value; \ + } + BINARY_OPERATOR_CASE(>=) + BINARY_OPERATOR_CASE(<=) + BINARY_OPERATOR_CASE(>) + BINARY_OPERATOR_CASE(<) + BINARY_OPERATOR_CASE(!=) + BINARY_OPERATOR_CASE(<<) + BINARY_OPERATOR_CASE(>>) + BINARY_OPERATOR_CASE(&) + BINARY_OPERATOR_CASE(|) + BINARY_OPERATOR_CASE(^) +#undef BINARY_OPERATOR_CASE +#define DIV_OPERATOR_CASE(op) \ + else if (nameIs(opName, #op)) \ + { \ + if (constArgs[1]->value == 0) \ + { \ + if (sink) \ + sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ + return nullptr; \ + } \ + resultValue = constArgs[0]->value op constArgs[1]->value; \ + } + DIV_OPERATOR_CASE(/) + DIV_OPERATOR_CASE(%) +#undef DIV_OPERATOR_CASE +#define LOGICAL_OPERATOR_CASE(op) \ + else if (nameIs(opName, #op)) \ + { \ + resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \ + } + LOGICAL_OPERATOR_CASE(&&) + LOGICAL_OPERATOR_CASE(|| ) +#undef LOGICAL_OPERATOR_CASE + else if (nameIs(opName, "!")) + { + resultValue = ((constArgs[0]->value != 0) ? 1 : 0); + } + else if (nameIs(opName, "~")) + { + resultValue = ~constArgs[0]->value; + } + else if (nameIs(opName, "?:")) + { + resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value; + } + else + { + SLANG_UNREACHABLE("constant folding of SomeIntVal"); + } + return astBuilder->create<ConstantIntVal>(resultValue); + } + return nullptr; +} + +Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff); + List<IntVal*> newArgs; + for (auto& arg : args) + { + auto substArg = arg->substituteImpl(astBuilder, subst, &diff); + if (substArg != arg) + diff++; + newArgs.add(as<IntVal>(substArg)); + } + *ioDiff += diff; + if (diff) + { + // TODO: report diagnostics back. + auto newVal = tryFoldImpl(astBuilder, newFuncDeclRef, newArgs, nullptr); + if (newVal) + return newVal; + else + { + auto result = astBuilder->create<SomeIntVal>(); + result->args = _Move(newArgs); + result->funcDeclRef = newFuncDeclRef; + result->funcType = funcType; + return result; + } + } + // Nothing found: don't substitute. + return this; } } // namespace Slang |
