summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp66
1 files changed, 27 insertions, 39 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b7f99c4e7..a787af211 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1526,48 +1526,42 @@ namespace Slang
return expr;
}
- // This function proceses primal params (i.e params of the inner function that is being
- // differentiated) that need to be carried over to the function signature for the JVP
- // function. (eg. out types can be discarded)
- //
- Type* primalToInputType(ASTBuilder*, Type* primalType)
- {
- if (auto primalOutType = as<OutType>(primalType))
- return nullptr;
- else if (auto primalInOutType = as<InOutType>(primalType))
- return primalInOutType->getValueType();
-
- return primalType;
- }
- Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
- // Only float and vector<float> types can be differentiated for now.
-
- if (primalType->equals(builder->getFloatType()))
- return primalType;
- else if (auto primalVectorType = as<VectorExpressionType>(primalType))
- {
- if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType))
- return builder->getVectorType(jvpElementType, primalVectorType->elementCount);
- }
- else if (auto primalOutType = as<OutType>(primalType))
+ // Check for type modifiers like 'out' and 'inout'. We need to differentiate the
+ // nested type.
+ //
+ if (auto primalOutType = as<OutType>(primalType))
{
- return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType()));
+ return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType()));
}
else if (auto primalInOutType = as<InOutType>(primalType))
{
- return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType()));
+ return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType()));
}
- return nullptr;
+
+ // Get a reference to the builtin 'IDifferentiable' interface
+ auto differentiableInterface = builder->getDifferentiableInterface();
+
+ // Check if the provided type inherits from IDifferentiable.
+ // If not, return the original type.
+ if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)))
+ return builder->getDifferentialPairType(primalType, conformanceWitness);
+ else
+ return primalType;
+
}
- Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType)
{
- if(auto jvpType = primalToJVPParamType(builder, primalType))
- return jvpType;
+ if (auto conformanceWitness =
+ as<Witness>(tryGetInterfaceConformanceWitness(
+ primalType,
+ builder->getDifferentiableInterface())))
+ return builder->getDifferentialPairType(primalType, conformanceWitness);
else
- return builder->getVoidType();
+ return primalType;
}
Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
@@ -1588,7 +1582,7 @@ namespace Slang
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType());
+ jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType());
// No support for differentiating function that throw errors, for now.
SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
@@ -1596,13 +1590,7 @@ namespace Slang
for (UInt i = 0; i < primalType->getParamCount(); i++)
{
- if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i)))
- jvpType->paramTypes.add(primalInputType);
- }
-
- for (UInt i = 0; i < primalType->getParamCount(); i++)
- {
- if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i)))
+ if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i)))
jvpType->paramTypes.add(jvpParamType);
}