diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-21 15:44:21 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-21 15:44:21 -0700 |
| commit | 96caba75e8dfbb879eff12cbe1a4c148a259f684 (patch) | |
| tree | 1c7b2f25484ac22c738e006334d4df559bb733a5 /source | |
| parent | 7f11f883d0781952f002b3aa3222a3aa0040f18a (diff) | |
Add texture tri-linear autodiff example. (#2715)
* Add quad texture example.
* delete output image
* remove irrelavent files
* update project files
* fix
* Update example.
* Fix.
* remove out-texture
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 33 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 13 |
2 files changed, 31 insertions, 15 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2613e6430..3779f48e3 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6934,9 +6934,12 @@ namespace Slang arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) + if (!param->findModifier<NoDiffModifier>()) { - arg->type.type = pairType; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } } imaginaryArguments.add(arg); } @@ -6958,18 +6961,26 @@ namespace Slang arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) + bool isDiffParam = (!param->findModifier<NoDiffModifier>()); + if (isDiffParam) { - arg->type.type = pairType; - if (isOutParam(param)) + if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) { - // out T -> in T.Differential - arg->type.isLeftValue = false; - arg->type.type = visitor->tryGetDifferentialType( - visitor->getASTBuilder(), pairType->getPrimalType()); + arg->type.type = pairType; + if (isOutParam(param)) + { + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + isDiffParam = false; } } - else + if (!isDiffParam) { if (isOutParam(param)) { @@ -7010,7 +7021,7 @@ namespace Slang HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create<TDifferentiateExpr>(); higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; higherOrderFuncExpr->loc = derivativeOfAttr->loc; - Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor); + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, visitor->allowStaticReferenceToNonStaticMember()); if (!checkedHigherOrderFuncExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index cf45a83f5..ac4e3825a 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -525,10 +525,14 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig return InstPair(primalCall, nullptr); } - auto calleeType = _getCalleeActualFuncType(diffCallee); + auto calleeType = _getCalleeActualFuncType(primalCallee); SLANG_ASSERT(calleeType); SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount()); + auto diffCalleeType = _getCalleeActualFuncType(diffCallee); + SLANG_ASSERT(diffCalleeType); + SLANG_RELEASE_ASSERT(diffCalleeType->getParamCount() == origCall->getArgCount()); + auto placeholderCall = builder->emitCallInst(nullptr, builder->emitUndefined(builder->getTypeKind()), 0, nullptr); builder->setInsertBefore(placeholderCall); IRBuilder argBuilder = *builder; @@ -545,8 +549,9 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto origType = origCall->getArg(ii)->getDataType(); auto primalType = primalArg->getDataType(); - auto paramType = calleeType->getParamType(ii); - if (!isNoDiffType(paramType)) + auto originalParamType = calleeType->getParamType(ii); + auto diffParamType = diffCalleeType->getParamType(ii); + if (!isNoDiffType(originalParamType)) { if (isNoDiffType(primalType)) { @@ -561,7 +566,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto pairValType = as<IRDifferentialPairType>( pairPtrType ? pairPtrType->getValueType() : pairType); auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType); - if (auto ptrParamType = as<IRPtrTypeBase>(paramType)) + if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) { // Create temp var to pass in/out arguments. auto srcVar = argBuilder.emitVar(ptrParamType->getValueType()); |
