diff options
| author | Yong He <yonghe@outlook.com> | 2022-08-24 10:56:53 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-24 10:56:53 -0700 |
| commit | d245c72f2a92a74ccda83f41758c1948ae5132d3 (patch) | |
| tree | f036e1f2afb7febe2de9b09990bcde6c04f3bad1 /source | |
| parent | 0b808453407f8feef8574cae99afd90771712185 (diff) | |
Compiler time evaluation of all int and bool operators. (#2376)
* Compiler time evaluation of all int and bool operators.
* Fix linux compile error.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 316 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 64 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 22 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-emit-c-like.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-hoist-constants.cpp | 96 | ||||
| -rw-r--r-- | source/slang/slang-ir-hoist-constants.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa-simplification.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 69 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 19 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 1 |
12 files changed, 558 insertions, 67 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 664c940a8..a84f04a32 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -387,15 +387,19 @@ Val* ArrayExpressionType::_substituteImplOverride(ASTBuilder* astBuilder, Substi { int diff = 0; auto elementType = as<Type>(baseType->substituteImpl(astBuilder, subst, &diff)); - auto arrlen = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff)); - SLANG_ASSERT(arrlen); + IntVal* newArrayLength = nullptr; + if (arrayLength) + { + newArrayLength = as<IntVal>(arrayLength->substituteImpl(astBuilder, subst, &diff)); + SLANG_ASSERT(newArrayLength); + } if (diff) { *ioDiff = 1; auto rsType = getArrayType( astBuilder, elementType, - arrlen); + newArrayLength); return rsType; } return this; diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 16c649f3a..4ed69e282 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -4,7 +4,7 @@ #include <typeinfo> #include "slang-generated-ast-macro.h" - +#include "slang-diagnostics.h" #include "slang-syntax.h" namespace Slang { @@ -693,17 +693,18 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut for (auto& factor : term->paramFactors) { auto substResult = factor->param->substituteImpl(astBuilder, subst, &diff); - if (auto genericVal = as<GenericParamIntVal>(substResult)) + + if (auto constantVal = as<ConstantIntVal>(substResult)) + { + evaluatedTermConstFactor *= constantVal->value; + } + else if (auto intResult = as<IntVal>(substResult)) { auto newFactor = astBuilder->create<PolynomialIntValFactor>(); - newFactor->param = genericVal; + newFactor->param = intResult; newFactor->power = factor->power; evaluatedTermParamFactors.add(newFactor); } - else if (auto constantVal = as<ConstantIntVal>(substResult)) - { - evaluatedTermConstFactor *= constantVal->value; - } } if (evaluatedTermParamFactors.getCount() == 0) { @@ -727,28 +728,16 @@ Val* PolynomialIntVal::_substituteImplOverride(ASTBuilder* astBuilder, Substitut auto newPolynomial = astBuilder->create<PolynomialIntVal>(); newPolynomial->constantTerm = evaluatedConstantTerm; newPolynomial->terms = _Move(evaluatedTerms); - newPolynomial->canonicalize(astBuilder); - return newPolynomial; + return newPolynomial->canonicalize(astBuilder); } - return nullptr; + return this; } // compute val += opreand*multiplier; bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* operand, IntegerLiteralValue multiplier) { - if (auto genVal = as<GenericParamIntVal>(operand)) - { - auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = multiplier; - auto factor = astBuilder->create<PolynomialIntValFactor>(); - factor->power = 1; - factor->param = genVal; - term->paramFactors.add(factor); - val->terms.add(term); - return true; - } - else if (auto c = as<ConstantIntVal>(operand)) + if (auto c = as<ConstantIntVal>(operand)) { val->constantTerm += c->value * multiplier; return true; @@ -765,6 +754,17 @@ bool addToPolynomialTerm(ASTBuilder* astBuilder, PolynomialIntVal* val, IntVal* } return true; } + else if (auto genVal = as<IntVal>(operand)) + { + auto term = astBuilder->create<PolynomialIntValTerm>(); + term->constFactor = multiplier; + auto factor = astBuilder->create<PolynomialIntValFactor>(); + factor->power = 1; + factor->param = genVal; + term->paramFactors.add(factor); + val->terms.add(term); + return true; + } return false; } @@ -845,24 +845,15 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int result->canonicalize(astBuilder); return result; } - else if (auto val1 = as<GenericParamIntVal>(op1)) + else if (auto cVal1 = as<ConstantIntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); - result->constantTerm = 0; + result->constantTerm = poly0->constantTerm * cVal1->value; auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - if (poly0->constantTerm != 0) - { - auto term0 = astBuilder->create<PolynomialIntValTerm>(); - term0->constFactor = poly0->constantTerm; - term0->paramFactors.add(factor1); - result->terms.add(term0); - } for (auto term : poly0->terms) { auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor; + newTerm->constFactor = term->constFactor * cVal1->value; newTerm->paramFactors.addRange(term->paramFactors); newTerm->paramFactors.add(factor1); result->terms.add(newTerm); @@ -870,15 +861,24 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int result->canonicalize(astBuilder); return result; } - else if (auto cVal1 = as<ConstantIntVal>(op1)) + else if (auto val1 = as<IntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); - result->constantTerm = poly0->constantTerm * cVal1->value; + result->constantTerm = 0; auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + if (poly0->constantTerm != 0) + { + auto term0 = astBuilder->create<PolynomialIntValTerm>(); + term0->constFactor = poly0->constantTerm; + term0->paramFactors.add(factor1); + result->terms.add(term0); + } for (auto term : poly0->terms) { auto newTerm = astBuilder->create<PolynomialIntValTerm>(); - newTerm->constFactor = term->constFactor * cVal1->value; + newTerm->constFactor = term->constFactor; newTerm->paramFactors.addRange(term->paramFactors); newTerm->paramFactors.add(factor1); result->terms.add(newTerm); @@ -889,51 +889,51 @@ PolynomialIntVal* PolynomialIntVal::mul(ASTBuilder* astBuilder, IntVal* op0, Int else return nullptr; } - else if (auto val0 = as<GenericParamIntVal>(op0)) + else if (as<ConstantIntVal>(op0)) + { + return mul(astBuilder, op1, op0); + } + else if (auto val0 = as<IntVal>(op0)) { if (auto poly1 = as<PolynomialIntVal>(op1)) { return mul(astBuilder, op1, op0); } - else if (auto val1 = as<GenericParamIntVal>(op1)) + else if (auto cVal1 = as<ConstantIntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = 1; + term->constFactor = cVal1->value; auto factor0 = astBuilder->create<PolynomialIntValFactor>(); factor0->power = 1; factor0->param = val0; term->paramFactors.add(factor0); - auto factor1 = astBuilder->create<PolynomialIntValFactor>(); - factor1->power = 1; - factor1->param = val1; - term->paramFactors.add(factor1); result->terms.add(term); result->canonicalize(astBuilder); return result; } - else if (auto cVal1 = as<ConstantIntVal>(op1)) + else if (auto val1 = as<IntVal>(op1)) { auto result = astBuilder->create<PolynomialIntVal>(); auto term = astBuilder->create<PolynomialIntValTerm>(); - term->constFactor = cVal1->value; + term->constFactor = 1; auto factor0 = astBuilder->create<PolynomialIntValFactor>(); factor0->power = 1; factor0->param = val0; term->paramFactors.add(factor0); + auto factor1 = astBuilder->create<PolynomialIntValFactor>(); + factor1->power = 1; + factor1->param = val1; + term->paramFactors.add(factor1); result->terms.add(term); result->canonicalize(astBuilder); return result; } } - else if (as<ConstantIntVal>(op0)) - { - return mul(astBuilder, op1, op0); - } return nullptr; } -void PolynomialIntVal::canonicalize(ASTBuilder* builder) +IntVal* PolynomialIntVal::canonicalize(ASTBuilder* builder) { List<PolynomialIntValTerm*> newTerms; IntegerLiteralValue newConstantTerm = constantTerm; @@ -1028,6 +1028,220 @@ void PolynomialIntVal::canonicalize(ASTBuilder* builder) } newTerms2.sort([](PolynomialIntValTerm* t1, PolynomialIntValTerm* t2) {return *t1 < *t2; }); terms = _Move(newTerms2); + constantTerm = newConstantTerm; + if (terms.getCount() == 1 && constantTerm == 0 && terms[0]->constFactor == 1 && terms[0]->paramFactors.getCount() == 1 && + terms[0]->paramFactors[0]->power == 1) + { + return terms[0]->paramFactors[0]->param; + } + if (terms.getCount() == 0) + return builder->create<ConstantIntVal>(constantTerm); + return this; +} + +// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! SomeIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! + +bool SomeIntVal::_equalsValOverride(Val* val) +{ + if (auto someIntVal = as<SomeIntVal>(val)) + { + if (!funcDeclRef.equals(someIntVal->funcDeclRef)) + return false; + if (args.getCount() != someIntVal->args.getCount()) + return false; + for (Index i = 0; i < args.getCount(); i++) + { + if (!args[i]->equalsVal(someIntVal->args[i])) + return false; + } + return true; + } + return false; +} + +void SomeIntVal::_toTextOverride(StringBuilder& out) +{ + auto argToText = [&](int index) + { + if (as<PolynomialIntVal>(args[index]) || as<SomeIntVal>(args[index])) + { + out << "("; + args[index]->toText(out); + out << ")"; + } + else + { + args[index]->toText(out); + } + }; + Name* name = funcDeclRef.getName(); + if (args.getCount() == 2) + { + argToText(0); + out << (name ? name->text : ""); + argToText(1);; + } + else if (args.getCount() == 1) + { + out << (name ? name->text : ""); + argToText(0); + } + else if (name && name->text == "?:") + { + argToText(0); + out << "?"; + argToText(1); + out << ":"; + argToText(2); + } + else + { + if (name) + { + out << name->text; + } + out << "("; + for (Index i = 0; i < args.getCount(); i++) + { + if (i > 0) out << ", "; + args[i]->toText(out); + } + out << ")"; + } +} + +HashCode SomeIntVal::_getHashCodeOverride() +{ + HashCode result = funcDeclRef.getHashCode(); + for (auto arg : args) + { + result = combineHash(result, arg->getHashCode()); + } + return result; +} + +static bool nameIs(Name* name, const char* val) +{ + if (name && name->text.getUnownedSlice() == val) + return true; + return false; +} + +Val* SomeIntVal::tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink) +{ + // Are all args const now? + List<ConstantIntVal*> constArgs; + bool allConst = true; + for (auto arg : newArgs) + { + if (auto c = as<ConstantIntVal>(arg)) + { + constArgs.add(c); + } + else + { + allConst = false; + break; + } + } + if (allConst) + { + // Evaluate the function. + auto opName = newFuncDecl.getName(); + IntegerLiteralValue resultValue = 0; + if (nameIs(opName, "==")) + { + resultValue = constArgs[0]->value / constArgs[1]->value; + } +#define BINARY_OPERATOR_CASE(op) \ + else if (nameIs(opName, #op)) \ + { \ + resultValue = constArgs[0]->value op constArgs[1]->value; \ + } + BINARY_OPERATOR_CASE(>=) + BINARY_OPERATOR_CASE(<=) + BINARY_OPERATOR_CASE(>) + BINARY_OPERATOR_CASE(<) + BINARY_OPERATOR_CASE(!=) + BINARY_OPERATOR_CASE(<<) + BINARY_OPERATOR_CASE(>>) + BINARY_OPERATOR_CASE(&) + BINARY_OPERATOR_CASE(|) + BINARY_OPERATOR_CASE(^) +#undef BINARY_OPERATOR_CASE +#define DIV_OPERATOR_CASE(op) \ + else if (nameIs(opName, #op)) \ + { \ + if (constArgs[1]->value == 0) \ + { \ + if (sink) \ + sink->diagnose(newFuncDecl.getLoc(), Diagnostics::divideByZero); \ + return nullptr; \ + } \ + resultValue = constArgs[0]->value op constArgs[1]->value; \ + } + DIV_OPERATOR_CASE(/) + DIV_OPERATOR_CASE(%) +#undef DIV_OPERATOR_CASE +#define LOGICAL_OPERATOR_CASE(op) \ + else if (nameIs(opName, #op)) \ + { \ + resultValue = (((constArgs[0]->value!=0) op (constArgs[1]->value!=0)) ? 1 : 0); \ + } + LOGICAL_OPERATOR_CASE(&&) + LOGICAL_OPERATOR_CASE(|| ) +#undef LOGICAL_OPERATOR_CASE + else if (nameIs(opName, "!")) + { + resultValue = ((constArgs[0]->value != 0) ? 1 : 0); + } + else if (nameIs(opName, "~")) + { + resultValue = ~constArgs[0]->value; + } + else if (nameIs(opName, "?:")) + { + resultValue = constArgs[0]->value != 0 ? constArgs[1]->value : constArgs[2]->value; + } + else + { + SLANG_UNREACHABLE("constant folding of SomeIntVal"); + } + return astBuilder->create<ConstantIntVal>(resultValue); + } + return nullptr; +} + +Val* SomeIntVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto newFuncDeclRef = funcDeclRef.substituteImpl(astBuilder, subst, &diff); + List<IntVal*> newArgs; + for (auto& arg : args) + { + auto substArg = arg->substituteImpl(astBuilder, subst, &diff); + if (substArg != arg) + diff++; + newArgs.add(as<IntVal>(substArg)); + } + *ioDiff += diff; + if (diff) + { + // TODO: report diagnostics back. + auto newVal = tryFoldImpl(astBuilder, newFuncDeclRef, newArgs, nullptr); + if (newVal) + return newVal; + else + { + auto result = astBuilder->create<SomeIntVal>(); + result->args = _Move(newArgs); + result->funcDeclRef = newFuncDeclRef; + result->funcType = funcType; + return result; + } + } + // Nothing found: don't substitute. + return this; } } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index 6eaaa8eb1..64f04abf9 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -52,29 +52,70 @@ protected: {} }; +// An compile time int val as result of some general computation. +class SomeIntVal : public IntVal +{ + SLANG_AST_CLASS(SomeIntVal) + + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); + + DeclRef<Decl> funcDeclRef; + List<IntVal*> args; + Type* funcType; + static Val* tryFoldImpl(ASTBuilder* astBuilder, DeclRef<Decl> newFuncDecl, List<IntVal*> &newArgs, DiagnosticSink* sink); +}; + // polynomial expression "2*a*b^3 + 1" will be represented as: // { constantTerm:1, terms: [ { constFactor:2, paramFactors:[{"a", 1}, {"b", 3}] } ] } class PolynomialIntValFactor : public NodeBase { SLANG_AST_CLASS(PolynomialIntValFactor) public: - GenericParamIntVal* param; + IntVal* param; IntegerLiteralValue power; // for sorting only. bool operator<(const PolynomialIntValFactor& other) const { - if (param->declRef.decl < other.param->declRef.decl) - return true; - else if (param->declRef.decl == other.param->declRef.decl) - return power < other.power; - return false; + if (auto thisGenParam = as<GenericParamIntVal>(param)) + { + if (auto thatGenParam = as<GenericParamIntVal>(other.param)) + { + if (thisGenParam->equalsVal(thatGenParam)) + return power < other.power; + else + return thisGenParam->declRef.decl < thatGenParam->declRef.decl; + } + else + { + return true; + } + } + else + { + if (auto thatGenParam = as<GenericParamIntVal>(other.param)) + { + return false; + } + return param == other.param ? power < other.power : param < other.param; + } + } // for sorting only. bool operator==(const PolynomialIntValFactor& other) const { - if (param->declRef.decl == other.param->declRef.decl && power == other.power) - return true; - return false; + if (auto thisGenParam = as<GenericParamIntVal>(param)) + { + if (auto thatGenParam = as<GenericParamIntVal>(other.param)) + { + if (param->equalsVal(other.param) && power == other.power) + return true; + } + return false; + } + return power == other.power && param == other.param; } bool equals(const PolynomialIntValFactor& other) const { @@ -132,7 +173,10 @@ public: IntegerLiteralValue constantTerm = 0; bool isConstant() { return terms.getCount() == 0; } - void canonicalize(ASTBuilder* builder); + // Canonicalize the polynomial. If the polynomial can be simplified to a constant or a genericparam, + // the method returns the value simplified to. + // Otherwise, in-place modifications are performed and returns this. + IntVal* canonicalize(ASTBuilder* builder); // Overrides should be public so base classes can access bool _equalsValOverride(Val* val); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 7257790af..8e14af72a 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -977,6 +977,19 @@ namespace Slang return PolynomialIntVal::mul(m_astBuilder, argVals[0], argVals[1]); } } + else if (opName == getName("/") || opName == getName("==") || opName == getName(">=") || opName == getName("<=") || opName == getName("!=") + || opName == getName(">") || opName == getName("<") || opName == getName("&&") || opName == getName("||") || opName == getName("!") + || opName == getName("|") || opName == getName("&") || opName == getName("^") || opName == getName("~") || opName == getName("%") || + opName == getName("?:") || opName == getName("<<") || opName == getName(">>")) + { + auto result = m_astBuilder->create<SomeIntVal>(); + result->args.addRange(argVals, argCount); + result->funcDeclRef = funcDeclRef; + result->funcType = as<Type>(funcDeclRefExpr.getExpr()->type->substitute( + m_astBuilder, funcDeclRefExpr.getSubsts())); + SLANG_RELEASE_ASSERT(result->funcType); + return result; + } return nullptr; } @@ -1062,6 +1075,15 @@ namespace Slang CASE(/); CASE(%); #undef CASE + else if (opName == getName("?:")) + { + if (argCount != 3) + return nullptr; + if (constArgVals[0] != 0) + resultValue = constArgVals[1]; + else + resultValue = constArgVals[2]; + } // TODO(tfoley): more cases else { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 14be426d6..27076681d 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -234,7 +234,7 @@ DIAGNOSTIC(20014, Error, classIsReservedKeyword, "'class' is a reserved keyword // // 3xxxx - Semantic analysis // - +DIAGNOSTIC(30002, Error, divideByZero, "divide by zero") DIAGNOSTIC(30003, Error, breakOutsideLoop, "'break' must appear inside loop constructs.") DIAGNOSTIC(30004, Error, continueOutsideLoop, "'continue' must appear inside loop constructs.") DIAGNOSTIC(30005, Error, whilePredicateTypeError, "'while': expression must evaluate to int.") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index e98c8c6f3..7a7951ba1 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3381,6 +3381,17 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I case kIROp_NativePtrType: requiredLevel = EmitAction::ForwardDeclaration; break; + case kIROp_lookup_interface_method: + case kIROp_FieldExtract: + case kIROp_FieldAddress: + { + auto opType = inst->getOperand(0)->getDataType(); + if (auto nativePtrType = as<IRNativePtrType>(opType)) + { + ensureInstOperand(ctx, nativePtrType->getValueType(), requiredLevel); + } + break; + } default: break; } diff --git a/source/slang/slang-ir-hoist-constants.cpp b/source/slang/slang-ir-hoist-constants.cpp new file mode 100644 index 000000000..87d3487a3 --- /dev/null +++ b/source/slang/slang-ir-hoist-constants.cpp @@ -0,0 +1,96 @@ +// slang-ir-hoist-constants.cpp +#include "slang-ir-hoist-constants.h" +#include "slang-ir-inst-pass-base.h" + +namespace Slang +{ + +struct HoistConstantPass : InstPassBase +{ + HoistConstantPass(IRModule* module) : InstPassBase(module) + {} + + bool changed = false; + + void processModule() + { + sharedBuilderStorage.init(module); + + processAllInsts([this](IRInst* inst) + { + + if (inst->getParent() == module->getModuleInst() || !inst->getParent()) + return; + auto parent = inst->getParent(); + auto p = parent; + while (p) + { + if (as<IRGlobalValueWithCode>(p)) + return; + p = p->parent; + } + while (parent && parent->parent != module->getModuleInst()) + parent = parent->parent; + if (!parent) + return; + switch (inst->getOp()) + { + default: + return; + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Module: + case kIROp_Neg: + case kIROp_And: + case kIROp_Or: + case kIROp_Not: + case kIROp_BitAnd: + case kIROp_BitNot: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_Select: + case kIROp_Greater: + case kIROp_Less: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_BitCast: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_Construct: + case kIROp_makeVector: + case kIROp_MakeMatrix: + case kIROp_swizzle: + case kIROp_IntLit: + case kIROp_BoolLit: + case kIROp_ArrayType: + case kIROp_Specialize: + case kIROp_VectorType: + break; + } + if (inst->typeUse.get() && inst->typeUse.get()->parent != module->getModuleInst()) + return; + for (UInt i = 0; i < inst->getOperandCount(); i++) + { + if (inst->getOperand(i)->parent != module->getModuleInst()) + return; + } + // all operands are in global scope, we can move this inst to global scope as well. + inst->insertBefore(parent); + changed = true; + }); + } +}; + +bool hoistConstants( + IRModule* module) +{ + HoistConstantPass context(module); + context.processModule(); + return context.changed; +} + +} diff --git a/source/slang/slang-ir-hoist-constants.h b/source/slang/slang-ir-hoist-constants.h new file mode 100644 index 000000000..28d4a0c6b --- /dev/null +++ b/source/slang/slang-ir-hoist-constants.h @@ -0,0 +1,13 @@ +// slang-ir-hoist-constants.h +#pragma once + +namespace Slang +{ +struct IRModule; + + /// A (specialized) generic type may contain insts that computes compile-time constants defined within + /// the type. We should hoist them to global scope so they can be SCCP'd when possible. +bool hoistConstants( + IRModule* module); + +} diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index 22aea8d36..f723325c4 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -6,6 +6,7 @@ #include "slang-ir-dce.h" #include "slang-ir-simplify-cfg.h" #include "slang-ir-peephole.h" +#include "slang-ir-hoist-constants.h" namespace Slang { @@ -21,6 +22,7 @@ namespace Slang while (changed && iterationCounter < kMaxIterations) { changed = false; + changed |= hoistConstants(module); changed |= applySparseConditionalConstantPropagation(module); changed |= peepholeOptimize(module); changed |= simplifyCFG(module); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index aea425a9b..c48f4b378 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -203,6 +203,36 @@ struct IRInstList : IRInstListBase Iterator end(); }; +template<typename T> +struct IRFilteredInstList : IRInstListBase +{ + IRFilteredInstList() {} + + IRFilteredInstList(IRInst* fst, IRInst* lst); + + explicit IRFilteredInstList(IRInstListBase const& list) + : IRFilteredInstList(list.first, list.last) + {} + + T* getFirst() { return (T*)first; } + T* getLast() { return (T*)last; } + + struct Iterator : public IRInstListBase::Iterator + { + IRInst* exclusiveLast; + Iterator() {} + Iterator(IRInst* inst, IRInst* lastIter) : IRInstListBase::Iterator(inst), exclusiveLast(lastIter) {} + void operator++(); + T* operator*() + { + return (T*)inst; + } + }; + + Iterator begin(); + Iterator end(); +}; + /// A list of contiguous operands that can be iterated over as `IRInst`s. struct IROperandListBase { @@ -741,6 +771,41 @@ typename IRInstList<T>::Iterator IRInstList<T>::end() return Iterator(last ? last->next : nullptr); } +template<typename T> +IRFilteredInstList<T>::IRFilteredInstList(IRInst* fst, IRInst* lst) +{ + first = fst; + last = lst; + + auto lastIter = last ? last->next : nullptr; + while (first != lastIter && !as<T>(first)) + first = first->next; + while (last && last != first && !as<T>(last)) + last = last->prev; +} + +template<typename T> +void IRFilteredInstList<T>::Iterator::operator++() +{ + inst = inst->next; + while (inst != exclusiveLast && !as<T>(inst)) + { + inst = inst->next; + } +} +template<typename T> +typename IRFilteredInstList<T>::Iterator IRFilteredInstList<T>::begin() +{ + auto lastIter = last ? last->next : nullptr; + return IRFilteredInstList<T>::Iterator(first, lastIter); +} + +template<typename T> +typename IRFilteredInstList<T>::Iterator IRFilteredInstList<T>::end() +{ + auto lastIter = last ? last->next : nullptr; + return IRFilteredInstList<T>::Iterator(lastIter, lastIter); +} // Types @@ -1419,14 +1484,14 @@ struct IRStructField : IRInst // struct IRStructType : IRType { - IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); } + IRFilteredInstList<IRStructField> getFields() { return IRFilteredInstList<IRStructField>(getChildren()); } IR_LEAF_ISA(StructType) }; struct IRClassType : IRType { - IRInstList<IRStructField> getFields() { return IRInstList<IRStructField>(getChildren()); } + IRFilteredInstList<IRStructField> getFields() { return IRFilteredInstList<IRStructField>(getChildren()); } IR_LEAF_ISA(ClassType) }; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 3d4f0df5d..b351bfe21 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1287,6 +1287,25 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower lowerType(context, getType(context->astBuilder, val->declRef))); } + LoweredValInfo visitSomeIntVal(SomeIntVal* val) + { + TryClauseEnvironment tryEnv; + List<IRInst*> args; + for (auto arg : val->args) + { + auto loweredArg = lowerVal(context, arg); + args.add(loweredArg.val); + } + auto funcType = lowerType(context, val->funcType); + return emitCallToDeclRef( + context, + as<IRFuncType>(funcType)->getResultType(), + val->funcDeclRef, + funcType, + args, + tryEnv); + } + LoweredValInfo visitPolynomialIntVal(PolynomialIntVal* val) { auto irBuilder = getBuilder(); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 33f2e251b..594ca4cc3 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -1997,6 +1997,7 @@ namespace Slang case TokenType::Semicolon: case TokenType::OpEql: case TokenType::OpNeq: + case TokenType::OpGreater: { return parseGenericApp(parser, base); } |
