From 96caba75e8dfbb879eff12cbe1a4c148a259f684 Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 21 Mar 2023 15:44:21 -0700 Subject: Add texture tri-linear autodiff example. (#2715) * Add quad texture example. * delete output image * remove irrelavent files * update project files * fix * Update example. * Fix. * remove out-texture --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 33 ++++++++++++++++++++++----------- source/slang/slang-ir-autodiff-fwd.cpp | 13 +++++++++---- 2 files changed, 31 insertions(+), 15 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2613e6430..3779f48e3 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6934,9 +6934,12 @@ namespace Slang arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = visitor->getDifferentialPairType(param->getType())) + if (!param->findModifier()) { - arg->type.type = pairType; + if (auto pairType = visitor->getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } } imaginaryArguments.add(arg); } @@ -6958,18 +6961,26 @@ namespace Slang arg->type.isLeftValue = param->findModifier() ? true : false; arg->type.type = param->getType(); arg->loc = loc; - if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) + bool isDiffParam = (!param->findModifier()); + if (isDiffParam) { - arg->type.type = pairType; - if (isOutParam(param)) + if (auto pairType = as(visitor->getDifferentialPairType(param->getType()))) { - // out T -> in T.Differential - arg->type.isLeftValue = false; - arg->type.type = visitor->tryGetDifferentialType( - visitor->getASTBuilder(), pairType->getPrimalType()); + arg->type.type = pairType; + if (isOutParam(param)) + { + // out T -> in T.Differential + arg->type.isLeftValue = false; + arg->type.type = visitor->tryGetDifferentialType( + visitor->getASTBuilder(), pairType->getPrimalType()); + } + } + else + { + isDiffParam = false; } } - else + if (!isDiffParam) { if (isOutParam(param)) { @@ -7010,7 +7021,7 @@ namespace Slang HigherOrderInvokeExpr* higherOrderFuncExpr = visitor->getASTBuilder()->create(); higherOrderFuncExpr->baseFunction = derivativeOfAttr->funcExpr; higherOrderFuncExpr->loc = derivativeOfAttr->loc; - Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, *visitor); + Expr* checkedHigherOrderFuncExpr = visitor->dispatchExpr(higherOrderFuncExpr, visitor->allowStaticReferenceToNonStaticMember()); if (!checkedHigherOrderFuncExpr) { visitor->getSink()->diagnose(derivativeOfAttr, Diagnostics::cannotResolveOriginalFunctionForDerivative); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index cf45a83f5..ac4e3825a 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -525,10 +525,14 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig return InstPair(primalCall, nullptr); } - auto calleeType = _getCalleeActualFuncType(diffCallee); + auto calleeType = _getCalleeActualFuncType(primalCallee); SLANG_ASSERT(calleeType); SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount()); + auto diffCalleeType = _getCalleeActualFuncType(diffCallee); + SLANG_ASSERT(diffCalleeType); + SLANG_RELEASE_ASSERT(diffCalleeType->getParamCount() == origCall->getArgCount()); + auto placeholderCall = builder->emitCallInst(nullptr, builder->emitUndefined(builder->getTypeKind()), 0, nullptr); builder->setInsertBefore(placeholderCall); IRBuilder argBuilder = *builder; @@ -545,8 +549,9 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto origType = origCall->getArg(ii)->getDataType(); auto primalType = primalArg->getDataType(); - auto paramType = calleeType->getParamType(ii); - if (!isNoDiffType(paramType)) + auto originalParamType = calleeType->getParamType(ii); + auto diffParamType = diffCalleeType->getParamType(ii); + if (!isNoDiffType(originalParamType)) { if (isNoDiffType(primalType)) { @@ -561,7 +566,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto pairValType = as( pairPtrType ? pairPtrType->getValueType() : pairType); auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType); - if (auto ptrParamType = as(paramType)) + if (auto ptrParamType = as(diffParamType)) { // Create temp var to pass in/out arguments. auto srcVar = argBuilder.emitVar(ptrParamType->getValueType()); -- cgit v1.2.3