summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ast-val.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-08-24 10:56:53 -0700
committerGitHub <noreply@github.com>2022-08-24 10:56:53 -0700
commitd245c72f2a92a74ccda83f41758c1948ae5132d3 (patch)
treef036e1f2afb7febe2de9b09990bcde6c04f3bad1 /source/slang/slang-ast-val.cpp
parent0b808453407f8feef8574cae99afd90771712185 (diff)
Compiler time evaluation of all int and bool operators. (#2376)
* Compiler time evaluation of all int and bool operators. * Fix linux compile error. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
-rw-r--r--source/slang/slang-ast-val.cpp316
1 files changed, 265 insertions, 51 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 16c649f3a..4ed69e282 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -4,7 +4,7 @@
#include <typeinfo>
#include "slang-generated-ast-macro.h"
-
+#include "slang-diagnostics.h"
#include "slang-syntax.h"
namespace Slang {
@@ -693,17 +693,18 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
for (auto& factor : term->paramFactors)
{
auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff);
- if (auto genericVal = as<GenericParamIntVal>(substResult))
+
+ if (auto constantVal = as<ConstantIntVal>(substResult))
+ {
+ evaluatedTermConstFactor *= constantVal->value;
+ }
+ else if (auto intResult = as<IntVal>(substResult))
{
auto newFactor = astBuilder->create<PolynomialIntValFactor>();
- newFactor->param = genericVal;
+ newFactor->param = intResult;
newFactor->power = factor->power;
evaluatedTermParamFactors.add(newFactor);
}
- else if (auto constantVal = as<ConstantIntVal>(substResult))
- {
- evaluatedTermConstFactor *= constantVal->value;
- }
}
if (evaluatedTermParamFactors.getCount() == 0)
{
@@ -727,28 +728,16 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
auto newPolynomial = astBuilder->create<PolynomialIntVal>();
newPolynomial->constantTerm = evaluatedConstantTerm;
newPolynomial->terms = _Move(evaluatedTerms);
- newPolynomial->canonicalize(astBuilder);
- return newPolynomial;
+ return newPolynomial->canonicalize(astBuilder);
}
- return nullptr;
+ return this;
}
// compute val += opreand*multiplier;
bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier)
{
- if (auto genVal = as<GenericParamIntVal>(operand))
- {
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = multiplier;
- auto factor = astBuilder->create<PolynomialIntValFactor>();
- factor->power = 1;
- factor->param = genVal;
- term->paramFactors.add(factor);
- val->terms.add(term);
- return true;
- }
- else if (auto c = as<ConstantIntVal>(operand))
+ if (auto c = as<ConstantIntVal>(operand))
{
val->constantTerm += c->value * multiplier;
return true;
@@ -765,6 +754,17 @@ bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal*
}
return true;
}
+ else if (auto genVal = as<IntVal>(operand))
+ {
+ auto term = astBuilder->create<PolynomialIntValTerm>();
+ term->constFactor = multiplier;
+ auto factor = astBuilder->create<PolynomialIntValFactor>();
+ factor->power = 1;
+ factor->param = genVal;
+ term->paramFactors.add(factor);
+ val->terms.add(term);
+ return true;
+ }
return false;
}
@@ -845,24 +845,15 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
result->canonicalize(astBuilder);
return result;
}
- else if (auto val1 = as<GenericParamIntVal>(op1))
+ else if (auto cVal1 = as<ConstantIntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
- result->constantTerm = 0;
+ result->constantTerm = poly0->constantTerm * cVal1->value;
auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- if (poly0->constantTerm != 0)
- {
- auto term0 = astBuilder->create<PolynomialIntValTerm>();
- term0->constFactor = poly0->constantTerm;
- term0->paramFactors.add(factor1);
- result->terms.add(term0);
- }
for (auto term : poly0->terms)
{
auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor;
+ newTerm->constFactor = term->constFactor * cVal1->value;
newTerm->paramFactors.addRange(term->paramFactors);
newTerm->paramFactors.add(factor1);
result->terms.add(newTerm);
@@ -870,15 +861,24 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
result->canonicalize(astBuilder);
return result;
}
- else if (auto cVal1 = as<ConstantIntVal>(op1))
+ else if (auto val1 = as<IntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
- result->constantTerm = poly0->constantTerm * cVal1->value;
+ result->constantTerm = 0;
auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ factor1->power = 1;
+ factor1->param = val1;
+ if (poly0->constantTerm != 0)
+ {
+ auto term0 = astBuilder->create<PolynomialIntValTerm>();
+ term0->constFactor = poly0->constantTerm;
+ term0->paramFactors.add(factor1);
+ result->terms.add(term0);
+ }
for (auto term : poly0->terms)
{
auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor * cVal1->value;
+ newTerm->constFactor = term->constFactor;
newTerm->paramFactors.addRange(term->paramFactors);
newTerm->paramFactors.add(factor1);
result->terms.add(newTerm);
@@ -889,51 +889,51 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
else
return nullptr;
}
- else if (auto val0 = as<GenericParamIntVal>(op0))
+ else if (as<ConstantIntVal>(op0))
+ {
+ return mul(astBuilder, op1, op0);
+ }
+ else if (auto val0 = as<IntVal>(op0))
{
if (auto poly1 = as<PolynomialIntVal>(op1))
{
return mul(astBuilder, op1, op0);
}
- else if (auto val1 = as<GenericParamIntVal>(op1))
+ else if (auto cVal1 = as<ConstantIntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = 1;
+ term->constFactor = cVal1->value;
auto factor0 = astBuilder->create<PolynomialIntValFactor>();
factor0->power = 1;
factor0->param = val0;
term->paramFactors.add(factor0);
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- term->paramFactors.add(factor1);
result->terms.add(term);
result->canonicalize(astBuilder);
return result;
}
- else if (auto cVal1 = as<ConstantIntVal>(op1))
+ else if (auto val1 = as<IntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = cVal1->value;
+ term->constFactor = 1;
auto factor0 = astBuilder->create<PolynomialIntValFactor>();
factor0->power = 1;
factor0->param = val0;
term->paramFactors.add(factor0);
+ auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ factor1->power = 1;
+ factor1->param = val1;
+ term->paramFactors.add(factor1);
result->terms.add(term);
result->canonicalize(astBuilder);
return result;
}
}
- else if (as<ConstantIntVal>(op0))
- {
- return mul(astBuilder, op1, op0);
- }
return nullptr;
}
-void PolynomialIntVal::canonicalize(ASTBuilder* builder)
+IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
{
List<PolynomialIntValTerm*> newTerms;
IntegerLiteralValue newConstantTerm = constantTerm;
@@ -1028,6 +1028,220 @@ void PolynomialIntVal::canonicalize(ASTBuilder* builder)
}
newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
terms = _Move(newTerms2);
+ constantTerm = newConstantTerm;
+ if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 &&
+ terms[0]->paramFactors[0]->power == 1)
+ {
+ return terms[0]->paramFactors[0]->param;
+ }
+ if (terms.getCount() == 0)
+ return builder->create<ConstantIntVal>(constantTerm);
+ return this;
+}
+
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SomeIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+bool SomeIntVal::_equalsValOverride(Val* val)
+{
+ if (auto someIntVal = as<SomeIntVal>(val))
+ {
+ if (!funcDeclRef.equals(someIntVal->funcDeclRef))
+ return false;
+ if (args.getCount() != someIntVal->args.getCount())
+ return false;
+ for (Index i = 0; i < args.getCount(); i++)
+ {
+ if (!args[i]->equalsVal(someIntVal->args[i]))
+ return false;
+ }
+ return true;
+ }
+ return false;
+}
+
+void SomeIntVal::_toTextOverride(StringBuilder& out)
+{
+ auto argToText = [&](int index)
+ {
+ if (as<PolynomialIntVal>(args[index]) || as<SomeIntVal>(args[index]))
+ {
+ out << "(";
+ args[index]->toText(out);
+ out << ")";
+ }
+ else
+ {
+ args[index]->toText(out);
+ }
+ };
+ Name* name = funcDeclRef.getName();
+ if (args.getCount() == 2)
+ {
+ argToText(0);
+ out << (name ? name->text : "");
+ argToText(1);;
+ }
+ else if (args.getCount() == 1)
+ {
+ out << (name ? name->text : "");
+ argToText(0);
+ }
+ else if (name && name->text == "?:")
+ {
+ argToText(0);
+ out << "?";
+ argToText(1);
+ out << ":";
+ argToText(2);
+ }
+ else
+ {
+ if (name)
+ {
+ out << name->text;
+ }
+ out << "(";
+ for (Index i = 0; i < args.getCount(); i++)
+ {
+ if (i > 0) out << ", ";
+ args[i]->toText(out);
+ }
+ out << ")";
+ }
+}
+
+HashCode SomeIntVal::_getHashCodeOverride()
+{
+ HashCode result = funcDeclRef.getHashCode();
+ for (auto arg : args)
+ {
+ result = combineHash(result, arg->getHashCode());
+ }
+ return result;
+}
+
+static bool nameIs(Name* name, const char* val)
+{
+ if (name && name->text.getUnownedSlice() == val)
+ return true;
+ return false;
+}
+
+Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
+{
+ // Are all args const now?
+ List<ConstantIntVal*> constArgs;
+ bool allConst = true;
+ for (auto arg : newArgs)
+ {
+ if (auto c = as<ConstantIntVal>(arg))
+ {
+ constArgs.add(c);
+ }
+ else
+ {
+ allConst = false;
+ break;
+ }
+ }
+ if (allConst)
+ {
+ // Evaluate the function.
+ auto opName = newFuncDecl.getName();
+ IntegerLiteralValue resultValue = 0;
+ if (nameIs(opName, "=="))
+ {
+ resultValue = constArgs[0]->value / constArgs[1]->value;
+ }
+#define BINARY_OPERATOR_CASE(op) \
+ else if (nameIs(opName, #op)) \
+ { \
+ resultValue = constArgs[0]->value op constArgs[1]->value; \
+ }
+ BINARY_OPERATOR_CASE(>=)
+ BINARY_OPERATOR_CASE(<=)
+ BINARY_OPERATOR_CASE(>)
+ BINARY_OPERATOR_CASE(<)
+ BINARY_OPERATOR_CASE(!=)
+ BINARY_OPERATOR_CASE(<<)
+ BINARY_OPERATOR_CASE(>>)
+ BINARY_OPERATOR_CASE(&)
+ BINARY_OPERATOR_CASE(|)
+ BINARY_OPERATOR_CASE(^)
+#undef BINARY_OPERATOR_CASE
+#define DIV_OPERATOR_CASE(op) \
+ else if (nameIs(opName, #op)) \
+ { \
+ if (constArgs[1]->value == 0) \
+ { \
+ if (sink) \
+ sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \
+ return nullptr; \
+ } \
+ resultValue = constArgs[0]->value op constArgs[1]->value; \
+ }
+ DIV_OPERATOR_CASE(/)
+ DIV_OPERATOR_CASE(%)
+#undef DIV_OPERATOR_CASE
+#define LOGICAL_OPERATOR_CASE(op) \
+ else if (nameIs(opName, #op)) \
+ { \
+ resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \
+ }
+ LOGICAL_OPERATOR_CASE(&&)
+ LOGICAL_OPERATOR_CASE(|| )
+#undef LOGICAL_OPERATOR_CASE
+ else if (nameIs(opName, "!"))
+ {
+ resultValue = ((constArgs[0]->value != 0) ? 1 : 0);
+ }
+ else if (nameIs(opName, "~"))
+ {
+ resultValue = ~constArgs[0]->value;
+ }
+ else if (nameIs(opName, "?:"))
+ {
+ resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;
+ }
+ else
+ {
+ SLANG_UNREACHABLE("constant folding of SomeIntVal");
+ }
+ return astBuilder->create<ConstantIntVal>(resultValue);
+ }
+ return nullptr;
+}
+
+Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff);
+ List<IntVal*> newArgs;
+ for (auto& arg : args)
+ {
+ auto substArg = arg->substituteImpl(astBuilder, subst, &diff);
+ if (substArg != arg)
+ diff++;
+ newArgs.add(as<IntVal>(substArg));
+ }
+ *ioDiff += diff;
+ if (diff)
+ {
+ // TODO: report diagnostics back.
+ auto newVal = tryFoldImpl(astBuilder, newFuncDeclRef, newArgs, nullptr);
+ if (newVal)
+ return newVal;
+ else
+ {
+ auto result = astBuilder->create<SomeIntVal>();
+ result->args = _Move(newArgs);
+ result->funcDeclRef = newFuncDeclRef;
+ result->funcType = funcType;
+ return result;
+ }
+ }
+ // Nothing found: don't substitute.
+ return this;
}
} // namespace Slang