summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-08-24 10:56:53 -0700
committerGitHub <noreply@github.com>2022-08-24 10:56:53 -0700
commitd245c72f2a92a74ccda83f41758c1948ae5132d3 (patch)
treef036e1f2afb7febe2de9b09990bcde6c04f3bad1
parent0b808453407f8feef8574cae99afd90771712185 (diff)
Compiler time evaluation of all int and bool operators. (#2376)
* Compiler time evaluation of all int and bool operators. * Fix linux compile error. * Fix. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--build/visual-studio/slang/slang.vcxproj2
-rw-r--r--build/visual-studio/slang/slang.vcxproj.filters6
-rw-r--r--source/slang/slang-ast-type.cpp10
-rw-r--r--source/slang/slang-ast-val.cpp316
-rw-r--r--source/slang/slang-ast-val.h64
-rw-r--r--source/slang/slang-check-expr.cpp22
-rw-r--r--source/slang/slang-diagnostic-defs.h2
-rw-r--r--source/slang/slang-emit-c-like.cpp11
-rw-r--r--source/slang/slang-ir-hoist-constants.cpp96
-rw-r--r--source/slang/slang-ir-hoist-constants.h13
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp2
-rw-r--r--source/slang/slang-ir.h69
-rw-r--r--source/slang/slang-lower-to-ir.cpp19
-rw-r--r--source/slang/slang-parser.cpp1
-rw-r--r--tests/language-feature/generics/generic-value-constant-folding.slang13
-rw-r--r--tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt8
-rw-r--r--tests/parser/generic-arg.slang15
-rw-r--r--tests/parser/generic-arg.slang.expected5
18 files changed, 597 insertions, 77 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj
index 8e4675262..9394806ef 100644
--- a/build/visual-studio/slang/slang.vcxproj
+++ b/build/visual-studio/slang/slang.vcxproj
@@ -355,6 +355,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClInclude Include="..\..\..\source\slang\slang-ir-generics-lowering-context.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-glsl-legalize.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-glsl-liveness.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-hoist-constants.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-hoist-local-types.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-inline.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-inst-defs.h" />
@@ -521,6 +522,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla
<ClCompile Include="..\..\..\source\slang\slang-ir-generics-lowering-context.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-glsl-legalize.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-glsl-liveness.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-hoist-constants.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-hoist-local-types.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-inline.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-layout.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters
index 2daf68e89..1115360d5 100644
--- a/build/visual-studio/slang/slang.vcxproj.filters
+++ b/build/visual-studio/slang/slang.vcxproj.filters
@@ -198,6 +198,9 @@
<ClInclude Include="..\..\..\source\slang\slang-ir-glsl-liveness.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-hoist-constants.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-hoist-local-types.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -692,6 +695,9 @@
<ClCompile Include="..\..\..\source\slang\slang-ir-glsl-liveness.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-hoist-constants.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-hoist-local-types.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 664c940a8..a84f04a32 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -387,15 +387,19 @@ Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, Substi
{
int diff = 0;
auto elementType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff));
- auto arrlen = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff));
- SLANG_ASSERT(arrlen);
+ IntVal* newArrayLength = nullptr;
+ if (arrayLength)
+ {
+ newArrayLength = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff));
+ SLANG_ASSERT(newArrayLength);
+ }
if (diff)
{
*ioDiff = 1;
auto rsType = getArrayType(
astBuilder,
elementType,
- arrlen);
+ newArrayLength);
return rsType;
}
return this;
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 16c649f3a..4ed69e282 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -4,7 +4,7 @@
#include <typeinfo>
#include "slang-generated-ast-macro.h"
-
+#include "slang-diagnostics.h"
#include "slang-syntax.h"
namespace Slang {
@@ -693,17 +693,18 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
for (auto& factor : term->paramFactors)
{
auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff);
- if (auto genericVal = as<GenericParamIntVal>(substResult))
+
+ if (auto constantVal = as<ConstantIntVal>(substResult))
+ {
+ evaluatedTermConstFactor *= constantVal->value;
+ }
+ else if (auto intResult = as<IntVal>(substResult))
{
auto newFactor = astBuilder->create<PolynomialIntValFactor>();
- newFactor->param = genericVal;
+ newFactor->param = intResult;
newFactor->power = factor->power;
evaluatedTermParamFactors.add(newFactor);
}
- else if (auto constantVal = as<ConstantIntVal>(substResult))
- {
- evaluatedTermConstFactor *= constantVal->value;
- }
}
if (evaluatedTermParamFactors.getCount() == 0)
{
@@ -727,28 +728,16 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut
auto newPolynomial = astBuilder->create<PolynomialIntVal>();
newPolynomial->constantTerm = evaluatedConstantTerm;
newPolynomial->terms = _Move(evaluatedTerms);
- newPolynomial->canonicalize(astBuilder);
- return newPolynomial;
+ return newPolynomial->canonicalize(astBuilder);
}
- return nullptr;
+ return this;
}
// compute val += opreand*multiplier;
bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier)
{
- if (auto genVal = as<GenericParamIntVal>(operand))
- {
- auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = multiplier;
- auto factor = astBuilder->create<PolynomialIntValFactor>();
- factor->power = 1;
- factor->param = genVal;
- term->paramFactors.add(factor);
- val->terms.add(term);
- return true;
- }
- else if (auto c = as<ConstantIntVal>(operand))
+ if (auto c = as<ConstantIntVal>(operand))
{
val->constantTerm += c->value * multiplier;
return true;
@@ -765,6 +754,17 @@ bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal*
}
return true;
}
+ else if (auto genVal = as<IntVal>(operand))
+ {
+ auto term = astBuilder->create<PolynomialIntValTerm>();
+ term->constFactor = multiplier;
+ auto factor = astBuilder->create<PolynomialIntValFactor>();
+ factor->power = 1;
+ factor->param = genVal;
+ term->paramFactors.add(factor);
+ val->terms.add(term);
+ return true;
+ }
return false;
}
@@ -845,24 +845,15 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
result->canonicalize(astBuilder);
return result;
}
- else if (auto val1 = as<GenericParamIntVal>(op1))
+ else if (auto cVal1 = as<ConstantIntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
- result->constantTerm = 0;
+ result->constantTerm = poly0->constantTerm * cVal1->value;
auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- if (poly0->constantTerm != 0)
- {
- auto term0 = astBuilder->create<PolynomialIntValTerm>();
- term0->constFactor = poly0->constantTerm;
- term0->paramFactors.add(factor1);
- result->terms.add(term0);
- }
for (auto term : poly0->terms)
{
auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor;
+ newTerm->constFactor = term->constFactor * cVal1->value;
newTerm->paramFactors.addRange(term->paramFactors);
newTerm->paramFactors.add(factor1);
result->terms.add(newTerm);
@@ -870,15 +861,24 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
result->canonicalize(astBuilder);
return result;
}
- else if (auto cVal1 = as<ConstantIntVal>(op1))
+ else if (auto val1 = as<IntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
- result->constantTerm = poly0->constantTerm * cVal1->value;
+ result->constantTerm = 0;
auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ factor1->power = 1;
+ factor1->param = val1;
+ if (poly0->constantTerm != 0)
+ {
+ auto term0 = astBuilder->create<PolynomialIntValTerm>();
+ term0->constFactor = poly0->constantTerm;
+ term0->paramFactors.add(factor1);
+ result->terms.add(term0);
+ }
for (auto term : poly0->terms)
{
auto newTerm = astBuilder->create<PolynomialIntValTerm>();
- newTerm->constFactor = term->constFactor * cVal1->value;
+ newTerm->constFactor = term->constFactor;
newTerm->paramFactors.addRange(term->paramFactors);
newTerm->paramFactors.add(factor1);
result->terms.add(newTerm);
@@ -889,51 +889,51 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int
else
return nullptr;
}
- else if (auto val0 = as<GenericParamIntVal>(op0))
+ else if (as<ConstantIntVal>(op0))
+ {
+ return mul(astBuilder, op1, op0);
+ }
+ else if (auto val0 = as<IntVal>(op0))
{
if (auto poly1 = as<PolynomialIntVal>(op1))
{
return mul(astBuilder, op1, op0);
}
- else if (auto val1 = as<GenericParamIntVal>(op1))
+ else if (auto cVal1 = as<ConstantIntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = 1;
+ term->constFactor = cVal1->value;
auto factor0 = astBuilder->create<PolynomialIntValFactor>();
factor0->power = 1;
factor0->param = val0;
term->paramFactors.add(factor0);
- auto factor1 = astBuilder->create<PolynomialIntValFactor>();
- factor1->power = 1;
- factor1->param = val1;
- term->paramFactors.add(factor1);
result->terms.add(term);
result->canonicalize(astBuilder);
return result;
}
- else if (auto cVal1 = as<ConstantIntVal>(op1))
+ else if (auto val1 = as<IntVal>(op1))
{
auto result = astBuilder->create<PolynomialIntVal>();
auto term = astBuilder->create<PolynomialIntValTerm>();
- term->constFactor = cVal1->value;
+ term->constFactor = 1;
auto factor0 = astBuilder->create<PolynomialIntValFactor>();
factor0->power = 1;
factor0->param = val0;
term->paramFactors.add(factor0);
+ auto factor1 = astBuilder->create<PolynomialIntValFactor>();
+ factor1->power = 1;
+ factor1->param = val1;
+ term->paramFactors.add(factor1);
result->terms.add(term);
result->canonicalize(astBuilder);
return result;
}
}
- else if (as<ConstantIntVal>(op0))
- {
- return mul(astBuilder, op1, op0);
- }
return nullptr;
}
-void PolynomialIntVal::canonicalize(ASTBuilder* builder)
+IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder)
{
List<PolynomialIntValTerm*> newTerms;
IntegerLiteralValue newConstantTerm = constantTerm;
@@ -1028,6 +1028,220 @@ void PolynomialIntVal::canonicalize(ASTBuilder* builder)
}
newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; });
terms = _Move(newTerms2);
+ constantTerm = newConstantTerm;
+ if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 &&
+ terms[0]->paramFactors[0]->power == 1)
+ {
+ return terms[0]->paramFactors[0]->param;
+ }
+ if (terms.getCount() == 0)
+ return builder->create<ConstantIntVal>(constantTerm);
+ return this;
+}
+
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SomeIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+bool SomeIntVal::_equalsValOverride(Val* val)
+{
+ if (auto someIntVal = as<SomeIntVal>(val))
+ {
+ if (!funcDeclRef.equals(someIntVal->funcDeclRef))
+ return false;
+ if (args.getCount() != someIntVal->args.getCount())
+ return false;
+ for (Index i = 0; i < args.getCount(); i++)
+ {
+ if (!args[i]->equalsVal(someIntVal->args[i]))
+ return false;
+ }
+ return true;
+ }
+ return false;
+}
+
+void SomeIntVal::_toTextOverride(StringBuilder& out)
+{
+ auto argToText = [&](int index)
+ {
+ if (as<PolynomialIntVal>(args[index]) || as<SomeIntVal>(args[index]))
+ {
+ out << "(";
+ args[index]->toText(out);
+ out << ")";
+ }
+ else
+ {
+ args[index]->toText(out);
+ }
+ };
+ Name* name = funcDeclRef.getName();
+ if (args.getCount() == 2)
+ {
+ argToText(0);
+ out << (name ? name->text : "");
+ argToText(1);;
+ }
+ else if (args.getCount() == 1)
+ {
+ out << (name ? name->text : "");
+ argToText(0);
+ }
+ else if (name && name->text == "?:")
+ {
+ argToText(0);
+ out << "?";
+ argToText(1);
+ out << ":";
+ argToText(2);
+ }
+ else
+ {
+ if (name)
+ {
+ out << name->text;
+ }
+ out << "(";
+ for (Index i = 0; i < args.getCount(); i++)
+ {
+ if (i > 0) out << ", ";
+ args[i]->toText(out);
+ }
+ out << ")";
+ }
+}
+
+HashCode SomeIntVal::_getHashCodeOverride()
+{
+ HashCode result = funcDeclRef.getHashCode();
+ for (auto arg : args)
+ {
+ result = combineHash(result, arg->getHashCode());
+ }
+ return result;
+}
+
+static bool nameIs(Name* name, const char* val)
+{
+ if (name && name->text.getUnownedSlice() == val)
+ return true;
+ return false;
+}
+
+Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink)
+{
+ // Are all args const now?
+ List<ConstantIntVal*> constArgs;
+ bool allConst = true;
+ for (auto arg : newArgs)
+ {
+ if (auto c = as<ConstantIntVal>(arg))
+ {
+ constArgs.add(c);
+ }
+ else
+ {
+ allConst = false;
+ break;
+ }
+ }
+ if (allConst)
+ {
+ // Evaluate the function.
+ auto opName = newFuncDecl.getName();
+ IntegerLiteralValue resultValue = 0;
+ if (nameIs(opName, "=="))
+ {
+ resultValue = constArgs[0]->value / constArgs[1]->value;
+ }
+#define BINARY_OPERATOR_CASE(op) \
+ else if (nameIs(opName, #op)) \
+ { \
+ resultValue = constArgs[0]->value op constArgs[1]->value; \
+ }
+ BINARY_OPERATOR_CASE(>=)
+ BINARY_OPERATOR_CASE(<=)
+ BINARY_OPERATOR_CASE(>)
+ BINARY_OPERATOR_CASE(<)
+ BINARY_OPERATOR_CASE(!=)
+ BINARY_OPERATOR_CASE(<<)
+ BINARY_OPERATOR_CASE(>>)
+ BINARY_OPERATOR_CASE(&)
+ BINARY_OPERATOR_CASE(|)
+ BINARY_OPERATOR_CASE(^)
+#undef BINARY_OPERATOR_CASE
+#define DIV_OPERATOR_CASE(op) \
+ else if (nameIs(opName, #op)) \
+ { \
+ if (constArgs[1]->value == 0) \
+ { \
+ if (sink) \
+ sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \
+ return nullptr; \
+ } \
+ resultValue = constArgs[0]->value op constArgs[1]->value; \
+ }
+ DIV_OPERATOR_CASE(/)
+ DIV_OPERATOR_CASE(%)
+#undef DIV_OPERATOR_CASE
+#define LOGICAL_OPERATOR_CASE(op) \
+ else if (nameIs(opName, #op)) \
+ { \
+ resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \
+ }
+ LOGICAL_OPERATOR_CASE(&&)
+ LOGICAL_OPERATOR_CASE(|| )
+#undef LOGICAL_OPERATOR_CASE
+ else if (nameIs(opName, "!"))
+ {
+ resultValue = ((constArgs[0]->value != 0) ? 1 : 0);
+ }
+ else if (nameIs(opName, "~"))
+ {
+ resultValue = ~constArgs[0]->value;
+ }
+ else if (nameIs(opName, "?:"))
+ {
+ resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value;
+ }
+ else
+ {
+ SLANG_UNREACHABLE("constant folding of SomeIntVal");
+ }
+ return astBuilder->create<ConstantIntVal>(resultValue);
+ }
+ return nullptr;
+}
+
+Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+ auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff);
+ List<IntVal*> newArgs;
+ for (auto& arg : args)
+ {
+ auto substArg = arg->substituteImpl(astBuilder, subst, &diff);
+ if (substArg != arg)
+ diff++;
+ newArgs.add(as<IntVal>(substArg));
+ }
+ *ioDiff += diff;
+ if (diff)
+ {
+ // TODO: report diagnostics back.
+ auto newVal = tryFoldImpl(astBuilder, newFuncDeclRef, newArgs, nullptr);
+ if (newVal)
+ return newVal;
+ else
+ {
+ auto result = astBuilder->create<SomeIntVal>();
+ result->args = _Move(newArgs);
+ result->funcDeclRef = newFuncDeclRef;
+ result->funcType = funcType;
+ return result;
+ }
+ }
+ // Nothing found: don't substitute.
+ return this;
}
} // namespace Slang
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 6eaaa8eb1..64f04abf9 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -52,29 +52,70 @@ protected:
{}
};
+// An compile time int val as result of some general computation.
+class SomeIntVal : public IntVal
+{
+ SLANG_AST_CLASS(SomeIntVal)
+
+ bool _equalsValOverride(Val* val);
+ void _toTextOverride(StringBuilder& out);
+ HashCode _getHashCodeOverride();
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+
+ DeclRef<Decl> funcDeclRef;
+ List<IntVal*> args;
+ Type* funcType;
+ static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*> &newArgs, DiagnosticSink* sink);
+};
+
// polynomial expression "2*a*b^3 + 1" will be represented as:
// { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] }
class PolynomialIntValFactor : public NodeBase
{
SLANG_AST_CLASS(PolynomialIntValFactor)
public:
- GenericParamIntVal* param;
+ IntVal* param;
IntegerLiteralValue power;
// for sorting only.
bool operator<(const PolynomialIntValFactor& other) const
{
- if (param->declRef.decl < other.param->declRef.decl)
- return true;
- else if (param->declRef.decl == other.param->declRef.decl)
- return power < other.power;
- return false;
+ if (auto thisGenParam = as<GenericParamIntVal>(param))
+ {
+ if (auto thatGenParam = as<GenericParamIntVal>(other.param))
+ {
+ if (thisGenParam->equalsVal(thatGenParam))
+ return power < other.power;
+ else
+ return thisGenParam->declRef.decl < thatGenParam->declRef.decl;
+ }
+ else
+ {
+ return true;
+ }
+ }
+ else
+ {
+ if (auto thatGenParam = as<GenericParamIntVal>(other.param))
+ {
+ return false;
+ }
+ return param == other.param ? power < other.power : param < other.param;
+ }
+
}
// for sorting only.
bool operator==(const PolynomialIntValFactor& other) const
{
- if (param->declRef.decl == other.param->declRef.decl && power == other.power)
- return true;
- return false;
+ if (auto thisGenParam = as<GenericParamIntVal>(param))
+ {
+ if (auto thatGenParam = as<GenericParamIntVal>(other.param))
+ {
+ if (param->equalsVal(other.param) && power == other.power)
+ return true;
+ }
+ return false;
+ }
+ return power == other.power && param == other.param;
}
bool equals(const PolynomialIntValFactor& other) const
{
@@ -132,7 +173,10 @@ public:
IntegerLiteralValue constantTerm = 0;
bool isConstant() { return terms.getCount() == 0; }
- void canonicalize(ASTBuilder* builder);
+ // Canonicalize the polynomial. If the polynomial can be simplified to a constant or a genericparam,
+ // the method returns the value simplified to.
+ // Otherwise, in-place modifications are performed and returns this.
+ IntVal* canonicalize(ASTBuilder* builder);
// Overrides should be public so base classes can access
bool _equalsValOverride(Val* val);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 7257790af..8e14af72a 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -977,6 +977,19 @@ namespace Slang
return PolynomialIntVal::mul(m_astBuilder, argVals[0], argVals[1]);
}
}
+ else if (opName == getName("/") || opName == getName("==") || opName == getName(">=") || opName == getName("<=") || opName == getName("!=")
+ || opName == getName(">") || opName == getName("<") || opName == getName("&&") || opName == getName("||") || opName == getName("!")
+ || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") ||
+ opName == getName("?:") || opName == getName("<<") || opName == getName(">>"))
+ {
+ auto result = m_astBuilder->create<SomeIntVal>();
+ result->args.addRange(argVals, argCount);
+ result->funcDeclRef = funcDeclRef;
+ result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute(
+ m_astBuilder, funcDeclRefExpr.getSubsts()));
+ SLANG_RELEASE_ASSERT(result->funcType);
+ return result;
+ }
return nullptr;
}
@@ -1062,6 +1075,15 @@ namespace Slang
CASE(/);
CASE(%);
#undef CASE
+ else if (opName == getName("?:"))
+ {
+ if (argCount != 3)
+ return nullptr;
+ if (constArgVals[0] != 0)
+ resultValue = constArgVals[1];
+ else
+ resultValue = constArgVals[2];
+ }
// TODO(tfoley): more cases
else
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 14be426d6..27076681d 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -234,7 +234,7 @@ DIAGNOSTIC(20014, Error, classIsReservedKeyword, "'class' is a reserved keyword
//
// 3xxxx - Semantic analysis
//
-
+DIAGNOSTIC(30002, Error, divideByZero, "divide by zero")
DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop constructs.")
DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.")
DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.")
diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp
index e98c8c6f3..7a7951ba1 100644
--- a/source/slang/slang-emit-c-like.cpp
+++ b/source/slang/slang-emit-c-like.cpp
@@ -3381,6 +3381,17 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I
case kIROp_NativePtrType:
requiredLevel = EmitAction::ForwardDeclaration;
break;
+ case kIROp_lookup_interface_method:
+ case kIROp_FieldExtract:
+ case kIROp_FieldAddress:
+ {
+ auto opType = inst->getOperand(0)->getDataType();
+ if (auto nativePtrType = as<IRNativePtrType>(opType))
+ {
+ ensureInstOperand(ctx, nativePtrType->getValueType(), requiredLevel);
+ }
+ break;
+ }
default:
break;
}
diff --git a/source/slang/slang-ir-hoist-constants.cpp b/source/slang/slang-ir-hoist-constants.cpp
new file mode 100644
index 000000000..87d3487a3
--- /dev/null
+++ b/source/slang/slang-ir-hoist-constants.cpp
@@ -0,0 +1,96 @@
+// slang-ir-hoist-constants.cpp
+#include "slang-ir-hoist-constants.h"
+#include "slang-ir-inst-pass-base.h"
+
+namespace Slang
+{
+
+struct HoistConstantPass : InstPassBase
+{
+ HoistConstantPass(IRModule* module) : InstPassBase(module)
+ {}
+
+ bool changed = false;
+
+ void processModule()
+ {
+ sharedBuilderStorage.init(module);
+
+ processAllInsts([this](IRInst* inst)
+ {
+
+ if (inst->getParent() == module->getModuleInst() || !inst->getParent())
+ return;
+ auto parent = inst->getParent();
+ auto p = parent;
+ while (p)
+ {
+ if (as<IRGlobalValueWithCode>(p))
+ return;
+ p = p->parent;
+ }
+ while (parent && parent->parent != module->getModuleInst())
+ parent = parent->parent;
+ if (!parent)
+ return;
+ switch (inst->getOp())
+ {
+ default:
+ return;
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Module:
+ case kIROp_Neg:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_BitAnd:
+ case kIROp_BitNot:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Select:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_Leq:
+ case kIROp_Geq:
+ case kIROp_Eql:
+ case kIROp_Neq:
+ case kIROp_BitCast:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+ case kIROp_Construct:
+ case kIROp_makeVector:
+ case kIROp_MakeMatrix:
+ case kIROp_swizzle:
+ case kIROp_IntLit:
+ case kIROp_BoolLit:
+ case kIROp_ArrayType:
+ case kIROp_Specialize:
+ case kIROp_VectorType:
+ break;
+ }
+ if (inst->typeUse.get() && inst->typeUse.get()->parent != module->getModuleInst())
+ return;
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
+ {
+ if (inst->getOperand(i)->parent != module->getModuleInst())
+ return;
+ }
+ // all operands are in global scope, we can move this inst to global scope as well.
+ inst->insertBefore(parent);
+ changed = true;
+ });
+ }
+};
+
+bool hoistConstants(
+ IRModule* module)
+{
+ HoistConstantPass context(module);
+ context.processModule();
+ return context.changed;
+}
+
+}
diff --git a/source/slang/slang-ir-hoist-constants.h b/source/slang/slang-ir-hoist-constants.h
new file mode 100644
index 000000000..28d4a0c6b
--- /dev/null
+++ b/source/slang/slang-ir-hoist-constants.h
@@ -0,0 +1,13 @@
+// slang-ir-hoist-constants.h
+#pragma once
+
+namespace Slang
+{
+struct IRModule;
+
+ /// A (specialized) generic type may contain insts that computes compile-time constants defined within
+ /// the type. We should hoist them to global scope so they can be SCCP'd when possible.
+bool hoistConstants(
+ IRModule* module);
+
+}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index 22aea8d36..f723325c4 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-dce.h"
#include "slang-ir-simplify-cfg.h"
#include "slang-ir-peephole.h"
+#include "slang-ir-hoist-constants.h"
namespace Slang
{
@@ -21,6 +22,7 @@ namespace Slang
while (changed && iterationCounter < kMaxIterations)
{
changed = false;
+ changed |= hoistConstants(module);
changed |= applySparseConditionalConstantPropagation(module);
changed |= peepholeOptimize(module);
changed |= simplifyCFG(module);
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index aea425a9b..c48f4b378 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -203,6 +203,36 @@ struct IRInstList : IRInstListBase
Iterator end();
};
+template<typename T>
+struct IRFilteredInstList : IRInstListBase
+{
+ IRFilteredInstList() {}
+
+ IRFilteredInstList(IRInst* fst, IRInst* lst);
+
+ explicit IRFilteredInstList(IRInstListBase const& list)
+ : IRFilteredInstList(list.first, list.last)
+ {}
+
+ T* getFirst() { return (T*)first; }
+ T* getLast() { return (T*)last; }
+
+ struct Iterator : public IRInstListBase::Iterator
+ {
+ IRInst* exclusiveLast;
+ Iterator() {}
+ Iterator(IRInst* inst, IRInst* lastIter) : IRInstListBase::Iterator(inst), exclusiveLast(lastIter) {}
+ void operator++();
+ T* operator*()
+ {
+ return (T*)inst;
+ }
+ };
+
+ Iterator begin();
+ Iterator end();
+};
+
/// A list of contiguous operands that can be iterated over as `IRInst`s.
struct IROperandListBase
{
@@ -741,6 +771,41 @@ typename IRInstList<T>::Iterator IRInstList<T>::end()
return Iterator(last ? last->next : nullptr);
}
+template<typename T>
+IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst)
+{
+ first = fst;
+ last = lst;
+
+ auto lastIter = last ? last->next : nullptr;
+ while (first != lastIter && !as<T>(first))
+ first = first->next;
+ while (last && last != first && !as<T>(last))
+ last = last->prev;
+}
+
+template<typename T>
+void IRFilteredInstList<T>::Iterator::operator++()
+{
+ inst = inst->next;
+ while (inst != exclusiveLast && !as<T>(inst))
+ {
+ inst = inst->next;
+ }
+}
+template<typename T>
+typename IRFilteredInstList<T>::Iterator IRFilteredInstList<T>::begin()
+{
+ auto lastIter = last ? last->next : nullptr;
+ return IRFilteredInstList<T>::Iterator(first, lastIter);
+}
+
+template<typename T>
+typename IRFilteredInstList<T>::Iterator IRFilteredInstList<T>::end()
+{
+ auto lastIter = last ? last->next : nullptr;
+ return IRFilteredInstList<T>::Iterator(lastIter, lastIter);
+}
// Types
@@ -1419,14 +1484,14 @@ struct IRStructField : IRInst
//
struct IRStructType : IRType
{
- IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); }
+ IRFilteredInstList<IRStructField> getFields() { return IRFilteredInstList<IRStructField>(getChildren()); }
IR_LEAF_ISA(StructType)
};
struct IRClassType : IRType
{
- IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); }
+ IRFilteredInstList<IRStructField> getFields() { return IRFilteredInstList<IRStructField>(getChildren()); }
IR_LEAF_ISA(ClassType)
};
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 3d4f0df5d..b351bfe21 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1287,6 +1287,25 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
lowerType(context, getType(context->astBuilder, val->declRef)));
}
+ LoweredValInfo visitSomeIntVal(SomeIntVal* val)
+ {
+ TryClauseEnvironment tryEnv;
+ List<IRInst*> args;
+ for (auto arg : val->args)
+ {
+ auto loweredArg = lowerVal(context, arg);
+ args.add(loweredArg.val);
+ }
+ auto funcType = lowerType(context, val->funcType);
+ return emitCallToDeclRef(
+ context,
+ as<IRFuncType>(funcType)->getResultType(),
+ val->funcDeclRef,
+ funcType,
+ args,
+ tryEnv);
+ }
+
LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val)
{
auto irBuilder = getBuilder();
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 33f2e251b..594ca4cc3 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -1997,6 +1997,7 @@ namespace Slang
case TokenType::Semicolon:
case TokenType::OpEql:
case TokenType::OpNeq:
+ case TokenType::OpGreater:
{
return parseGenericApp(parser, base);
}
diff --git a/tests/language-feature/generics/generic-value-constant-folding.slang b/tests/language-feature/generics/generic-value-constant-folding.slang
index 1d6781889..f7525e6fd 100644
--- a/tests/language-feature/generics/generic-value-constant-folding.slang
+++ b/tests/language-feature/generics/generic-value-constant-folding.slang
@@ -1,13 +1,13 @@
//TEST(compute):COMPARE_COMPUTE: -shaderobj -output-using-type
struct PlusOne<let v : int>
-{
- static const int value = v + 1;
+{
+ static const int value = v > 0? v + 1 : v - 1;
}
struct GetConst<let v : int, let u : int>
{
- static const int value = (u+v)*(u+v) + PlusOne<u-1>.value;
+ static const int value = (u/2+v)*(u+v) / PlusOne<u-1>.value;
int arr[value];
}
@@ -19,7 +19,8 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
{
int tid = dispatchThreadID.x;
int inVal = tid;
- int arr[GetConst<5,2>.value + 1];
- arr[0] = GetConst<5,3>.value + 1;
- outputBuffer[tid] = arr[0];
+ int arr[GetConst<5, 2>.value + 1];
+ GetConst<5, 3> obj;
+ obj.arr[0] = GetConst<5, 3>.value + 1;
+ outputBuffer[tid] = obj.arr[0];
}
diff --git a/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt b/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt
index 2ba17a828..ebb66dc8d 100644
--- a/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt
+++ b/tests/language-feature/generics/generic-value-constant-folding.slang.expected.txt
@@ -1,5 +1,5 @@
type: int32_t
-68
-68
-68
-68
+17
+17
+17
+17 \ No newline at end of file
diff --git a/tests/parser/generic-arg.slang b/tests/parser/generic-arg.slang
new file mode 100644
index 000000000..6917ad4db
--- /dev/null
+++ b/tests/parser/generic-arg.slang
@@ -0,0 +1,15 @@
+// generic-arg.slang
+
+//DIAGNOSTIC_TEST:SIMPLE:
+
+// Test disambiguation of expression and generic app.
+
+namespace NS
+{
+ struct MyType<let u : int, let v : int>
+ {
+ int arr[u /(v+1-1)];
+ }
+}
+
+StructuredBuffer<NS.MyType<1, 3>> buffer;
diff --git a/tests/parser/generic-arg.slang.expected b/tests/parser/generic-arg.slang.expected
new file mode 100644
index 000000000..4c32e2510
--- /dev/null
+++ b/tests/parser/generic-arg.slang.expected
@@ -0,0 +1,5 @@
+result code = 0
+standard error = {
+}
+standard output = {
+}