summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-specialize.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-26 22:21:29 -0400
committerGitHub <noreply@github.com>2022-10-26 19:21:29 -0700
commitf7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch)
tree574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang/slang-ir-specialize.cpp
parent939be44ca23476e622dfb24a592383fe2a1da61f (diff)
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang/slang-ir-specialize.cpp')
-rw-r--r--source/slang/slang-ir-specialize.cpp69
1 files changed, 69 insertions, 0 deletions
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 3ef79df28..53ea99a0c 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -373,7 +373,76 @@ struct SpecializationContext
// represents a definition rather than a declaration.
//
if(!canSpecializeGeneric(genericVal))
+ {
+ // We have to consider a special case here if baseVal is
+ // an intrinsic, and contains a custom differential.
+ // This is a case where the base cannot be specialized since it has
+ // no body, but the custom should be specialized.
+ // A better way to handle this would be to grab a reference to the
+ // appropriate custom differential, if one exists, at checking time
+ // during CheckInvoke() and construct it's specialization args appropriately.
+ //
+ // For now, we will overwrite the specialization args for the differential
+ // using the args for the base.
+ //
+ auto genericReturnVal = findInnerMostGenericReturnVal(genericVal);
+ if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>())
+ {
+ if (auto customDiffRef = genericReturnVal->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ // If we already have a diff func on this specialize, skip.
+ if (auto specDiffRef = specInst->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ return false;
+ }
+
+ auto specDiffFunc = as<IRSpecialize>(customDiffRef->getJVPFunc());
+
+ // If the base is specialized, the JVP version must be also be a specialized
+ // generic.
+ //
+ SLANG_ASSERT(specDiffFunc);
+
+ // Build specialization arguments from specInst.
+ // Note that if we've reached this point, we can safely assume
+ // that our args are fully specialized/concrete.
+ //
+ UCount argCount = specInst->getArgCount();
+ List<IRInst*> args;
+ for (UIndex ii = 0; ii < argCount; ii++)
+ args.add(specInst->getArg(ii));
+
+ IRBuilder builder(&sharedBuilderStorage);
+
+ // Specialize the custom JVP function type with the original arguments.
+ builder.setInsertInto(module);
+ auto newDiffFuncType = builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ specDiffFunc->getBase()->getDataType(),
+ argCount,
+ args.getBuffer());
+
+ // Specialize the custom JVP function with the original arguments.
+ builder.setInsertBefore(specInst);
+ auto newDiffFunc = builder.emitSpecializeInst(
+ (IRType*) newDiffFuncType,
+ specDiffFunc->getBase(),
+ argCount,
+ args.getBuffer());
+
+ // Add the new spec insts to the list so they get specialized with
+ // the usual logic.
+ //
+ addToWorkList(newDiffFuncType);
+ addToWorkList(newDiffFunc);
+
+ builder.addJVPDerivativeReferenceDecoration(specInst, newDiffFunc);
+
+ return true;
+ }
+ }
return false;
+ }
// Once we know that specialization is possible,
// the actual work is fairly simple.