diff options
24 files changed, 512 insertions, 171 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index f016ae3d8..a535ba104 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -6918,6 +6918,8 @@ namespace Slang // has an associated derivative function. if (func->findModifier<BackwardDifferentiableAttribute>()) return true; + if (func->findModifier<BackwardDerivativeAttribute>()) + return true; for (auto assocDecl : getAssociatedDeclsForDecl(func)) { switch (assocDecl.kind) diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index f505b1321..f3623f19f 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -613,7 +613,7 @@ namespace Slang hitObjectAttributesAttr->location = (int32_t)val->value; } - else if (auto forwardDerivativeAttr = as<ForwardDerivativeAttribute>(attr)) + else if (auto derivativeAttr = as<UserDefinedDerivativeAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); SLANG_ASSERT(as<Decl>(attrTarget)); @@ -633,7 +633,7 @@ namespace Slang // // Set type to null to indicate that this needs expr needs to be further resolved. diffExpr->type.type = nullptr; - forwardDerivativeAttr->funcExpr = diffExpr; + derivativeAttr->funcExpr = diffExpr; } else if (auto derivativeOfAttr = as<DerivativeOfAttribute>(attr)) { diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 54d32ae3e..68a86bc00 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -287,6 +287,33 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* } } +static bool _isDifferentiableFunc(IRInst* func) +{ + for (auto decor = func->getFirstDecoration(); decor; decor = decor->getNextDecoration()) + { + switch (decor->getOp()) + { + case kIROp_ForwardDerivativeDecoration: + case kIROp_ForwardDifferentiableDecoration: + case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDifferentiableDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + return true; + } + } + return false; +} + +static IRFuncType* _getCalleeActualFuncType(IRInst* callee) +{ + auto type = callee->getFullType(); + if (auto funcType = as<IRFuncType>(type)) + return funcType; + if (auto specialize = as<IRSpecialize>(callee)) + return as<IRFuncType>(findGenericReturnVal(as<IRGeneric>(specialize->getBase()))->getFullType()); + return nullptr; +} + // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials // in the current transcription context. @@ -310,10 +337,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig return InstPair(nullptr, nullptr); } - // Since concrete functions are globals, the primal callee is the same - // as the original callee. - // - auto primalCallee = origCallee; + auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee); IRInst* diffCallee = nullptr; @@ -325,8 +349,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig // If the user has already provided an differentiated implementation, use that. diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); } - else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>() || - primalCallee->findDecoration<IRBackwardDifferentiableDecoration>()) + else if (_isDifferentiableFunc(primalCallee)) { // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. @@ -343,7 +366,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig return InstPair(primalCall, nullptr); } - auto calleeType = as<IRFuncType>(diffCallee->getDataType()); + auto calleeType = _getCalleeActualFuncType(diffCallee); SLANG_ASSERT(calleeType); SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount()); @@ -399,6 +422,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig diffCallee, args); builder->markInstAsMixedDifferential(callInst, diffReturnType); + builder->addAutoDiffOriginalValueDecoration(callInst, primalCallee); if (diffReturnType->getOp() != kIROp_VoidType) { @@ -629,7 +653,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } - else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>()) + else if (_isDifferentiableFunc(genericInnerVal)) { List<IRInst*> args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) @@ -927,8 +951,15 @@ InstPair ForwardDiffTranscriber::transcribeWrapExistential(IRBuilder* builder, I // Create an empty func to represent the transcribed func of `origFunc`. InstPair ForwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { - if (auto bwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>()) - return InstPair(origFunc, bwdDecor->getForwardDerivativeFunc()); + if (auto fwdDecor = origFunc->findDecoration<IRForwardDerivativeDecoration>()) + { + // If we reach here, the function must have been used directly in a `call` inst, and therefore + // can't be a generic. + // Generic function are always referenced with `specialize` inst and the handling logic for + // custom derivatives is implemented in `transcribeSpecialize`. + SLANG_RELEASE_ASSERT(fwdDecor->getForwardDerivativeFunc()->getOp() == kIROp_Func); + return InstPair(origFunc, fwdDecor->getForwardDerivativeFunc()); + } auto diffFunc = transcribeFuncHeaderImpl(inBuilder, origFunc); @@ -1012,51 +1043,6 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr return InstPair(primalFunc, diffFunc); } -// Transcribe a generic definition -InstPair ForwardDiffTranscriber::transcribeGeneric(IRBuilder* inBuilder, IRGeneric* origGeneric) -{ - auto innerVal = findInnerMostGenericReturnVal(origGeneric); - if (auto innerFunc = as<IRFunc>(innerVal)) - { - differentiableTypeConformanceContext.setFunc(innerFunc); - } - else if (auto funcType = as<IRFuncType>(innerVal)) - { - } - else - { - return InstPair(origGeneric, nullptr); - } - - IRGeneric* primalGeneric = origGeneric; - - IRBuilder builder(inBuilder->getSharedBuilder()); - builder.setInsertBefore(origGeneric); - - auto diffGeneric = builder.emitGeneric(); - - // Process type of generic. If the generic is a function, then it's type will also be a - // generic and this logic will transcribe that generic first before continuing with the - // function itself. - // - auto primalType = primalGeneric->getFullType(); - - IRType* diffType = nullptr; - if (primalType) - { - diffType = (IRType*) findOrTranscribeDiffInst(&builder, primalType); - } - - diffGeneric->setFullType(diffType); - - // Transcribe children from origFunc into diffFunc. - builder.setInsertInto(diffGeneric); - for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); - - return InstPair(primalGeneric, diffGeneric); -} - InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* origInst) { // Handle common SSA-style operations diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index af408a5b3..8d6419cf2 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -46,7 +46,8 @@ namespace Slang } } - newParameterTypes.add(differentiateType(builder, funcType->getResultType())); + if (auto diffResultType = differentiateType(builder, funcType->getResultType())) + newParameterTypes.add(diffResultType); if (intermeidateType) { @@ -58,20 +59,14 @@ namespace Slang return builder->getFuncType(newParameterTypes, diffReturnType); } - static IRInst* getOriginalFuncRef(IRBuilder& builder, IRInst* func, IRInst* useSite) - { - if (!func) return nullptr; - auto userGeneric = findOuterGeneric(useSite); - if (!userGeneric) return func; - auto funcGen = findOuterGeneric(func); - SLANG_RELEASE_ASSERT(funcGen); - return maybeSpecializeWithGeneric(builder, funcGen, userGeneric); - } - IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { - auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent()); - auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef); + IRType* intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) + { + intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric)); + } + auto outType = builder->getOutType(intermediateType); List<IRType*> paramTypes; for (UInt i = 0; i < funcType->getParamCount(); i++) @@ -91,13 +86,98 @@ namespace Slang // Don't need to do anything other than add a decoration in the original func to point to the primal func. // The body of the primal func will be generated by propagateTranscriber together with propagate func. addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); - return InstPair(primalFunc, primalFunc); + return InstPair(primalFunc, diffFunc); + } + + static List<IRInst*> _defineFuncParams(IRBuilder* builder, IRFunc* func) + { + auto propFuncType = cast<IRFuncType>(func->getFullType()); + List<IRInst*> params; + for (UInt i = 0; i < propFuncType->getParamCount(); i++) + { + auto paramType = propFuncType->getParamType(i); + auto param = builder->emitParam(paramType); + params.add(param); + } + return params; + } + + void BackwardDiffPropagateTranscriber::generateTrivialDiffFuncFromUserDefinedDerivative( + IRBuilder* builder, + IRFunc* originalFunc, + IRFunc* diffPropFunc, + IRUserDefinedBackwardDerivativeDecoration* udfDecor) + { + // Create an empty struct type to use as the intermediate context type. + auto originalGeneric = findOuterGeneric(originalFunc); + builder->setInsertBefore(originalFunc); + IRInst* emptyStruct = builder->createStructType(); + IRInst* emptyStructType = nullptr; + auto emptyStructGeneric = hoistValueFromGeneric(*builder, emptyStruct, emptyStructType, false); + builder->addBackwardDerivativeIntermediateTypeDecoration(originalFunc, emptyStructGeneric); + + IRInst* udf = udfDecor->getBackwardDerivativeFunc(); + builder->setInsertInto(diffPropFunc); + builder->emitBlock(); + List<IRInst*> params = _defineFuncParams(builder, diffPropFunc); + params.removeLast(); + IRInst* udfRefFromPropFunc = udf; + if (auto specialize = as<IRSpecialize>(udf)) + { + udf = specialize->getBase(); + auto propGeneric = findOuterGeneric(diffPropFunc); + SLANG_RELEASE_ASSERT(propGeneric); + udfRefFromPropFunc = maybeSpecializeWithGeneric(*builder, udf, propGeneric); + } + builder->emitCallInst(builder->getVoidType(), udfRefFromPropFunc, params); + builder->emitReturn(); + + // Now create the trivial primal function. + auto existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>(); + if (!existingDecor) + { + // We haven't created a header for primal func yet, create it now. + if (originalGeneric) + builder->setInsertBefore(originalGeneric); + else + builder->setInsertBefore(originalFunc); + + autoDiffSharedContext->transcriberSet.primalTranscriber->transcribe(builder, originalGeneric ? originalGeneric : originalFunc); + existingDecor = originalFunc->findDecoration<IRBackwardDerivativePrimalDecoration>(); + } + SLANG_RELEASE_ASSERT(existingDecor); + + // Fill the primal func header with trivial call to original func. + IRInst* existingPrimalFunc = existingDecor->getBackwardDerivativePrimalFunc(); + IRGeneric* existingPriamlFuncGeneric = nullptr; + if (auto specialize = as<IRSpecialize>(existingPrimalFunc)) + { + existingPriamlFuncGeneric = as<IRGeneric>(specialize->getBase()); + existingPrimalFunc = findGenericReturnVal(existingPriamlFuncGeneric); + } + builder->setInsertBefore(existingPrimalFunc); + + builder->setInsertInto(existingPrimalFunc); + builder->emitBlock(); + params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc)); + params.removeLast(); + IRInst* originalFuncRefFromPrimalFunc = originalFunc; + if (originalGeneric) + originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric); + auto result = builder->emitCallInst( + cast<IRFuncType>(existingPrimalFunc->getFullType())->getResultType(), + originalFuncRefFromPrimalFunc, + params); + builder->emitReturn(result); } IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) { - auto funcRef = getOriginalFuncRef(*builder, func, builder->getInsertLoc().getParent()); - auto intermediateType = builder->getBackwardDiffIntermediateContextType(funcRef); + IRType* intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); + if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) + { + intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as<IRGeneric>(outerGeneric)); + } return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } @@ -109,9 +189,15 @@ namespace Slang InstPair BackwardDiffPropagateTranscriber::transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { - IRGlobalValueWithCode* diffPrimalFunc = nullptr; addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); - transcribeFuncImpl(builder, primalFunc, diffFunc, diffPrimalFunc); + if (auto udf = primalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + { + generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf); + } + else + { + transcribeFuncImpl(builder, primalFunc, diffFunc); + } return InstPair(primalFunc, diffFunc); } @@ -212,18 +298,13 @@ namespace Slang return InstPair(diffBlock, diffBlock); } - static bool isMarkedForBackwardDifferentiation(IRInst* callable) - { - return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; - } - // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc) { if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) return InstPair(origFunc, bwdDiffFunc); - if (!isMarkedForBackwardDifferentiation(origFunc)) + if (!isBackwardDifferentiableFunc(origFunc)) return InstPair(nullptr, nullptr); IRBuilder builder = *inBuilder; @@ -253,7 +334,6 @@ namespace Slang // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); - // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) { @@ -339,13 +419,14 @@ namespace Slang } auto outerGeneric = findOuterGeneric(origFunc); + IRType* intermediateType = builder.getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(origFunc)); IRInst* specializedOriginalFunc = origFunc; if (outerGeneric) { specializedOriginalFunc = maybeSpecializeWithGeneric(builder, outerGeneric, findOuterGeneric(header.differential)); + intermediateType = (IRType*)specializeWithGeneric(builder, intermediateType, as<IRGeneric>(findOuterGeneric(header.differential))); } - auto intermediateType = builder.getBackwardDiffIntermediateContextType(specializedOriginalFunc); auto intermediateVar = builder.emitVar(intermediateType); auto origFuncType = as<IRFuncType>(origFunc->getDataType()); @@ -420,11 +501,19 @@ namespace Slang eliminateDeadCode(primalOuterParent); // Forward transcribe the clone of the original func. - ForwardDiffTranscriber fwdTranscriber(autoDiffSharedContext, builder->getSharedBuilder(), sink); - fwdTranscriber.pairBuilder = pairBuilder; + ForwardDiffTranscriber& fwdTranscriber = *static_cast<ForwardDiffTranscriber*>( + autoDiffSharedContext->transcriberSet.forwardTranscriber); + auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); IRFunc* fwdDiffFunc = as<IRFunc>(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent))); SLANG_ASSERT(fwdDiffFunc); - fwdTranscriber.transcribeFunc(builder, primalFunc, fwdDiffFunc); + auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); + for (auto i = oldCount; i < newCount; i++) + { + auto pendingTask = autoDiffSharedContext->followUpFunctionsToTranscribe.getLast(); + autoDiffSharedContext->followUpFunctionsToTranscribe.removeLast(); + SLANG_RELEASE_ASSERT(pendingTask.type == FuncBodyTranscriptionTaskType::Forward); + fwdTranscriber.transcribeFunc(builder, pendingTask.originalFunc, pendingTask.resultFunc); + } // Remove the clone of original func. primalOuterParent->removeAndDeallocate(); @@ -453,12 +542,11 @@ namespace Slang } fwdParentGeneric->removeAndDeallocate(); } - return fwdDiffFunc; } // Transcribe a function definition. - void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc) + void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc) { SLANG_ASSERT(primalFunc); SLANG_ASSERT(diffPropagateFunc); @@ -546,8 +634,7 @@ namespace Slang IRInst* specializedFunc = nullptr; auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc, true); builder->setInsertBefore(primalFunc); - auto specializedIntermeidateType = maybeSpecializeWithGeneric(*builder, intermediateTypeGeneric, primalOuterGeneric); - builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, specializedIntermeidateType); + builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, intermediateTypeGeneric); auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true); builder->setInsertBefore(primalFunc); @@ -567,7 +654,6 @@ namespace Slang auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric); builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); } - diffPrimalFunc = as<IRGlobalValueWithCode>(primalOuterGeneric); } void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) @@ -900,7 +986,7 @@ namespace Slang return InstPair(primalSpecialize, diffSpecialize); } - else if (auto diffDecor = genericInnerVal->findDecoration<IRBackwardDifferentiableDecoration>()) + else if (isBackwardDifferentiableFunc(genericInnerVal)) { List<IRInst*> args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 02a100c80..228bcf588 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -87,7 +87,7 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase IRFunc* generateNewForwardDerivativeForFunc(IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc); - void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc, IRGlobalValueWithCode*& diffPrimalFunc); + void transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc); InstPair transcribeFuncHeaderImpl(IRBuilder* inBuilder, IRFunc* origFunc); @@ -144,6 +144,11 @@ struct BackwardDiffPropagateTranscriber : BackwardDiffTranscriberBase inSharedBuilder, inSink) { } + void generateTrivialDiffFuncFromUserDefinedDerivative( + IRBuilder* builder, + IRFunc* primalFunc, + IRFunc* diffPropFunc, + IRUserDefinedBackwardDerivativeDecoration* udfDecor); virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) override; @@ -189,6 +194,10 @@ struct BackwardDiffTranscriber : BackwardDiffTranscriberBase { return backDecor->getBackwardDerivativeFunc(); } + if (auto backDecor = originalFunc->findDecoration<IRUserDefinedBackwardDerivativeDecoration>()) + { + return backDecor->getBackwardDerivativeFunc(); + } return nullptr; } virtual void addExistingDiffFuncDecor(IRBuilder* builder, IRInst* inst, IRInst* diffFunc) override diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index deb1b2da9..275b40b46 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -143,6 +143,9 @@ IRInst* AutoDiffTranscriberBase::findOrTranscribePrimalInst(IRBuilder* builder, IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst) { + if (!inst) + return nullptr; + IRInst* primal = lookupPrimalInst(builder, inst, nullptr); if (!primal) { @@ -234,6 +237,13 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRInst* primalType) IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) { + auto primalType = lookupPrimalInst(builder, origType, origType); + if (primalType->getOp() == kIROp_Param && + primalType->getParent() && primalType->getParent()->getParent() && + primalType->getParent()->getParent()->getOp() == kIROp_Generic) + { + return (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType); + } return (IRType*)transcribe(builder, origType); } @@ -725,6 +735,8 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene builder.setInsertBefore(origGeneric); auto diffGeneric = builder.emitGeneric(); + + mapDifferentialInst(origGeneric, diffGeneric); // Process type of generic. If the generic is a function, then it's type will also be a // generic and this logic will transcribe that generic first before continuing with the diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index fa9f4ffb2..78b8c5098 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -413,6 +413,14 @@ struct DiffTransposePass void transposeInst(IRBuilder* builder, IRInst* inst) { + switch (inst->getOp()) + { + case kIROp_ForwardDifferentiate: + return; + default: + break; + } + // Look for gradient entries for this inst. List<RevGradient> gradients; if (hasRevGradients(inst)) @@ -520,14 +528,21 @@ struct DiffTransposePass List<IRInst*> args; List<IRType*> argTypes; - List<bool> isArgPtrTyped; + List<bool> argRequiresLoad; + + auto getDiffPairType = [](IRType* type) + { + if (auto ptrType = as<IRPtrTypeBase>(type)) + type = ptrType->getValueType(); + return as<IRDifferentialPairType>(type); + }; for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++) { auto arg = fwdCall->getArg(ii); // If this isn't a ptr-type, make a var. - if (!as<IRPtrTypeBase>(arg->getDataType()) && as<IRDifferentialPairType>(arg->getDataType())) + if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType())) { auto pairType = as<IRDifferentialPairType>(arg->getDataType()); @@ -548,24 +563,26 @@ struct DiffTransposePass args.add(var); argTypes.add(builder->getInOutType(pairType)); - isArgPtrTyped.add(false); + argRequiresLoad.add(true); } else { args.add(arg); argTypes.add(arg->getDataType()); - isArgPtrTyped.add(true); + argRequiresLoad.add(false); } } args.add(revValue); argTypes.add(revValue->getDataType()); + argRequiresLoad.add(false); args.add(primalContextDecor->getBackwardDerivativePrimalContextVar()); argTypes.add(builder->getOutType( as<IRPtrTypeBase>( primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType()) ->getValueType())); + argRequiresLoad.add(false); auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); auto revCallee = builder->emitBackwardDifferentiatePropagateInst( @@ -578,17 +595,16 @@ struct DiffTransposePass for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++) { // Is this arg relevant to auto-diff? - if (as<IRDifferentialPairType>(as<IRPtrTypeBase>(args[ii]->getDataType())->getValueType())) + if (auto diffPairType = getDiffPairType(args[ii]->getDataType())) { // If this is ptr typed, ignore (the gradient will be accumulated on the pointer) // automatically. // - if (!isArgPtrTyped[ii]) + if (argRequiresLoad[ii]) { auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType( builder, - as<IRDifferentialPairType>( - as<IRPtrTypeBase>(argTypes[ii])->getValueType())->getValueType()); + diffPairType->getValueType()); auto diffArgPtrType = builder->getPtrType(kIROp_PtrType, diffArgType); gradients.add(RevGradient( @@ -889,7 +905,6 @@ struct DiffTransposePass TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) { - // Dispatch logic. switch(fwdInst->getOp()) { @@ -924,7 +939,8 @@ struct DiffTransposePass case kIROp_MakeVector: return transposeMakeVector(builder, fwdInst, revValue); - + + case kIROp_Specialize: case kIROp_unconditionalBranch: case kIROp_conditionalBranch: case kIROp_ifElse: diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index b8a4c4f08..a95fd7b9b 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -252,6 +252,12 @@ struct ExtractPrimalFuncContext SLANG_RELEASE_ASSERT(structType); auto structKey = genTypeBuilder.createStructKey(); genTypeBuilder.setInsertInto(structType); + + if (isChildInstOf(fieldType->getParent(), structType->getParent())) + { + IRCloneEnv cloneEnv; + fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType); + } return genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); } @@ -452,19 +458,21 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>()) { builder.setInsertBefore(inst); - auto addr = builder.emitFieldAddress( - builder.getPtrType(inst->getDataType()), + auto val = builder.emitFieldExtract( + inst->getDataType(), intermediateVar, structKeyDecor->getStructKey()); if (inst->getOp() == kIROp_Var) { // This is a var for intermediate context. - inst->replaceUsesWith(addr); + auto tempVar = + builder.emitVar(cast<IRPtrTypeBase>(inst->getFullType())->getValueType()); + builder.emitStore(tempVar, val); + inst->replaceUsesWith(tempVar); } else { // Orindary value. - auto val = builder.emitLoad(addr); inst->replaceUsesWith(val); } instsToRemove.add(inst); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 612212dd9..b06ed29bf 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -184,19 +184,43 @@ struct DiffUnzipPass return false; } + static IRInst* _getOriginalFunc(IRInst* call) + { + if (auto decor = call->findDecoration<IRAutoDiffOriginalValueDecoration>()) + return decor->getOriginalValue(); + return nullptr; + } + InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall) { IRBuilder globalBuilder; globalBuilder.init(autodiffContext->sharedBuilder); - auto fwdCallee = as<IRForwardDifferentiate>(mixedCall->getCallee()); - auto fwdCalleeType = as<IRFuncType>(fwdCallee->getDataType()); - auto baseFn = fwdCallee->getBaseFn(); + auto fwdCalleeType = mixedCall->getCallee()->getDataType(); + auto baseFn = _getOriginalFunc(mixedCall); + SLANG_RELEASE_ASSERT(baseFn); auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType( primalBuilder, baseFn, as<IRFuncType>(baseFn->getDataType())); - auto intermediateVar = primalBuilder->emitVar(primalBuilder->getBackwardDiffIntermediateContextType(baseFn)); + IRInst* intermediateType = nullptr; + + if (auto specialize = as<IRSpecialize>(baseFn)) + { + auto func = findSpecializeReturnVal(specialize); + auto outerGen = findOuterGeneric(func); + intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); + intermediateType = specializeWithGeneric( + *primalBuilder, + intermediateType, + as<IRGeneric>(findOuterGeneric(primalBuilder->getInsertLoc().getParent()))); + } + else + { + intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn); + } + + auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); @@ -204,7 +228,7 @@ struct DiffUnzipPass List<IRInst*> primalArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { - auto arg = mixedCall->getArg(0); + auto arg = mixedCall->getArg(ii); if (isRelevantDifferentialPair(arg->getDataType())) { @@ -232,7 +256,7 @@ struct DiffUnzipPass List<IRInst*> diffArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { - auto arg = mixedCall->getArg(0); + auto arg = mixedCall->getArg(ii); if (isRelevantDifferentialPair(arg->getDataType())) { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 74afa4002..7182375de 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -6,6 +6,21 @@ namespace Slang { + +bool isBackwardDifferentiableFunc(IRInst* func) +{ + for (auto decorations : func->getDecorations()) + { + switch (decorations->getOp()) + { + case kIROp_BackwardDifferentiableDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + return true; + } + } + return false; +} + static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) { if (auto witnessTable = as<IRWitnessTable>(witness)) @@ -388,7 +403,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { if (auto pairType = as<IRDifferentialPairType>(globalInst)) { - differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness()); + differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness()); } } } @@ -406,6 +421,8 @@ void stripDerivativeDecorations(IRInst* inst) case kIROp_BackwardDerivativeIntermediateTypeDecoration: case kIROp_BackwardDerivativePropagateDecoration: case kIROp_BackwardDerivativePrimalDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_AutoDiffOriginalValueDecoration: decor->removeAndDeallocate(); break; default: @@ -435,6 +452,8 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_AutoDiffOriginalValueDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: decor->removeAndDeallocate(); break; default: @@ -456,27 +475,26 @@ void stripAutoDiffDecorations(IRModule* module) } -void stripBlockTypeDecorations(IRFunc* func) +void stripTempDecorations(IRInst* inst) { - for (auto child : func->getChildren()) + for (auto decor = inst->getFirstDecoration(); decor; ) { - if (auto block = as<IRBlock>(child)) + auto next = decor->getNextDecoration(); + switch (decor->getOp()) { - for (auto decor = block->getFirstDecoration(); decor; ) - { - auto next = decor->getNextDecoration(); - switch (decor->getOp()) - { - case kIROp_DifferentialInstDecoration: - case kIROp_MixedDifferentialInstDecoration: - decor->removeAndDeallocate(); - break; - default: - break; - } - decor = next; - } + case kIROp_DifferentialInstDecoration: + case kIROp_MixedDifferentialInstDecoration: + case kIROp_AutoDiffOriginalValueDecoration: + decor->removeAndDeallocate(); + break; + default: + break; } + decor = next; + } + for (auto child : inst->getChildren()) + { + stripTempDecorations(child); } } @@ -554,9 +572,7 @@ struct AutoDiffPass : public InstPassBase auto inner = findGenericReturnVal(baseGeneric); if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) { - auto typeSpec = cast<IRSpecialize>(typeDecor->getBackwardDerivativeIntermediateType()); - auto typeSpecBase = typeSpec->getBase(); - return typeSpecBase; + return typeDecor->getBackwardDerivativeIntermediateType(); } } else if (auto func = as<IRFunc>(base)) @@ -742,7 +758,7 @@ struct AutoDiffPass : public InstPassBase // passes since they don't expect decorations in blocks. // for (auto diffFunc : autodiffCleanupList) - stripBlockTypeDecorations(diffFunc); + stripTempDecorations(diffFunc); autodiffCleanupList.clear(); diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index f468b1fca..7479e4eee 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -259,4 +259,5 @@ bool finalizeAutoDiffPass(IRModule* module); void stripDerivativeDecorations(IRInst* inst); +bool isBackwardDifferentiableFunc(IRInst* func); }; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 83351d07b..8413e7e79 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -71,6 +71,7 @@ public: { case kIROp_ForwardDerivativeDecoration: case kIROp_ForwardDifferentiableDecoration: + case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_BackwardDerivativeDecoration: case kIROp_BackwardDifferentiableDecoration: return true; @@ -140,20 +141,6 @@ public: return false; } - bool isBackwardDifferentiableFunc(IRInst* func) - { - for (auto decorations : func->getDecorations()) - { - switch (decorations->getOp()) - { - case kIROp_BackwardDerivativeDecoration: - case kIROp_BackwardDifferentiableDecoration: - return true; - } - } - return false; - } - bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst) { HashSet<IRInst*> processedSet; diff --git a/source/slang/slang-ir-hoist-local-types.cpp b/source/slang/slang-ir-hoist-local-types.cpp index cf091f701..d8b0eab22 100644 --- a/source/slang/slang-ir-hoist-local-types.cpp +++ b/source/slang/slang-ir-hoist-local-types.cpp @@ -16,12 +16,6 @@ struct HoistLocalTypesContext void addToWorkList(IRInst* inst) { - for (auto ii = inst->getParent(); ii; ii = ii->getParent()) - { - if (as<IRGeneric>(ii)) - return; - } - if (workListSet.Contains(inst)) return; @@ -29,19 +23,28 @@ struct HoistLocalTypesContext workListSet.Add(inst); } - void processInst(IRInst* inst) + bool processInst(IRInst* inst) { auto sharedBuilder = &sharedBuilderStorage; if (!as<IRType>(inst)) - return; + return false; if (inst->getParent() == module->getModuleInst()) - return; + return false; + switch (inst->getOp()) + { + case kIROp_InterfaceType: + case kIROp_StructType: + case kIROp_ClassType: + return false; + default: + break; + } IRInstKey key = {inst}; if (auto value = sharedBuilder->getGlobalValueNumberingMap().TryGetValue(key)) { inst->replaceUsesWith(*value); inst->removeAndDeallocate(); - return; + return true; } IRBuilder builder(sharedBuilder); builder.setInsertInto(module->getModuleInst()); @@ -67,7 +70,9 @@ struct HoistLocalTypesContext inst->transferDecorationsTo(newType); inst->replaceUsesWith(newType); inst->removeAndDeallocate(); + return true; } + return false; } void processModule() @@ -75,24 +80,31 @@ struct HoistLocalTypesContext SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); - // Deduplicate equivalent types and build numbering map for global types. - sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); + for (;;) + { + bool changed = false; + // Deduplicate equivalent types and build numbering map for global types. + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); - addToWorkList(module->getModuleInst()); + addToWorkList(module->getModuleInst()); - while (workList.getCount() != 0) - { - IRInst* inst = workList.getLast(); + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); - workList.removeLast(); - workListSet.Remove(inst); + workList.removeLast(); + workListSet.Remove(inst); - processInst(inst); + changed |= processInst(inst); - for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) - { - addToWorkList(child); + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } } + + if (!changed) + break; } } }; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index ab7453b41..68afbbb95 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -726,6 +726,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Decorated function is marked for the forward-mode differentiation pass. INST(ForwardDifferentiableDecoration, forwardDifferentiable, 0, 0) + /// Decorates a auto-diff transcribed value with the original value that the inst is transcribed from. + INST(AutoDiffOriginalValueDecoration, AutoDiffOriginalValueDecoration, 1, 0) + /// Used by the auto-diff pass to hold a reference to the /// generated derivative function. INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index b30d489dc..22da763b3 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -566,6 +566,16 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; +struct IRAutoDiffOriginalValueDecoration : IRDecoration +{ + enum + { + kOp = kIROp_AutoDiffOriginalValueDecoration + }; + IR_LEAF_ISA(AutoDiffOriginalValueDecoration) + IRInst* getOriginalValue() { return getOperand(0); } +}; + struct IRForwardDifferentiableDecoration : IRDecoration { enum @@ -708,6 +718,7 @@ struct IRUserDefinedBackwardDerivativeDecoration : IRDecoration kOp = kIROp_UserDefinedBackwardDerivativeDecoration }; IR_LEAF_ISA(UserDefinedBackwardDerivativeDecoration) + IRInst* getBackwardDerivativeFunc() { return getOperand(0); } }; struct IRTreatAsDifferentiableDecoration : IRDecoration @@ -3491,6 +3502,11 @@ public: addDecoration(value, kIROp_ForceInlineDecoration); } + void addAutoDiffOriginalValueDecoration(IRInst* value, IRInst* originalVal) + { + addDecoration(value, kIROp_AutoDiffOriginalValueDecoration, originalVal); + } + void addForwardDifferentiableDecoration(IRInst* value) { addDecoration(value, kIROp_ForwardDifferentiableDecoration); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 73d8865ed..fb465f638 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1,6 +1,7 @@ #include "slang-ir-util.h" #include "slang-ir-insts.h" #include "slang-ir-clone.h" +#include "slang-ir-dce.h" namespace Slang { @@ -198,6 +199,7 @@ IRInst* hoistValueFromGeneric(IRBuilder& inBuilder, IRInst* value, IRInst*& outS value->replaceUsesWith(outSpecializedVal); value->removeAndDeallocate(); } + eliminateDeadCode(newGeneric); return newGeneric; } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index b36a2ebec..e400d0a17 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -6671,6 +6671,8 @@ namespace Slang case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: return false; } @@ -6815,6 +6817,13 @@ namespace Slang return nullptr; } + IRInst* maybeFindOuterGeneric(IRInst* inst) + { + auto outerGeneric = findOuterGeneric(inst); + if (!outerGeneric) return inst; + return outerGeneric; + } + IRInst* findOuterMostGeneric(IRInst* inst) { IRInst* currInst = inst; diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index b4a657545..e22e41f0c 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1758,6 +1758,9 @@ IRInst* findOuterGeneric(IRInst* inst); // Recursively find the outer most generic container. IRInst* findOuterMostGeneric(IRInst* inst); +// Returns `inst` if it is not a generic, otherwise its outer generic. +IRInst* maybeFindOuterGeneric(IRInst* inst); + struct IRSpecialize; IRGeneric* findSpecializedGeneric(IRSpecialize* specialize); IRInst* findSpecializeReturnVal(IRSpecialize* specialize); diff --git a/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang new file mode 100644 index 000000000..bd0780174 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang @@ -0,0 +1,61 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return 1.0; } +}; + +struct B : IInterface +{ + static float calc(float x) { return 2.0; } +}; + +void dsqr<T:IInterface>(T obj, inout DifferentialPair<float> x, float dOut) +{ + float diff = 2.0 * x.p * dOut; + updateDiff(x, diff); +} + +[BackwardDerivative(dsqr)] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} + +// Use automatically differentiated outer function to triger the primal/propagate func generation logic +// on a function that has user provided backward derivative. +[BackwardDifferentiable] +float sqr_outter<T:IInterface>(T obj, float x) +{ + return sqr(obj, x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + var p = DifferentialPair<float>(2.0, 1.0); + __bwd_diff(sqr_outter)(obj, p, 1.0); // A.calc, expect 4 + outputBuffer[0] = p.d; + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + p = DifferentialPair<float>(1.5, 1.0); + __bwd_diff(sqr)(obj, p, 1.0); // A.calc, expect 4 + outputBuffer[1] = p.d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt new file mode 100644 index 000000000..780ba6ed4 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-bwd-derivative.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +4.000000 +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang new file mode 100644 index 000000000..930c1c82b --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang @@ -0,0 +1,53 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + static float calc(float x); +} + +struct A : IInterface +{ + static float calc(float x) { return 1.0; } +}; + +struct B : IInterface +{ + static float calc(float x) { return 2.0; } +}; + +DifferentialPair<float> dsqr<T:IInterface>(T obj, DifferentialPair<float> x) +{ + float primal = obj.calc(x.p) + x.p * x.p; + float diff = 2.0 * x.p * x.d; + return diffPair(primal, diff); +} + +[ForwardDerivative(dsqr)] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + var p = DifferentialPair<float>(2.0, 1.0); + + outputBuffer[0] = __fwd_diff(sqr)(obj, p).d; // A.calc, expect 4 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + p = DifferentialPair<float>(1.5, 1.0); + outputBuffer[1] = __fwd_diff(sqr)(obj, p).d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt new file mode 100644 index 000000000..780ba6ed4 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-custom-fwd-derivative.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +4.000000 +3.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/language-server/robustness-6.slang b/tests/language-server/robustness-6.slang new file mode 100644 index 000000000..ef5924cf3 --- /dev/null +++ b/tests/language-server/robustness-6.slang @@ -0,0 +1,10 @@ +//TEST:LANG_SERVER: +//HOVER:4,8 + +float dsqr<T:II + +[ForwardDerivative(dsqr)] +float sqr<T:IInterface>(T obj, float x) +{ + return no_diff(obj.calc(x)) + x * x; +} diff --git a/tests/language-server/robustness-6.slang.expected.txt b/tests/language-server/robustness-6.slang.expected.txt new file mode 100644 index 000000000..d5aa6c8c9 --- /dev/null +++ b/tests/language-server/robustness-6.slang.expected.txt @@ -0,0 +1,13 @@ +-------- +range: 3,6 - 3,10 +content: +``` +func dsqr<T>(T obj, float x) -> float +``` + +TEST:LANG_SERVER: +HOVER:4,8 + +{REDACTED}.slang(4) + + |
