summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-07-13 15:55:30 -0400
committerGitHub <noreply@github.com>2022-07-13 15:55:30 -0400
commit4af61e2296a49876c2d9e7cf192ae825302a83de (patch)
treed067b944b9794fe5061bbf51e8ef6a39d5fcefbf
parent564f0d84a9c5276c05e8288955a7685f96278d1b (diff)
Added support for differentiating out and inout parameters. (#2323)
* Added out/inout tests * Added support for out and inout parameters. Still untested * Fixed and tested support for out and inout types * Removed some comments
-rw-r--r--source/slang/slang-check-expr.cpp29
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp106
-rw-r--r--tests/autodiff/inout-parameters-jvp.slang30
-rw-r--r--tests/autodiff/inout-parameters-jvp.slang.expected.txt5
-rw-r--r--tests/autodiff/out-parameters-jvp.slang28
-rw-r--r--tests/autodiff/out-parameters-jvp.slang.expected.txt5
6 files changed, 183 insertions, 20 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 03da084d3..67e8bf650 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1509,18 +1509,40 @@ namespace Slang
return expr;
}
+ // This function proceses primal params (i.e params of the inner function that is being
+ // differentiated) that need to be carried over to the function signature for the JVP
+ // function. (eg. out types can be discarded)
+ //
+ Type* primalToInputType(ASTBuilder*, Type* primalType)
+ {
+ if (auto primalOutType = as<OutType>(primalType))
+ return nullptr;
+ else if (auto primalInOutType = as<InOutType>(primalType))
+ return primalInOutType->getValueType();
+
+ return primalType;
+ }
+
Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
{
// Only float and float3 types can be differentiated for now.
- if(primalType->equals(builder->getFloatType()))
+ if (primalType->equals(builder->getFloatType()))
return primalType;
- else if(auto primalVectorType = as<VectorExpressionType>(primalType))
+ else if (auto primalVectorType = as<VectorExpressionType>(primalType))
{
// TODO(sai): There's probably a more elegant way to check if a type is a float3?
if (getIntVal(primalVectorType->elementCount) == 3 && primalVectorType->elementType->equals(builder->getFloatType()))
return primalVectorType;
}
+ else if (auto primalOutType = as<OutType>(primalType))
+ {
+ return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType()));
+ }
+ else if (auto primalInOutType = as<InOutType>(primalType))
+ {
+ return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType()));
+ }
return nullptr;
}
@@ -1558,7 +1580,8 @@ namespace Slang
for (UInt i = 0; i < primalType->getParamCount(); i++)
{
- jvpType->paramTypes.add(primalType->getParamType(i));
+ if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i)))
+ jvpType->paramTypes.add(primalInputType);
}
for (UInt i = 0; i < primalType->getParamCount(); i++)
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 5b77d483d..f5afccd0c 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -56,7 +56,9 @@ struct JVPTranscriber
// Add all primal parameters to the list.
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
- parameterTypesD.add(funcType->getParamType(i));
+ // TODO(sai): Move this check to a separate function.
+ if (!as<IROutType>(funcType->getParamType(i)))
+ parameterTypesD.add(funcType->getParamType(i));
}
// Add differential versions for the types we support.
@@ -85,7 +87,12 @@ struct JVPTranscriber
case kIROp_DoubleType:
return builder->getType(typeP->getOp());
case kIROp_VectorType:
+ // TODO(sai): Call differentiateType() on typeP.
return as<IRVectorType>(typeP);
+ case kIROp_OutType:
+ return builder->getOutType(differentiateType(builder, as<IROutType>(typeP)->getValueType()));
+ case kIROp_InOutType:
+ return builder->getInOutType(differentiateType(builder, as<IRInOutType>(typeP)->getValueType()));
default:
return nullptr;
}
@@ -102,21 +109,45 @@ struct JVPTranscriber
return nullptr;
}
+ IRInst* emitInputParam(IRBuilder* builder, IRParam* paramP)
+ {
+ // Convert primal 'inout' types into pure input types, because a
+ // JVP transformed function must never have primal side-effects.
+ //
+ if (auto inoutTypeP = as<IRInOutType>(paramP->getDataType()))
+ {
+ auto newParamP = builder->emitParam(inoutTypeP->getValueType());
+ cloneEnv.mapOldValToNew.Add(paramP, newParamP);
+
+ return newParamP;
+ }
+ else if (as<IROutType>(paramP->getDataType()))
+ {
+ getSink()->diagnose(paramP->sourceLoc,
+ Diagnostics::unexpected,
+ "encountered unexpected output parameter");
+ return nullptr;
+ }
+ else
+ return as<IRParam>(cloneInst(&cloneEnv, builder, paramP));
+ }
+
List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP)
{
// Clone (and emit) all the primal parameters.
List<IRParam*> newParamListP;
for (auto paramP : paramListP)
{
- newParamListP.add(as<IRParam>(cloneInst(&cloneEnv, builder, paramP)));
+ if(requiresPrimalClone(builder, paramP))
+ newParamListP.add(as<IRParam>(emitInputParam(builder, paramP)));
}
// Now emit differentials.
List<IRParam*> newParamListD;
- for (auto paramP : newParamListP)
+ for (auto paramP : paramListP)
{
IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP));
- mapDifferentialInst(paramP, paramD);
+ mapDifferentialInst(findCloneForOperand(&cloneEnv, paramP), paramD);
newParamListD.add(paramD);
}
@@ -187,15 +218,16 @@ struct JVPTranscriber
IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP)
{
- if (auto varP = as<IRVar>(loadP->getPtr()))
+ auto ptrP = loadP->getPtr();
+ if (as<IRVar>(ptrP) || as<IRParam>(ptrP))
{
// If the loaded parameter has a differential version,
// emit a load instruction for the differential parameter.
// Otherwise, emit nothing since there's nothing to load.
//
- if (auto varD = as<IRVar>(getDifferentialInst(varP)))
+ if (auto ptrD = getDifferentialInst(ptrP, nullptr))
{
- IRLoad* loadD = as<IRLoad>(builder->emitLoad(varD));
+ IRLoad* loadD = as<IRLoad>(builder->emitLoad(ptrD));
SLANG_ASSERT(loadD);
return loadD;
}
@@ -212,14 +244,14 @@ struct JVPTranscriber
{
IRInst* storeLocation = storeP->getPtr();
IRInst* storeVal = storeP->getVal();
- if (auto destParam = as<IRVar>(storeLocation))
+ if (as<IRVar>(storeLocation) || as<IRParam>(storeLocation))
{
// If the stored value has a differential version,
// emit a store instruction for the differential parameter.
// Otherwise, emit nothing since there's nothing to load.
//
IRInst* storeValD = getDifferentialInst(storeVal);
- IRVar* storeLocationD = as<IRVar>(getDifferentialInst(destParam));
+ IRInst* storeLocationD = getDifferentialInst(storeLocation);
if (storeValD && storeLocationD)
{
IRStore* storeD = as<IRStore>(
@@ -239,13 +271,18 @@ struct JVPTranscriber
IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP)
{
IRInst* returnVal = findCloneForOperand(&cloneEnv, returnP->getVal());
- if (auto returnValD = getDifferentialInst(returnVal))
+ if (auto returnValD = getDifferentialInst(returnVal, nullptr))
{
IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD));
SLANG_ASSERT(returnD);
return returnD;
}
- return nullptr;
+ else
+ {
+ // If the differential return value is not available, emit a
+ // void return.
+ return builder->emitReturn();
+ }
}
// Since int/float literals are sometimes nested inside an IRConstructor
@@ -352,16 +389,38 @@ struct JVPTranscriber
}
// Logic for whether a primal instruction needs to be replicated
- // in the differential function. For puerly functional blocks with
- // no side-effects, it's safe to replicate everything except the
- // return instruction.
- //
+ // in the differential function. We detect and avoid replicating
+ // side-effect instructions.
+ //
bool requiresPrimalClone(IRBuilder*, IRInst* instP)
{
if (as<IRReturn>(instP))
return false;
- else
- return true;
+ else if (auto paramP = as<IRParam>(instP))
+ {
+ // Out-type parameters are discarded from the parameter list,
+ // since pure JVP functions to not write to primal outputs.
+ //
+ if (as<IROutType>(paramP->getDataType()))
+ return false;
+ }
+ else if (auto storeP = as<IRStore>(instP))
+ {
+ IRInst* storeLocation = storeP->getPtr();
+
+ // Writing to a parameter is a side-effect that should be avoided.
+ if(as<IRParam>(storeLocation))
+ return false;
+
+ // If attempting to store to a location without a clone,
+ // then this instruction likely has side-effects external to the
+ // current function.
+ //
+ if(!lookUp(&cloneEnv, storeLocation))
+ return false;
+ }
+
+ return true;
}
IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP)
@@ -374,6 +433,19 @@ struct JVPTranscriber
//
if (requiresPrimalClone(builder, oldInstP))
instP = cloneInst(&cloneEnv, builder, oldInstP);
+ else
+ {
+ // We replace the operands of the old instruction with their clones,
+ // if available.
+ //
+ for(UInt ii = 0; ii < oldInstP->getOperandCount(); ++ii)
+ {
+ auto oldOperand = oldInstP->getOperand(ii);
+ auto newOperand = findCloneForOperand(&cloneEnv, oldOperand);
+
+ instP->getOperands()[ii].init(instP, newOperand);
+ }
+ }
SLANG_ASSERT(instP);
IRInst* instD = differentiateInst(builder, instP);
diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang
new file mode 100644
index 000000000..989e56c02
--- /dev/null
+++ b/tests/autodiff/inout-parameters-jvp.slang
@@ -0,0 +1,30 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+__differentiate_jvp void h(float x, float y, inout float z)
+{
+ float m = x + y;
+ float n = x - y;
+ z = z + m * n + 2 * x * y;
+}
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float x = 2.0;
+ float y = 3.5;
+ float z = 1.0;
+ float dx = 1.0;
+ float dy = 0.5;
+ float dz = 2.5;
+
+ __jvp(h)(x, y, z, dx, dy, dz);
+
+ outputBuffer[0] = dz; // Expect: 12.0
+ outputBuffer[1] = z; // Expect: 1.0
+
+} \ No newline at end of file
diff --git a/tests/autodiff/inout-parameters-jvp.slang.expected.txt b/tests/autodiff/inout-parameters-jvp.slang.expected.txt
new file mode 100644
index 000000000..d8a590c0e
--- /dev/null
+++ b/tests/autodiff/inout-parameters-jvp.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+12.0
+1.0
+0.0
+0.0 \ No newline at end of file
diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang
new file mode 100644
index 000000000..58c6cfeb0
--- /dev/null
+++ b/tests/autodiff/out-parameters-jvp.slang
@@ -0,0 +1,28 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+__differentiate_jvp void h(float x, float y, out float result)
+{
+ float m = x + y;
+ float n = x - y;
+ result = m * n + 2 * x * y;
+}
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float x = 2.0;
+ float y = 3.5;
+ float dx = 1.0;
+ float dy = 0.5;
+
+ float dresult = 0.0f;
+ __jvp(h)(x, y, dx, dy, dresult);
+
+ outputBuffer[0] = dresult; // Expect: 9.5
+
+} \ No newline at end of file
diff --git a/tests/autodiff/out-parameters-jvp.slang.expected.txt b/tests/autodiff/out-parameters-jvp.slang.expected.txt
new file mode 100644
index 000000000..555935fc4
--- /dev/null
+++ b/tests/autodiff/out-parameters-jvp.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+9.5
+0.0
+0.0
+0.0 \ No newline at end of file