diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-10-23 20:28:49 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-23 17:28:49 -0700 |
| commit | a0bea07503c68160ad2e88986ba98cfc2161bdff (patch) | |
| tree | 4afbd4009607a5b44e2bc72d13a27627a3501acb | |
| parent | 5a161dd799cfc62dcfee281bfaff9819a8be43ad (diff) | |
Fix several bugs with `specializeWithArgTypes()` (#5365)
* Fix several bugs with `specializeWithArgTypes()`
* Make all types L-values for the purposes of reflection API resolution
| -rwxr-xr-x | source/slang/slang-compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 80 | ||||
| -rw-r--r-- | tools/slang-unit-test/unit-test-function-reflection.cpp | 65 |
4 files changed, 146 insertions, 14 deletions
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 980490df5..c1251488b 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2281,6 +2281,8 @@ namespace Slang Expr* funcExpr, List<Type*> argTypes, DiagnosticSink* sink); + + bool isSpecialized(DeclRef<Decl> declRef); DiagnosticSink::Flags diagnosticSinkFlags = 0; diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 45bbd400c..ae351aee9 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -808,9 +808,9 @@ SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuild auto declRef = declRefExpr->declRef; if (auto genericDeclRef = declRef.as<GenericDecl>()) { - auto innerDeclRef = substituteDeclRef( - SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner); - declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef); + auto innerDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, genericDeclRef.getDecl()->inner); + declRef = substituteDeclRef( + SubstitutionSet(genericDeclRef), astBuilder, innerDeclRef); } if (auto funcDeclRef = declRef.as<FunctionDeclBase>()) @@ -3194,7 +3194,12 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes( try { DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer); - return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>()); + auto resultFunc = linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>(); + + if (sink.getErrorCount() != 0) + return nullptr; // Failed coercion. + + return convert(resultFunc); } catch (...) { diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 313e94439..7949fd9f7 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1369,16 +1369,68 @@ DeclRef<GenericDecl> getGenericParentDeclRef( // Create substituted parent decl ref. auto decl = declRef.getDecl(); - while (!as<GenericDecl>(decl)) + while (decl && !as<GenericDecl>(decl)) { decl = decl->parentDecl; } + if (!decl) + { + // No generic parent + return DeclRef<GenericDecl>(); + } + auto genericDecl = as<GenericDecl>(decl); auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as<GenericDecl>(); return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as<GenericDecl>(); } +bool Linkage::isSpecialized(DeclRef<Decl> declRef) +{ + // For now, we only support two 'states': fully applied or not at all. + // If we add support for partial specialization, we will need to update this logic. + // + // If it's not specialized, then declRef will be the one with default substitutions. + // + SemanticsVisitor visitor(getSemanticsForReflection()); + + auto decl = declRef.getDecl(); + while (decl && !as<GenericDecl>(decl)) + { + decl = decl->parentDecl; + } + + if(!decl) + return true; // no generics => always specialized + + auto defaultArgs = getDefaultSubstitutionArgs(getASTBuilder(), &visitor, as<GenericDecl>(decl)); + auto currentArgs = SubstitutionSet(declRef).findGenericAppDeclRef(as<GenericDecl>(decl))->getArgs(); + + if (defaultArgs.getCount() != currentArgs.getCount()) // should really never happen. + return true; + + for (Index i = 0; i < defaultArgs.getCount(); ++i) + { + if (defaultArgs[i] != currentArgs[i]) + return true; + } + + return false; +} + +bool isFuncGeneric(DeclRef<Decl> declRef) +{ + if (auto funcDecl = as<FuncDecl>(declRef.getDecl())) + { + if (funcDecl->parentDecl && as<GenericDecl>(funcDecl->parentDecl)) + { + return true; + } + } + + return false; +} + DeclRef<Decl> Linkage::specializeWithArgTypes( Expr* funcExpr, List<Type*> argTypes, @@ -1387,16 +1439,22 @@ DeclRef<Decl> Linkage::specializeWithArgTypes( SemanticsVisitor visitor(getSemanticsForReflection()); visitor = visitor.withSink(sink); - ASTBuilder* astBuilder = getASTBuilder(); + SLANG_AST_BUILDER_RAII(getASTBuilder()); if (auto declRefFuncExpr = as<DeclRefExpr>(funcExpr)) { - auto genericDeclRefExpr = astBuilder->create<DeclRefExpr>(); - genericDeclRefExpr->declRef = getGenericParentDeclRef( - getASTBuilder(), - &visitor, - declRefFuncExpr->declRef); - funcExpr = genericDeclRefExpr; + if (isFuncGeneric(declRefFuncExpr->declRef) && !isSpecialized(declRefFuncExpr->declRef)) + { + if (auto genericDeclRef = getGenericParentDeclRef( + getCurrentASTBuilder(), + &visitor, + declRefFuncExpr->declRef)) + { + auto genericDeclRefExpr = getCurrentASTBuilder()->create<DeclRefExpr>(); + genericDeclRefExpr->declRef = genericDeclRef; + funcExpr = genericDeclRefExpr; + } + } } List<Expr*> argExprs; @@ -1407,17 +1465,19 @@ DeclRef<Decl> Linkage::specializeWithArgTypes( // Create an 'empty' expr with the given type. Ideally, the expression itself should not matter // only its checked type. // - auto argExpr = astBuilder->create<VarExpr>(); + auto argExpr = getCurrentASTBuilder()->create<VarExpr>(); argExpr->type = argType; + argExpr->type.isLeftValue = true; argExprs.add(argExpr); } // Construct invoke expr. - auto invokeExpr = astBuilder->create<InvokeExpr>(); + auto invokeExpr = getCurrentASTBuilder()->create<InvokeExpr>(); invokeExpr->functionExpr = funcExpr; invokeExpr->arguments = argExprs; auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr); + return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef; } diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp index 2b52a8691..f893da69d 100644 --- a/tools/slang-unit-test/unit-test-function-reflection.cpp +++ b/tools/slang-unit-test/unit-test-function-reflection.cpp @@ -39,6 +39,10 @@ SLANG_UNIT_TEST(functionReflection) float foo(float x) { return x; } float foo(float x, uint i) { return x + i; } + + int bar1(IFloat a, IFloat b) { return 0; } + int bar2<T>(T a, float3 b) { return 0; } + int bar3(float3 b) { return 0; } )"; auto moduleName = "moduleG" + String(Process::getId()); @@ -122,5 +126,66 @@ SLANG_UNIT_TEST(functionReflection) }; auto resolvedFunctionReflection = overloadReflection->specializeWithArgTypes(2, argTypes); SLANG_CHECK(resolvedFunctionReflection == firstOverload); + + // + // More testing for specializeWithArgTypes + // + + // bar1 (IFloat, IFloat) -> int + // + auto bar1Reflection = module->getLayout()->findFunctionByName("bar1"); + SLANG_CHECK(bar1Reflection != nullptr); + SLANG_CHECK(bar1Reflection->isOverloaded() == false); + SLANG_CHECK(bar1Reflection->getParameterCount() == 2); + + auto float3Type = module->getLayout()->findTypeByName("float3"); + SLANG_CHECK(float3Type != nullptr); + argTypes[0] = float3Type; + argTypes[1] = float3Type; + + resolvedFunctionReflection = bar1Reflection->specializeWithArgTypes(2, argTypes); + + SLANG_CHECK(resolvedFunctionReflection != nullptr); + SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "IFloat"); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "IFloat"); + + // bar2 (T : IFloat, float3) -> int + // + auto bar2Reflection = module->getLayout()->findFunctionByName("bar2"); + SLANG_CHECK(bar2Reflection != nullptr); + SLANG_CHECK(bar2Reflection->isOverloaded() == false); + SLANG_CHECK(bar2Reflection->getParameterCount() == 2); + + auto floatType = module->getLayout()->findTypeByName("float"); + SLANG_CHECK(floatType != nullptr); + argTypes[0] = floatType; + argTypes[1] = float3Type; + + resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes); + + SLANG_CHECK(resolvedFunctionReflection != nullptr); + SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "float"); + SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "vector<float,3>"); + + + // failure case + argTypes[0] = floatType; + argTypes[1] = module->getLayout()->findTypeByName("float2"); + resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes); + SLANG_CHECK(resolvedFunctionReflection == nullptr); // any errors should result in a nullptr. + + // bar3 (float3) -> int + // (trivial case) + auto bar3Reflection = module->getLayout()->findFunctionByName("bar3"); + SLANG_CHECK(bar3Reflection != nullptr); + SLANG_CHECK(bar3Reflection->isOverloaded() == false); + SLANG_CHECK(bar3Reflection->getParameterCount() == 1); + + argTypes[0] = float3Type; + resolvedFunctionReflection = bar3Reflection->specializeWithArgTypes(1, argTypes); + SLANG_CHECK(resolvedFunctionReflection != nullptr); + SLANG_CHECK(resolvedFunctionReflection == bar3Reflection); } |
