diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-26 22:21:29 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 19:21:29 -0700 |
| commit | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch) | |
| tree | 574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-ir-specialize.cpp | |
| parent | 939be44ca23476e622dfb24a592383fe2a1da61f (diff) | |
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-ir-specialize.cpp')
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 3ef79df28..53ea99a0c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -373,7 +373,76 @@ struct SpecializationContext // represents a definition rather than a declaration. // if(!canSpecializeGeneric(genericVal)) + { + // We have to consider a special case here if baseVal is + // an intrinsic, and contains a custom differential. + // This is a case where the base cannot be specialized since it has + // no body, but the custom should be specialized. + // A better way to handle this would be to grab a reference to the + // appropriate custom differential, if one exists, at checking time + // during CheckInvoke() and construct it's specialization args appropriately. + // + // For now, we will overwrite the specialization args for the differential + // using the args for the base. + // + auto genericReturnVal = findInnerMostGenericReturnVal(genericVal); + if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>()) + { + if (auto customDiffRef = genericReturnVal->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + // If we already have a diff func on this specialize, skip. + if (auto specDiffRef = specInst->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + return false; + } + + auto specDiffFunc = as<IRSpecialize>(customDiffRef->getJVPFunc()); + + // If the base is specialized, the JVP version must be also be a specialized + // generic. + // + SLANG_ASSERT(specDiffFunc); + + // Build specialization arguments from specInst. + // Note that if we've reached this point, we can safely assume + // that our args are fully specialized/concrete. + // + UCount argCount = specInst->getArgCount(); + List<IRInst*> args; + for (UIndex ii = 0; ii < argCount; ii++) + args.add(specInst->getArg(ii)); + + IRBuilder builder(&sharedBuilderStorage); + + // Specialize the custom JVP function type with the original arguments. + builder.setInsertInto(module); + auto newDiffFuncType = builder.emitSpecializeInst( + builder.getTypeKind(), + specDiffFunc->getBase()->getDataType(), + argCount, + args.getBuffer()); + + // Specialize the custom JVP function with the original arguments. + builder.setInsertBefore(specInst); + auto newDiffFunc = builder.emitSpecializeInst( + (IRType*) newDiffFuncType, + specDiffFunc->getBase(), + argCount, + args.getBuffer()); + + // Add the new spec insts to the list so they get specialized with + // the usual logic. + // + addToWorkList(newDiffFuncType); + addToWorkList(newDiffFunc); + + builder.addJVPDerivativeReferenceDecoration(specInst, newDiffFunc); + + return true; + } + } return false; + } // Once we know that specialization is possible, // the actual work is fairly simple. |
