diff options
26 files changed, 311 insertions, 87 deletions
@@ -4649,6 +4649,14 @@ namespace slang */ virtual SLANG_NO_THROW bool SLANG_MCALL isBinaryModuleUpToDate( const char* modulePath, slang::IBlob* binaryModuleBlob) = 0; + + /** Load a module from a string. + */ + virtual SLANG_NO_THROW IModule* SLANG_MCALL loadModuleFromSourceString( + const char* moduleName, + const char* path, + const char* string, + slang::IBlob** outDiagnostics = nullptr) = 0; }; #define SLANG_UUID_ISession ISession::getTypeGuid() diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 8e5120fad..ed8cbf514 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -784,9 +784,9 @@ class GLSLLayoutLocalSizeAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - int32_t x; - int32_t y; - int32_t z; + IntVal* x; + IntVal* y; + IntVal* z; }; // TODO: for attributes that take arguments, the syntax node @@ -839,9 +839,9 @@ class NumThreadsAttribute : public Attribute // // TODO: These should be accessors that use the // ordinary `args` list, rather than side data. - int32_t x; - int32_t y; - int32_t z; + IntVal* x; + IntVal* y; + IntVal* z; }; class MaxVertexCountAttribute : public Attribute diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index d1408a3fc..b2b874fad 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -229,6 +229,11 @@ Val* GenericParamIntVal::_substituteImplOverride(ASTBuilder* /* astBuilder */, S return this; } +bool GenericParamIntVal::_isLinkTimeValOverride() +{ + return getDeclRef().getDecl()->hasModifier<ExternModifier>(); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ErrorIntVal !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! void ErrorIntVal::_toTextOverride(StringBuilder& out) @@ -1496,4 +1501,9 @@ Val* PolynomialIntVal::_resolveImplOverride() return this; } +bool IntVal::isLinkTimeVal() +{ + SLANG_AST_NODE_VIRTUAL_CALL(IntVal, isLinkTimeVal, ()); +} + } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index f85a76187..ce494b9da 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -141,6 +141,9 @@ class IntVal : public Val Type* getType() { return as<Type>(getOperand(0)); } Val* _resolveImplOverride() { return this; } + + bool isLinkTimeVal(); + bool _isLinkTimeValOverride() { return false; } }; // Trivial case of a value that is just a constant integer @@ -157,6 +160,7 @@ class ConstantIntVal : public IntVal { setOperands(inType, inValue); } + bool _isLinkTimeValOverride() { return false; } }; // The logical "value" of a reference to a generic value parameter @@ -174,6 +178,8 @@ class GenericParamIntVal : public IntVal { setOperands(inType, inDeclRef); } + + bool _isLinkTimeValOverride(); }; class TypeCastIntVal : public IntVal @@ -191,6 +197,13 @@ class TypeCastIntVal : public IntVal } static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink); + + bool _isLinkTimeValOverride() + { + if (auto intBase = as<IntVal>(getBase())) + return intBase->isLinkTimeVal(); + return false; + } }; // An compile time int val as result of some general computation. @@ -215,6 +228,16 @@ class FuncCallIntVal : public IntVal } static Val* tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, DeclRef<Decl> newFuncDecl, List<IntVal*>& newArgs, DiagnosticSink* sink); + + bool _isLinkTimeValOverride() + { + for (auto arg : getArgs()) + { + if (arg->isLinkTimeVal()) + return true; + } + return false; + } }; class WitnessLookupIntVal : public IntVal @@ -236,6 +259,11 @@ class WitnessLookupIntVal : public IntVal static Val* tryFoldOrNull(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key); static Val* tryFold(ASTBuilder* astBuilder, SubtypeWitness* witness, Decl* key, Type* type); + + bool _isLinkTimeValOverride() + { + return false; + } }; // polynomial expression "2*a*b^3 + 1" will be represented as: @@ -361,6 +389,16 @@ public: } return false; } + + bool isLinkTimeVal() + { + for (auto factor : getParamFactors()) + { + if (factor->getParam()->isLinkTimeVal()) + return true; + } + return false; + } }; class PolynomialIntVal : public IntVal @@ -387,6 +425,16 @@ public: setOperands(inType, inConstantTerm); addOperands(inTerms); } + + bool _isLinkTimeValOverride() + { + for (auto factor : getTerms()) + { + if (factor->isLinkTimeVal()) + return true; + } + return false; + } }; /// An unknown integer value indicating an erroneous sub-expression @@ -404,6 +452,10 @@ class ErrorIntVal : public IntVal void _toTextOverride(StringBuilder& out); Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); Val* _resolveImplOverride() { return this; } + bool _isLinkTimeValOverride() + { + return false; + } }; // A witness to the fact that some proposition is true, encoded diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index e73c0723b..d6e73e798 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -264,13 +264,35 @@ namespace Slang { if (auto arrayType = as<ArrayExpressionType>(type)) { - return getTypeTags(arrayType->getElementType()); + auto typeTag = getTypeTags(arrayType->getElementType()); + bool sized = false; + if (auto cint = as<ConstantIntVal>(arrayType->getElementCount())) + { + if (cint->getValue() != kUnsizedArrayMagicLength) + { + sized = true; + } + } + else if (auto intVal = arrayType->getElementCount()) + { + sized = !intVal->isLinkTimeVal(); + } + if (!sized) + typeTag = (TypeTag)((int)typeTag | (int)TypeTag::Unsized); + + return typeTag; } if (auto modifiedType = as<ModifiedType>(type)) { return getTypeTags(modifiedType->getBase()); } - if (auto declRefType = as<DeclRefType>(type)) + if (auto parameterGroupType = as<UniformParameterGroupType>(type)) + { + auto elementTags = getTypeTags(parameterGroupType->getElementType()); + elementTags = (TypeTag)((int)elementTags & ~(int)TypeTag::Unsized); + return elementTags; + } + else if (auto declRefType = as<DeclRefType>(type)) { if (auto aggTypeDecl = as<AggTypeDecl>(declRefType->getDeclRef())) return aggTypeDecl.getDecl()->typeTags; diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 9dae07018..a2381c7f7 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -1163,7 +1163,7 @@ namespace Slang bool shouldEmitGeneralWarning = true; if (isScalarIntegerType(toType)) { - if (auto intVal = tryFoldIntegerConstantExpression(fromExpr, nullptr)) + if (auto intVal = tryFoldIntegerConstantExpression(fromExpr, ConstantFoldingKind::CompileTime, nullptr)) { if (auto val = as<ConstantIntVal>(intVal)) { diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 3596b7045..7784100a6 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1379,7 +1379,7 @@ namespace Slang // if(!isScalarIntegerType(varDecl->type)) return; - tryConstantFoldDeclRef(DeclRef<VarDeclBase>(varDecl), nullptr); + tryConstantFoldDeclRef(DeclRef<VarDeclBase>(varDecl), ConstantFoldingKind::LinkTime, nullptr); } void SemanticsDeclModifiersVisitor::visitStructDecl(StructDecl* structDecl) @@ -1900,6 +1900,32 @@ namespace Slang varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); } } + + if (auto parentDecl = as<AggTypeDecl>(getParentDecl(varDecl))) + { + auto typeTags = getTypeTags(varDecl->getType()); + parentDecl->addTag(typeTags); + if ((int)typeTags & (int)TypeTag::Unsized) + { + // Unsized decl must appear as the last member of the struct. + for (auto memberIdx = parentDecl->members.getCount() - 1; memberIdx >= 0; memberIdx--) + { + if (parentDecl->members[memberIdx] == varDecl) + { + break; + } + if (auto memberVarDecl = as<VarDeclBase>(parentDecl->members[memberIdx])) + { + if (!memberVarDecl->hasModifier<HLSLStaticModifier>()) + { + getSink()->diagnose(varDecl, Diagnostics::unsizedMemberMustAppearLast); + } + break; + } + } + } + } + if (auto elementType = getConstantBufferElementType(varDecl->getType())) { if (doesTypeHaveTag(elementType, TypeTag::Incomplete)) @@ -2841,7 +2867,7 @@ namespace Slang return false; } - auto satisfyingVal = tryConstantFoldDeclRef(satisfyingMemberDeclRef, nullptr); + auto satisfyingVal = tryConstantFoldDeclRef(satisfyingMemberDeclRef, ConstantFoldingKind::LinkTime, nullptr); if (satisfyingVal) { witnessTable->add( @@ -5925,7 +5951,7 @@ namespace Slang // the tag value for a successor case that doesn't // provide an explicit tag. - IntVal* explicitTagVal = tryConstantFoldExpr(explicitTagValExpr, nullptr); + IntVal* explicitTagVal = tryConstantFoldExpr(explicitTagValExpr, ConstantFoldingKind::CompileTime, nullptr); if(explicitTagVal) { if(auto constIntVal = as<ConstantIntVal>(explicitTagVal)) @@ -5992,7 +6018,7 @@ namespace Slang // We want to enforce that this is an integer constant // expression, but we don't actually care to retain // the value. - CheckIntegerConstantExpression(initExpr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr); + CheckIntegerConstantExpression(initExpr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::CompileTime); decl->tagExpr = initExpr; } @@ -6906,6 +6932,7 @@ namespace Slang indexExpr->indexExprs[0], IntegerConstantExpressionCoercionType::AnyInteger, nullptr, + ConstantFoldingKind::LinkTime, getSink()); Type* d = m_astBuilder->getMeshOutputTypeFromModifier(modifier, base, index); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 33a1fa680..8b6fe76c7 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1411,6 +1411,7 @@ namespace Slang IntVal* SemanticsVisitor::tryConstantFoldExpr( SubstExpr<InvokeExpr> invokeExpr, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo) { // We need all the operands to the expression @@ -1448,7 +1449,7 @@ namespace Slang for(Index a = 0; a < argCount; ++a) { auto argExpr = getArg(invokeExpr, a); - auto argVal = tryFoldIntegerConstantExpression(argExpr, circularityInfo); + auto argVal = tryFoldIntegerConstantExpression(argExpr, kind, circularityInfo); if (!argVal) return nullptr; @@ -1647,6 +1648,7 @@ namespace Slang IntVal* SemanticsVisitor::tryConstantFoldDeclRef( DeclRef<VarDeclBase> const& declRef, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo) { auto decl = declRef.getDecl(); @@ -1657,9 +1659,17 @@ namespace Slang // In HLSL, `const` is used to mark compile-time constant expressions. if(!decl->hasModifier<ConstModifier>()) return nullptr; - // Extern const is not considered compile-time constant by the front-end. if (decl->hasModifier<ExternModifier>()) - return nullptr; + { + // Extern const is not considered compile-time constant by the front-end. + if (kind == ConstantFoldingKind::CompileTime) + return nullptr; + // But if we are OK with link-time constants, we can still fold it into a val. + auto rs = m_astBuilder->getOrCreate<GenericParamIntVal>( + declRef.substitute(m_astBuilder, declRef.getDecl()->getType()), + declRef); + return rs; + } if (isInterfaceRequirement(decl)) { @@ -1678,11 +1688,12 @@ namespace Slang ensureDecl(declRef.getDecl(), DeclCheckState::DefinitionChecked); ConstantFoldingCircularityInfo newCircularityInfo(decl, circularityInfo); - return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), &newCircularityInfo); + return tryConstantFoldExpr(getInitExpr(m_astBuilder, declRef), kind, &newCircularityInfo); } IntVal* SemanticsVisitor::tryConstantFoldExpr( SubstExpr<Expr> expr, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo) { @@ -1738,7 +1749,7 @@ namespace Slang // are defined in a way that can be used as a constant expression: if(auto varRef = declRef.as<VarDeclBase>()) { - return tryConstantFoldDeclRef(varRef, circularityInfo); + return tryConstantFoldDeclRef(varRef, kind, circularityInfo); } else if(auto enumRef = declRef.as<EnumCaseDecl>()) { @@ -1750,7 +1761,7 @@ namespace Slang return nullptr; ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo); - return tryConstantFoldExpr(tagExpr, &newCircularityInfo); + return tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo); } } } @@ -1762,7 +1773,7 @@ namespace Slang return nullptr; if (!isScalarIntegerType(substType)) return nullptr; - auto val = tryConstantFoldExpr(getArg(castExpr, 0), circularityInfo); + auto val = tryConstantFoldExpr(getArg(castExpr, 0), kind, circularityInfo); if (val) { if (!castExpr.getExpr()->type) @@ -1777,7 +1788,7 @@ namespace Slang } else if (auto invokeExpr = expr.as<InvokeExpr>()) { - auto val = tryConstantFoldExpr(invokeExpr, circularityInfo); + auto val = tryConstantFoldExpr(invokeExpr, kind, circularityInfo); if (val) return val; } @@ -1803,6 +1814,7 @@ namespace Slang IntVal* SemanticsVisitor::tryFoldIntegerConstantExpression( SubstExpr<Expr> expr, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo) { // Check if type is acceptable for an integer constant expression @@ -1812,10 +1824,10 @@ namespace Slang // Consider operations that we might be able to constant-fold... // - return tryConstantFoldExpr(expr, circularityInfo); + return tryConstantFoldExpr(expr, kind, circularityInfo); } - IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, DiagnosticSink* sink) + IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind, DiagnosticSink* sink) { // No need to issue further errors if the expression didn't even type-check. if(IsErrorExpr(inExpr)) return nullptr; @@ -1840,7 +1852,7 @@ namespace Slang // No need to issue further errors if the type coercion failed. if(IsErrorExpr(expr)) return nullptr; - auto result = tryFoldIntegerConstantExpression(expr, nullptr); + auto result = tryFoldIntegerConstantExpression(expr, kind, nullptr); if (!result && sink) { sink->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); @@ -1848,12 +1860,12 @@ namespace Slang return result; } - IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType) + IntVal* SemanticsVisitor::CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind) { - return CheckIntegerConstantExpression(inExpr, coercionType, expectedType, getSink()); + return CheckIntegerConstantExpression(inExpr, coercionType, expectedType, kind, getSink()); } - IntVal* SemanticsVisitor::CheckEnumConstantExpression(Expr* expr) + IntVal* SemanticsVisitor::CheckEnumConstantExpression(Expr* expr, ConstantFoldingKind kind) { // No need to issue further errors if the expression didn't even type-check. if(IsErrorExpr(expr)) return nullptr; @@ -1861,7 +1873,7 @@ namespace Slang // No need to issue further errors if the type coercion failed. if(IsErrorExpr(expr)) return nullptr; - auto result = tryConstantFoldExpr(expr, nullptr); + auto result = tryConstantFoldExpr(expr, kind, nullptr); if (!result) { getSink()->diagnose(expr, Diagnostics::expectedIntegerConstantNotConstant); @@ -1936,7 +1948,7 @@ namespace Slang IntVal* elementCount = nullptr; if (subscriptExpr->indexExprs.getCount() == 1) { - elementCount = CheckIntegerConstantExpression(subscriptExpr->indexExprs[0], IntegerConstantExpressionCoercionType::AnyInteger, nullptr); + elementCount = CheckIntegerConstantExpression(subscriptExpr->indexExprs[0], IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::LinkTime); } else if (subscriptExpr->indexExprs.getCount() != 0) { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index a209d96d9..702fe5619 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1482,6 +1482,9 @@ namespace Slang void checkGenericDeclHeader(GenericDecl* genericDecl); + IntVal* checkLinkTimeConstantIntVal( + Expr* expr); + ConstantIntVal* checkConstantIntVal( Expr* expr); @@ -1808,7 +1811,12 @@ namespace Slang Expr* checkPredicateExpr(Expr* expr); - Expr* checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal); + enum class ConstantFoldingKind + { + CompileTime, + LinkTime, + }; + Expr* checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal, ConstantFoldingKind kind); IntegerLiteralValue GetMinBound(IntVal* val); @@ -1845,15 +1853,16 @@ namespace Slang /// The rest of the links in the chain of declarations being folded ConstantFoldingCircularityInfo* next = nullptr; }; - /// Try to apply front-end constant folding to determine the value of `invokeExpr`. IntVal* tryConstantFoldExpr( SubstExpr<InvokeExpr> invokeExpr, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo); /// Try to apply front-end constant folding to determine the value of `expr`. IntVal* tryConstantFoldExpr( SubstExpr<Expr> expr, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo); bool _checkForCircularityInConstantFolding( @@ -1863,6 +1872,7 @@ namespace Slang /// Try to resolve a compile-time constant `IntVal` from the given `declRef`. IntVal* tryConstantFoldDeclRef( DeclRef<VarDeclBase> const& declRef, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo); /// Try to extract the value of an integer constant expression, either @@ -1871,6 +1881,7 @@ namespace Slang /// IntVal* tryFoldIntegerConstantExpression( SubstExpr<Expr> expr, + ConstantFoldingKind kind, ConstantFoldingCircularityInfo* circularityInfo); // Enforce that an expression resolves to an integer constant, and get its value @@ -1879,10 +1890,10 @@ namespace Slang SpecificType, AnyInteger }; - IntVal* CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType); - IntVal* CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, DiagnosticSink* sink); + IntVal* CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind); + IntVal* CheckIntegerConstantExpression(Expr* inExpr, IntegerConstantExpressionCoercionType coercionType, Type* expectedType, ConstantFoldingKind kind, DiagnosticSink* sink); - IntVal* CheckEnumConstantExpression(Expr* expr); + IntVal* CheckEnumConstantExpression(Expr* expr, ConstantFoldingKind kind); Expr* CheckSimpleSubscriptExpr( diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 0f49891d0..cf4bf3b02 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -12,13 +12,20 @@ namespace Slang { + IntVal* SemanticsVisitor::checkLinkTimeConstantIntVal( + Expr* expr) + { + expr = CheckExpr(expr); + return CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::LinkTime); + } + ConstantIntVal* SemanticsVisitor::checkConstantIntVal( Expr* expr) { // First type-check the expression as normal expr = CheckExpr(expr); - auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr); + auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, ConstantFoldingKind::CompileTime); if(!intVal) return nullptr; @@ -37,7 +44,7 @@ namespace Slang // First type-check the expression as normal expr = CheckExpr(expr); - auto intVal = CheckEnumConstantExpression(expr); + auto intVal = CheckEnumConstantExpression(expr, ConstantFoldingKind::CompileTime); if(!intVal) return nullptr; @@ -320,26 +327,33 @@ namespace Slang { SLANG_ASSERT(attr->args.getCount() == 3); - int32_t values[3]; + IntVal* values[3]; for (int i = 0; i < 3; ++i) { - int32_t value = 1; + IntVal* value = nullptr; auto arg = attr->args[i]; if (arg) { - auto intValue = checkConstantIntVal(arg); + auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { return false; } - if (intValue->getValue() < 1) + if (auto constIntVal = as<ConstantIntVal>(intValue)) { - getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->getValue()); - return false; + if (constIntVal->getValue() < 1) + { + getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, constIntVal->getValue()); + return false; + } } - value = int32_t(intValue->getValue()); + value = intValue; + } + else + { + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } values[i] = value; } @@ -1317,11 +1331,11 @@ namespace Slang { SLANG_ASSERT(attr->args.getCount() == 3); - int32_t values[3]; + IntVal* values[3]; for (int i = 0; i < 3; ++i) { - int32_t value = 1; + IntVal* value = nullptr; auto arg = attr->args[i]; if (arg) @@ -1331,12 +1345,19 @@ namespace Slang { return nullptr; } - if (intValue->getValue() < 1) + if (auto cintVal = as<ConstantIntVal>(intValue)) { - getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, intValue->getValue()); - return nullptr; + if (cintVal->getValue() < 1) + { + getSink()->diagnose(attr, Diagnostics::nonPositiveNumThreads, cintVal->getValue()); + return nullptr; + } } - value = int32_t(intValue->getValue()); + value = intValue; + } + else + { + value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } values[i] = value; } diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 84c005f28..1cb6681b3 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -353,7 +353,7 @@ namespace Slang // or this reference is ill-formed. ensureDecl(valParamRef, DeclCheckState::DefinitionChecked); ConstantFoldingCircularityInfo newCircularityInfo(valParamRef.getDecl(), nullptr); - auto defaultVal = tryConstantFoldExpr(valParamRef.substitute(m_astBuilder, valParamRef.getDecl()->initExpr), &newCircularityInfo); + auto defaultVal = tryConstantFoldExpr(valParamRef.substitute(m_astBuilder, valParamRef.getDecl()->initExpr), ConstantFoldingKind::CompileTime, &newCircularityInfo); if (!defaultVal) return false; checkedArgs.add(defaultVal); diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 1090655e5..b819945fd 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -192,10 +192,10 @@ namespace Slang checkLoopInDifferentiableFunc(stmt); } - Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal) + Expr* SemanticsVisitor::checkExpressionAndExpectIntegerConstant(Expr* expr, IntVal** outIntVal, ConstantFoldingKind kind) { expr = CheckExpr(expr); - auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr); + auto intVal = CheckIntegerConstantExpression(expr, IntegerConstantExpressionCoercionType::AnyInteger, nullptr, kind); if (outIntVal) *outIntVal = intVal; return expr; @@ -214,7 +214,7 @@ namespace Slang if (stmt->rangeBeginExpr) { - stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeBeginExpr, &rangeBeginVal); + stmt->rangeBeginExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeBeginExpr, &rangeBeginVal, ConstantFoldingKind::LinkTime); } else { @@ -222,7 +222,7 @@ namespace Slang rangeBeginVal = rangeBeginConst; } - stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeEndExpr, &rangeEndVal); + stmt->rangeEndExpr = checkExpressionAndExpectIntegerConstant(stmt->rangeEndExpr, &rangeEndVal, ConstantFoldingKind::LinkTime); stmt->rangeBeginVal = rangeBeginVal; stmt->rangeEndVal = rangeEndVal; @@ -428,7 +428,7 @@ namespace Slang return; auto initialLitVal = - as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, nullptr)); + as<ConstantIntVal>(tryFoldIntegerConstantExpression(initialVal, ConstantFoldingKind::CompileTime, nullptr)); ConstantIntVal* finalVal = nullptr; auto binaryExpr = as<InfixExpr>(stmt->predicateExpression); @@ -456,7 +456,7 @@ namespace Slang return; if (!rightCompareOperand) return; - if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], nullptr)) + if (auto rightVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[1], ConstantFoldingKind::CompileTime, nullptr)) { auto leftVar = as<VarExpr>(leftCompareOperand); if (!leftVar) @@ -464,7 +464,7 @@ namespace Slang predicateVar = leftVar->declRef; finalVal = as<ConstantIntVal>(rightVal); } - else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], nullptr)) + else if (auto leftVal = tryFoldIntegerConstantExpression(binaryExpr->arguments[0], ConstantFoldingKind::CompileTime, nullptr)) { auto rightVar = as<VarExpr>(rightCompareOperand); if (!rightVar) @@ -543,7 +543,7 @@ namespace Slang return; if (opSideEffectExpr->arguments.getCount() == 2) { - auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], nullptr); + auto stepVal = tryFoldIntegerConstantExpression(opSideEffectExpr->arguments[1], ConstantFoldingKind::CompileTime, nullptr); if (!stepVal) return; if (auto constantIntVal = as<ConstantIntVal>(stepVal)) diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp index b5ee2c4fc..0ab3998d8 100644 --- a/source/slang/slang-check-type.cpp +++ b/source/slang/slang-check-type.cpp @@ -142,6 +142,7 @@ namespace Slang genericParamType ? IntegerConstantExpressionCoercionType::SpecificType : IntegerConstantExpressionCoercionType::AnyInteger, genericParamType, + ConstantFoldingKind::LinkTime, sink); if(val) return val; diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 0dd0012ad..b0350d618 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -1789,6 +1789,11 @@ namespace Slang const char* path, slang::IBlob* source, slang::IBlob** outDiagnostics = nullptr) override; + SLANG_NO_THROW slang::IModule* SLANG_MCALL loadModuleFromSourceString( + const char* moduleName, + const char* path, + const char* string, + slang::IBlob** outDiagnostics = nullptr) override; SLANG_NO_THROW SlangResult SLANG_MCALL createCompositeComponentType( slang::IComponentType* const* componentTypes, SlangInt componentTypeCount, diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 8264d256b..2ced9180e 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -323,6 +323,8 @@ DIAGNOSTIC(30066, Error, classCanOnlyBeInitializedWithNew, "a class can only be DIAGNOSTIC(30067, Error, mutatingMethodOnFunctionInputParameterError, "mutating method '$0' called on `in` parameter '$1'; changes will not be visible to caller. copy the parameter into a local variable if this behavior is intended") DIAGNOSTIC(30068, Warning, mutatingMethodOnFunctionInputParameterWarning, "mutating method '$0' called on `in` parameter '$1'; changes will not be visible to caller. copy the parameter into a local variable if this behavior is intended") +DIAGNOSTIC(30070, Error, unsizedMemberMustAppearLast, "unsized member can only appear as the last member in a composite type.") + DIAGNOSTIC(30100, Error, staticRefToNonStaticMember, "type '$0' cannot be used to refer to non-static member '$1'") DIAGNOSTIC(30200, Error, redeclaration, "declaration of '$0' conflicts with existing declaration") diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 92a9d473f..6fc94c657 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3533,7 +3533,7 @@ public: IRInst* addFloatingModeOverrideDecoration(IRInst* dest, FloatingPointMode mode); - IRInst* addNumThreadsDecoration(IRInst* inst, Int x, Int y, Int z); + IRInst* addNumThreadsDecoration(IRInst* inst, IRInst* x, IRInst* y, IRInst* z); IRInst* emitSpecializeInst( IRType* type, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index fd5ea0fc7..696e862d6 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5075,14 +5075,12 @@ namespace Slang getIntValue(getIntType(), (IRIntegerValue)mode)); } - IRInst* IRBuilder::addNumThreadsDecoration(IRInst* inst, Int x, Int y, Int z) + IRInst* IRBuilder::addNumThreadsDecoration(IRInst* inst, IRInst* x, IRInst* y, IRInst* z) { - IRType* intType = getIntType(); - IRInst* operands[3] = { - getIntValue(intType, x), - getIntValue(intType, y), - getIntValue(intType, z) + x, + y, + z }; return addDecoration(inst, kIROp_NumThreadsDecoration, operands, 3); diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index 0256464c9..b8ebf8e08 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -314,7 +314,7 @@ String getDeclSignatureString(DeclRef<Decl> declRef, WorkspaceVersion* version) SemanticsVisitor semanticsVisitor(&semanticContext); if (auto intVal = semanticsVisitor.tryFoldIntegerConstantExpression( declRef.substitute(version->linkage->getASTBuilder(), varDecl->initExpr), - nullptr)) + SemanticsVisitor::ConstantFoldingKind::LinkTime, nullptr)) { if (auto constantInt = as<ConstantIntVal>(intVal)) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 146af452f..c5a4da1f6 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7113,9 +7113,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addNumThreadsDecoration( d, - layoutLocalSizeAttr->x, - layoutLocalSizeAttr->y, - layoutLocalSizeAttr->z + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->x)), + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->y)), + getSimpleVal(context, lowerVal(context, layoutLocalSizeAttr->z)) ); } } @@ -9534,9 +9534,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addNumThreadsDecoration( irFunc, - numThreadsAttr->x, - numThreadsAttr->y, - numThreadsAttr->z + getSimpleVal(context, lowerVal(context, numThreadsAttr->x)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->y)), + getSimpleVal(context, lowerVal(context, numThreadsAttr->z)) ); } else if (as<ReadNoneAttribute>(modifier)) diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 13984e530..42eb66517 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -2833,7 +2833,10 @@ SlangResult OptionsParser::_parse( } [[fallthrough]]; default: - m_sink->diagnose(SourceLoc(), Diagnostics::cannotMatchOutputFileToEntryPoint, rawOutput.path); + if (rawOutput.path.getLength() != 0) + { + m_sink->diagnose(SourceLoc(), Diagnostics::cannotMatchOutputFileToEntryPoint, rawOutput.path); + } break; } } diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index b9eeccd30..7af3ce0a3 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -2811,9 +2811,12 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier<NumThreadsAttribute>(); if (numThreadsAttribute) { - sizeAlongAxis[0] = numThreadsAttribute->x; - sizeAlongAxis[1] = numThreadsAttribute->y; - sizeAlongAxis[2] = numThreadsAttribute->z; + if (auto cint = as<ConstantIntVal>(numThreadsAttribute->x)) + sizeAlongAxis[0] = (SlangUInt)cint->getValue(); + if (auto cint = as<ConstantIntVal>(numThreadsAttribute->y)) + sizeAlongAxis[1] = (SlangUInt)cint->getValue(); + if (auto cint = as<ConstantIntVal>(numThreadsAttribute->z)) + sizeAlongAxis[2] = (SlangUInt)cint->getValue(); } // diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index 6b1e2e115..3f49d62d2 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -1610,14 +1610,6 @@ static LayoutSize GetElementCount(IntVal* val) } else if (const auto polyIntVal = as<PolynomialIntVal>(val)) { - // TODO: We want to treat the case where the number of - // elements in an array depends on a generic parameter - // much like the case where the number of elements is - // unbounded, *but* we can't just blindly do that because - // an API might disallow unbounded arrays in various - // cases where a generic bound might work (because - // any concrete specialization will have a finite bound...) - // return 0; } SLANG_UNEXPECTED("unhandled integer literal kind"); diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 1aa014a4b..dc9f86b48 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1161,13 +1161,29 @@ slang::IModule* Linkage::loadModuleFromBlob( try { - auto name = getNamePool()->getName(moduleName); + auto getDigestStr = [](auto x) + { + DigestBuilder<SHA1> digestBuilder; + digestBuilder.append(x); + return digestBuilder.finalize().toString(); + }; + + String moduleNameStr = moduleName; + if (!moduleName) + moduleNameStr = getDigestStr(source); + + auto name = getNamePool()->getName(moduleNameStr); RefPtr<LoadedModule> loadedModule; if (mapNameToLoadedModules.tryGetValue(name, loadedModule)) { return loadedModule; } String pathStr = path; + if (pathStr.getLength() == 0) + { + // If path is empty, use a digest from source as path. + pathStr = getDigestStr(source); + } auto pathInfo = PathInfo::makeFromString(pathStr); if (File::exists(pathStr)) { @@ -1205,6 +1221,16 @@ SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSource( return loadModuleFromBlob(moduleName, path, source, ModuleBlobType::Source, outDiagnostics); } +SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromSourceString( + const char* moduleName, + const char* path, + const char* source, + slang::IBlob** outDiagnostics) +{ + auto sourceBlob = StringBlob::create(UnownedStringSlice(source)); + return loadModuleFromSource(moduleName, path, sourceBlob.get(), outDiagnostics); +} + SLANG_NO_THROW slang::IModule* SLANG_MCALL Linkage::loadModuleFromIRBlob( const char* moduleName, const char* path, diff --git a/tests/diagnostics/unsized.slang b/tests/diagnostics/unsized.slang new file mode 100644 index 000000000..6bcad95a9 --- /dev/null +++ b/tests/diagnostics/unsized.slang @@ -0,0 +1,24 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry main -stage compute + +extern static const int size = 1; + +struct V +{ + // CHECK-DAG: ([[# @LINE+1]]): error 30070 + int b[size]; + int a[]; +} + +struct Q +{ + // CHECK-DAG: ([[# @LINE+1]]): error 30070 + V v1; + V v2; +} + +[numthreads(1,1,1)] +void main() +{ + V v; + int i = 2; +}
\ No newline at end of file diff --git a/tools/gfx-unit-test/link-time-constant.cpp b/tools/gfx-unit-test/link-time-constant.cpp index 62cb294c1..b349a7dcb 100644 --- a/tools/gfx-unit-test/link-time-constant.cpp +++ b/tools/gfx-unit-test/link-time-constant.cpp @@ -79,6 +79,8 @@ namespace gfx_test R"( export static const bool turnOnFeature = true; export static const float constValue = 2.0; + export static const int numthread = 1; + export static const int arraySize = 4; )")); ComputePipelineStateDesc pipelineDesc = {}; diff --git a/tools/gfx-unit-test/link-time-constant.slang b/tools/gfx-unit-test/link-time-constant.slang index 385932dd4..b23f84ce8 100644 --- a/tools/gfx-unit-test/link-time-constant.slang +++ b/tools/gfx-unit-test/link-time-constant.slang @@ -1,17 +1,22 @@ extern static const bool turnOnFeature; extern static const float constValue; +extern static const int numthread = -1; +extern static const int arraySize = -1; // Main entry-point. Write some value into buffer depending on link // time constant. [shader("compute")] -[numthreads(4, 1, 1)] +[numthreads(numthread, 1, 1)] void computeMain( uint3 sv_dispatchThreadID: SV_DispatchThreadID, uniform RWStructuredBuffer<float> buffer) { + int array[arraySize]; + + array[sv_dispatchThreadID.x] = sv_dispatchThreadID.x; if (turnOnFeature) { - buffer[0] = constValue; + buffer[array[0]] = constValue; } else { |
