summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp33
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp13
2 files changed, 31 insertions, 15 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 2613e6430..3779f48e3 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6934,9 +6934,12 @@ namespace Slang
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ if (!param->findModifier<NoDiffModifier>())
{
- arg->type.type = pairType;
+ if (auto pairType = visitor->getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
}
imaginaryArguments.add(arg);
}
@@ -6958,18 +6961,26 @@ namespace Slang
arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
arg->type.type = param->getType();
arg->loc = loc;
- if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
+ bool isDiffParam = (!param->findModifier<NoDiffModifier>());
+ if (isDiffParam)
{
- arg->type.type = pairType;
- if (isOutParam(param))
+ if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType())))
{
- // out T -> in T.Differential
- arg->type.isLeftValue = false;
- arg->type.type = visitor->tryGetDifferentialType(
- visitor->getASTBuilder(), pairType->getPrimalType());
+ arg->type.type = pairType;
+ if (isOutParam(param))
+ {
+ // out T -> in T.Differential
+ arg->type.isLeftValue = false;
+ arg->type.type = visitor->tryGetDifferentialType(
+ visitor->getASTBuilder(), pairType->getPrimalType());
+ }
+ }
+ else
+ {
+ isDiffParam = false;
}
}
- else
+ if (!isDiffParam)
{
if (isOutParam(param))
{
@@ -7010,7 +7021,7 @@ namespace Slang
HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>();
higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr;
higherOrderFuncExpr->loc = derivativeOfAttr->loc;
- Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor);
+ Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, visitor->allowStaticReferenceToNonStaticMember());
if (!checkedHigherOrderFuncExpr)
{
visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index cf45a83f5..ac4e3825a 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -525,10 +525,14 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
return InstPair(primalCall, nullptr);
}
- auto calleeType = _getCalleeActualFuncType(diffCallee);
+ auto calleeType = _getCalleeActualFuncType(primalCallee);
SLANG_ASSERT(calleeType);
SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount());
+ auto diffCalleeType = _getCalleeActualFuncType(diffCallee);
+ SLANG_ASSERT(diffCalleeType);
+ SLANG_RELEASE_ASSERT(diffCalleeType->getParamCount() == origCall->getArgCount());
+
auto placeholderCall = builder->emitCallInst(nullptr, builder->emitUndefined(builder->getTypeKind()), 0, nullptr);
builder->setInsertBefore(placeholderCall);
IRBuilder argBuilder = *builder;
@@ -545,8 +549,9 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto origType = origCall->getArg(ii)->getDataType();
auto primalType = primalArg->getDataType();
- auto paramType = calleeType->getParamType(ii);
- if (!isNoDiffType(paramType))
+ auto originalParamType = calleeType->getParamType(ii);
+ auto diffParamType = diffCalleeType->getParamType(ii);
+ if (!isNoDiffType(originalParamType))
{
if (isNoDiffType(primalType))
{
@@ -561,7 +566,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
auto pairValType = as<IRDifferentialPairType>(
pairPtrType ? pairPtrType->getValueType() : pairType);
auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType);
- if (auto ptrParamType = as<IRPtrTypeBase>(paramType))
+ if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
auto srcVar = argBuilder.emitVar(ptrParamType->getValueType());