diff options
Diffstat (limited to 'source/slang/slang-check-shader.cpp')
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 174 |
1 files changed, 100 insertions, 74 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index dc4f84920..45deea109 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -238,6 +238,25 @@ DeclRef<FuncDecl> findFunctionDeclByName(Module* translationUnit, Name* name, Di { entryPointFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); + if (!entryPointFuncDeclRef) + { + if (auto genDeclRef = as<GenericDecl>(declRefExpr->declRef)) + { + SharedSemanticsContext context( + translationUnit->getLinkage(), + translationUnit, + sink); + SemanticsVisitor visitor(&context); + entryPointFuncDeclRef = createDefaultSubstitutionsIfNeeded( + translationUnit->getASTBuilder(), + &visitor, + translationUnit->getASTBuilder()->getMemberDeclRef( + genDeclRef, + genDeclRef.getDecl()->inner)) + .as<FuncDecl>(); + } + } + if (entryPointFuncDeclRef && getModule(entryPointFuncDeclRef.getDecl()) != translationUnit) entryPointFuncDeclRef = DeclRef<FuncDecl>(); } @@ -1251,9 +1270,20 @@ RefPtr<ComponentType> createUnspecializedGlobalAndEntryPointsComponentType( RefPtr<ComponentType::SpecializationInfo> Module::_validateSpecializationArgsImpl( SpecializationArg const* args, Index argCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { - SLANG_ASSERT(argCount == getSpecializationParamCount()); + if (argCount < getSpecializationParamCount()) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + getSpecializationParamCount(), + argCount); + return nullptr; + } + outConsumedArgCount = getSpecializationParamCount(); + argCount = outConsumedArgCount; SharedSemanticsContext semanticsContext(getLinkage(), this, sink); SemanticsVisitor visitor(&semanticsContext); @@ -1468,6 +1498,7 @@ static void _extractSpecializationArgs( RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArgsImpl( SpecializationArg const* inArgs, Index inArgCount, + Index& outConsumedArgCount, DiagnosticSink* sink) { auto args = inArgs; @@ -1476,15 +1507,16 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg SharedSemanticsContext sharedSemanticsContext(getLinkage(), nullptr, sink); SemanticsVisitor visitor(&sharedSemanticsContext); - // The first N arguments will be for the explicit generic parameters + // The last N arguments will be for the implicit existential arguments // of the entry point (if it has any). // + auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); auto genericSpecializationParamCount = getGenericSpecializationParamCount(); - SLANG_ASSERT(argCount >= genericSpecializationParamCount); RefPtr<EntryPointSpecializationInfo> info = new EntryPointSpecializationInfo(); DeclRef<FuncDecl> specializedFuncDeclRef = m_funcDeclRef; + Index genericArgCount = genericSpecializationParamCount; if (genericSpecializationParamCount) { // We need to construct a generic application and use @@ -1494,83 +1526,69 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg auto genericDeclRef = m_funcDeclRef.getParent().as<GenericDecl>(); SLANG_ASSERT(genericDeclRef); // otherwise we wouldn't have generic parameters - List<Val*> genericArgs; + bool isVariadic = + (genericDeclRef.getDecl()->getMembersOfType<GenericTypePackParamDecl>().getCount() != + 0); + + // If function is variadic generic, it will consume all the provided arguments. + if (isVariadic) + genericArgCount = argCount - existentialSpecializationParamCount; - for (Index ii = 0; ii < genericSpecializationParamCount; ++ii) + if (genericArgCount < 0) { - auto specializationArg = args[ii]; - genericArgs.add(specializationArg.val); + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + genericSpecializationParamCount + existentialSpecializationParamCount, + argCount); + return nullptr; } - auto astBuilder = getLinkage()->getASTBuilder(); - for (auto constraintDecl : getMembersOfType<GenericTypeConstraintDecl>( - getLinkage()->getASTBuilder(), - DeclRef<ContainerDecl>(genericDeclRef))) - { - DeclRef<GenericTypeConstraintDecl> constraintDeclRef = - astBuilder->getDirectDeclRef(constraintDecl.getDecl()); - int argIndex = -1; - int ii = 0; - - // Find the generic parameter type (T) that this constraint (T:IFoo) is applying to. - auto genericParamType = getSub(astBuilder, constraintDeclRef); - auto genParamDeclRefType = as<DeclRefType>(genericParamType); - if (!genParamDeclRefType) - { - continue; - } - auto genParamDeclRef = genParamDeclRefType->getDeclRef(); - // Find the generic argument index of the corresponding generic parameter type in the - // generic parameter set. - // - for (auto member : genericDeclRef.getDecl()->getMembersOfType<GenericTypeParamDecl>()) - { - if (member == genParamDeclRef.getDecl()) - { - argIndex = ii; - break; - } - ii++; - } - if (argIndex == -1) - { - SLANG_ASSERT(!"generic parameter not found in generic decl"); - continue; - } - auto sub = as<Type>(args[argIndex].val); - if (!sub) - { - sink->diagnose( - constraintDecl, - Diagnostics::expectedTypeForSpecializationArg, - argIndex); - continue; - } + List<Expr*> genericArgs; - auto sup = getSup(astBuilder, constraintDeclRef); - auto subTypeWitness = visitor.isSubtype(sub, sup, IsSubTypeOptions::None); - if (subTypeWitness) - { - genericArgs.add(subTypeWitness); - } - else + auto astBuilder = getLinkage()->getASTBuilder(); + for (Index ii = 0; ii < genericArgCount; ++ii) + { + auto specializationArg = args[ii]; + if (specializationArg.expr) { - // TODO: diagnose a problem here - sink->diagnose( - constraintDecl, - Diagnostics::typeArgumentDoesNotConformToInterface, - sub, - sup); + genericArgs.add(specializationArg.expr); continue; } + auto typeExpr = astBuilder->create<SharedTypeExpr>(); + typeExpr->type = astBuilder->getTypeType((Type*)specializationArg.val); + genericArgs.add(typeExpr); + } + auto genAppExpr = astBuilder->create<GenericAppExpr>(); + auto genExpr = astBuilder->create<VarExpr>(); + genExpr->declRef = genericDeclRef; + genExpr->type = astBuilder->getOrCreate<GenericDeclRefType>(); + genExpr->checked = true; + genAppExpr->functionExpr = genExpr; + genAppExpr->arguments = _Move(genericArgs); + auto checkedExpr = visitor.CheckTerm(genAppExpr); + if (auto partiallyAppliedExpr = as<PartiallyAppliedGenericExpr>(checkedExpr)) + { + // If checked generic is partially applied generic, we try to force conversion into + // a fully defined declref by calling `trySolveConstraintSystem`. + SemanticsVisitor::ConstraintSystem system; + system.genericDecl = genericDeclRef.getDecl(); + ConversionCost outCost; + specializedFuncDeclRef = visitor + .trySolveConstraintSystem( + &system, + genericDeclRef, + partiallyAppliedExpr->knownGenericArgs.getArrayView(), + outCost) + .as<FuncDecl>(); + } + else if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) + { + specializedFuncDeclRef = declRefExpr->declRef.as<FuncDecl>(); } - specializedFuncDeclRef = - getLinkage() - ->getASTBuilder() - ->getGenericAppDeclRef(genericDeclRef, genericArgs.getArrayView()) - .as<FuncDecl>(); - SLANG_ASSERT(specializedFuncDeclRef); + if (!specializedFuncDeclRef) + return nullptr; } info->specializedFuncDeclRef = specializedFuncDeclRef; @@ -1580,11 +1598,19 @@ RefPtr<ComponentType::SpecializationInfo> EntryPoint::_validateSpecializationArg // specialization parameters, attached to the value parameters // of the entry point. // - args += genericSpecializationParamCount; - argCount -= genericSpecializationParamCount; + args += genericArgCount; + argCount -= genericArgCount; + outConsumedArgCount = genericArgCount + existentialSpecializationParamCount; - auto existentialSpecializationParamCount = getExistentialSpecializationParamCount(); - SLANG_ASSERT(argCount == existentialSpecializationParamCount); + if (argCount < existentialSpecializationParamCount) + { + sink->diagnose( + SourceLoc(), + Diagnostics::mismatchSpecializationArguments, + genericSpecializationParamCount + existentialSpecializationParamCount, + argCount); + return nullptr; + } for (Index ii = 0; ii < existentialSpecializationParamCount; ++ii) { |
