summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rwxr-xr-xsource/slang/slang-compiler.h2
-rw-r--r--source/slang/slang-reflection-api.cpp13
-rw-r--r--source/slang/slang.cpp80
-rw-r--r--tools/slang-unit-test/unit-test-function-reflection.cpp65
4 files changed, 146 insertions, 14 deletions
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 980490df5..c1251488b 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2281,6 +2281,8 @@ namespace Slang
Expr* funcExpr,
List<Type*> argTypes,
DiagnosticSink* sink);
+
+ bool isSpecialized(DeclRef<Decl> declRef);
DiagnosticSink::Flags diagnosticSinkFlags = 0;
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp
index 45bbd400c..ae351aee9 100644
--- a/source/slang/slang-reflection-api.cpp
+++ b/source/slang/slang-reflection-api.cpp
@@ -808,9 +808,9 @@ SlangReflectionFunction* tryConvertExprToFunctionReflection(ASTBuilder* astBuild
auto declRef = declRefExpr->declRef;
if (auto genericDeclRef = declRef.as<GenericDecl>())
{
- auto innerDeclRef = substituteDeclRef(
- SubstitutionSet(genericDeclRef), astBuilder, genericDeclRef.getDecl()->inner);
- declRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, innerDeclRef);
+ auto innerDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, nullptr, genericDeclRef.getDecl()->inner);
+ declRef = substituteDeclRef(
+ SubstitutionSet(genericDeclRef), astBuilder, innerDeclRef);
}
if (auto funcDeclRef = declRef.as<FunctionDeclBase>())
@@ -3194,7 +3194,12 @@ SLANG_API SlangReflectionFunction* spReflectionFunction_specializeWithArgTypes(
try
{
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
- return convert(linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>());
+ auto resultFunc = linkage->specializeWithArgTypes(funcExpr, argTypeList, &sink).as<FunctionDeclBase>();
+
+ if (sink.getErrorCount() != 0)
+ return nullptr; // Failed coercion.
+
+ return convert(resultFunc);
}
catch (...)
{
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 313e94439..7949fd9f7 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -1369,16 +1369,68 @@ DeclRef<GenericDecl> getGenericParentDeclRef(
// Create substituted parent decl ref.
auto decl = declRef.getDecl();
- while (!as<GenericDecl>(decl))
+ while (decl && !as<GenericDecl>(decl))
{
decl = decl->parentDecl;
}
+ if (!decl)
+ {
+ // No generic parent
+ return DeclRef<GenericDecl>();
+ }
+
auto genericDecl = as<GenericDecl>(decl);
auto genericDeclRef = createDefaultSubstitutionsIfNeeded(astBuilder, visitor, DeclRef(genericDecl)).as<GenericDecl>();
return substituteDeclRef(SubstitutionSet(declRef), astBuilder, genericDeclRef).as<GenericDecl>();
}
+bool Linkage::isSpecialized(DeclRef<Decl> declRef)
+{
+ // For now, we only support two 'states': fully applied or not at all.
+ // If we add support for partial specialization, we will need to update this logic.
+ //
+ // If it's not specialized, then declRef will be the one with default substitutions.
+ //
+ SemanticsVisitor visitor(getSemanticsForReflection());
+
+ auto decl = declRef.getDecl();
+ while (decl && !as<GenericDecl>(decl))
+ {
+ decl = decl->parentDecl;
+ }
+
+ if(!decl)
+ return true; // no generics => always specialized
+
+ auto defaultArgs = getDefaultSubstitutionArgs(getASTBuilder(), &visitor, as<GenericDecl>(decl));
+ auto currentArgs = SubstitutionSet(declRef).findGenericAppDeclRef(as<GenericDecl>(decl))->getArgs();
+
+ if (defaultArgs.getCount() != currentArgs.getCount()) // should really never happen.
+ return true;
+
+ for (Index i = 0; i < defaultArgs.getCount(); ++i)
+ {
+ if (defaultArgs[i] != currentArgs[i])
+ return true;
+ }
+
+ return false;
+}
+
+bool isFuncGeneric(DeclRef<Decl> declRef)
+{
+ if (auto funcDecl = as<FuncDecl>(declRef.getDecl()))
+ {
+ if (funcDecl->parentDecl && as<GenericDecl>(funcDecl->parentDecl))
+ {
+ return true;
+ }
+ }
+
+ return false;
+}
+
DeclRef<Decl> Linkage::specializeWithArgTypes(
Expr* funcExpr,
List<Type*> argTypes,
@@ -1387,16 +1439,22 @@ DeclRef<Decl> Linkage::specializeWithArgTypes(
SemanticsVisitor visitor(getSemanticsForReflection());
visitor = visitor.withSink(sink);
- ASTBuilder* astBuilder = getASTBuilder();
+ SLANG_AST_BUILDER_RAII(getASTBuilder());
if (auto declRefFuncExpr = as<DeclRefExpr>(funcExpr))
{
- auto genericDeclRefExpr = astBuilder->create<DeclRefExpr>();
- genericDeclRefExpr->declRef = getGenericParentDeclRef(
- getASTBuilder(),
- &visitor,
- declRefFuncExpr->declRef);
- funcExpr = genericDeclRefExpr;
+ if (isFuncGeneric(declRefFuncExpr->declRef) && !isSpecialized(declRefFuncExpr->declRef))
+ {
+ if (auto genericDeclRef = getGenericParentDeclRef(
+ getCurrentASTBuilder(),
+ &visitor,
+ declRefFuncExpr->declRef))
+ {
+ auto genericDeclRefExpr = getCurrentASTBuilder()->create<DeclRefExpr>();
+ genericDeclRefExpr->declRef = genericDeclRef;
+ funcExpr = genericDeclRefExpr;
+ }
+ }
}
List<Expr*> argExprs;
@@ -1407,17 +1465,19 @@ DeclRef<Decl> Linkage::specializeWithArgTypes(
// Create an 'empty' expr with the given type. Ideally, the expression itself should not matter
// only its checked type.
//
- auto argExpr = astBuilder->create<VarExpr>();
+ auto argExpr = getCurrentASTBuilder()->create<VarExpr>();
argExpr->type = argType;
+ argExpr->type.isLeftValue = true;
argExprs.add(argExpr);
}
// Construct invoke expr.
- auto invokeExpr = astBuilder->create<InvokeExpr>();
+ auto invokeExpr = getCurrentASTBuilder()->create<InvokeExpr>();
invokeExpr->functionExpr = funcExpr;
invokeExpr->arguments = argExprs;
auto checkedInvokeExpr = visitor.CheckInvokeExprWithCheckedOperands(invokeExpr);
+
return as<DeclRefExpr>(as<InvokeExpr>(checkedInvokeExpr)->functionExpr)->declRef;
}
diff --git a/tools/slang-unit-test/unit-test-function-reflection.cpp b/tools/slang-unit-test/unit-test-function-reflection.cpp
index 2b52a8691..f893da69d 100644
--- a/tools/slang-unit-test/unit-test-function-reflection.cpp
+++ b/tools/slang-unit-test/unit-test-function-reflection.cpp
@@ -39,6 +39,10 @@ SLANG_UNIT_TEST(functionReflection)
float foo(float x) { return x; }
float foo(float x, uint i) { return x + i; }
+
+ int bar1(IFloat a, IFloat b) { return 0; }
+ int bar2<T>(T a, float3 b) { return 0; }
+ int bar3(float3 b) { return 0; }
)";
auto moduleName = "moduleG" + String(Process::getId());
@@ -122,5 +126,66 @@ SLANG_UNIT_TEST(functionReflection)
};
auto resolvedFunctionReflection = overloadReflection->specializeWithArgTypes(2, argTypes);
SLANG_CHECK(resolvedFunctionReflection == firstOverload);
+
+ //
+ // More testing for specializeWithArgTypes
+ //
+
+ // bar1 (IFloat, IFloat) -> int
+ //
+ auto bar1Reflection = module->getLayout()->findFunctionByName("bar1");
+ SLANG_CHECK(bar1Reflection != nullptr);
+ SLANG_CHECK(bar1Reflection->isOverloaded() == false);
+ SLANG_CHECK(bar1Reflection->getParameterCount() == 2);
+
+ auto float3Type = module->getLayout()->findTypeByName("float3");
+ SLANG_CHECK(float3Type != nullptr);
+ argTypes[0] = float3Type;
+ argTypes[1] = float3Type;
+
+ resolvedFunctionReflection = bar1Reflection->specializeWithArgTypes(2, argTypes);
+
+ SLANG_CHECK(resolvedFunctionReflection != nullptr);
+ SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2);
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "IFloat");
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "IFloat");
+
+ // bar2 (T : IFloat, float3) -> int
+ //
+ auto bar2Reflection = module->getLayout()->findFunctionByName("bar2");
+ SLANG_CHECK(bar2Reflection != nullptr);
+ SLANG_CHECK(bar2Reflection->isOverloaded() == false);
+ SLANG_CHECK(bar2Reflection->getParameterCount() == 2);
+
+ auto floatType = module->getLayout()->findTypeByName("float");
+ SLANG_CHECK(floatType != nullptr);
+ argTypes[0] = floatType;
+ argTypes[1] = float3Type;
+
+ resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes);
+
+ SLANG_CHECK(resolvedFunctionReflection != nullptr);
+ SLANG_CHECK(resolvedFunctionReflection->getParameterCount() == 2);
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(0)->getType()) == "float");
+ SLANG_CHECK(getTypeFullName(resolvedFunctionReflection->getParameterByIndex(1)->getType()) == "vector<float,3>");
+
+
+ // failure case
+ argTypes[0] = floatType;
+ argTypes[1] = module->getLayout()->findTypeByName("float2");
+ resolvedFunctionReflection = bar2Reflection->specializeWithArgTypes(2, argTypes);
+ SLANG_CHECK(resolvedFunctionReflection == nullptr); // any errors should result in a nullptr.
+
+ // bar3 (float3) -> int
+ // (trivial case)
+ auto bar3Reflection = module->getLayout()->findFunctionByName("bar3");
+ SLANG_CHECK(bar3Reflection != nullptr);
+ SLANG_CHECK(bar3Reflection->isOverloaded() == false);
+ SLANG_CHECK(bar3Reflection->getParameterCount() == 1);
+
+ argTypes[0] = float3Type;
+ resolvedFunctionReflection = bar3Reflection->specializeWithArgTypes(1, argTypes);
+ SLANG_CHECK(resolvedFunctionReflection != nullptr);
+ SLANG_CHECK(resolvedFunctionReflection == bar3Reflection);
}