summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-11-21 10:29:57 -0500
committerGitHub <noreply@github.com>2022-11-21 10:29:57 -0500
commit545de51298ddda52ac51ded03ad489c98bdda397 (patch)
treedef78374f743d2c722fbde45eba60951a6f5c8f9
parentd58e08f8237a1888ceaad53402d534679ea83b1a (diff)
WIP: Fixed inout struct and added testing for calls to non-differentiable functions (#2505)
* Added non-differentiable call test * Extended testing for nondifferentiable calls * Fixed subtle issue with extensions on generic types not applying the correct substitutions, leading to unspecialized generics at the emit stage * More fixes. inout struct params now work fine * Update inout-struct-parameters-jvp.slang * Update slang-ir.cpp * Fixed hoisting lookup_interface_method * Fixed non-diff call return value * Fixed issue with phi nodes * Fixed problem with IRSpecialize preventing hoisitng of DifferentialPairType * Fixed non-diff call test to conform to the new 'no_diff' system
-rw-r--r--source/slang/slang-check-decl.cpp7
-rw-r--r--source/slang/slang-emit.cpp1
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp222
-rw-r--r--source/slang/slang-ir-diff-jvp.h1
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir.cpp49
-rw-r--r--tests/autodiff/inout-struct-parameters-jvp.slang41
-rw-r--r--tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt5
-rw-r--r--tests/autodiff/nondiff-call.slang66
-rw-r--r--tests/autodiff/nondiff-call.slang.expected.txt6
10 files changed, 309 insertions, 91 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 009d0a987..5a1218abe 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -6024,6 +6024,13 @@ namespace Slang
// without any additional substitutions.
if (extDecl->targetType->equals(type))
{
+ /*
+ auto subst = trySolveConstraintSystem(
+ &constraints,
+ DeclRef<Decl>(extGenericDecl, nullptr).as<GenericDecl>(),
+ as<GenericSubstitution>(as<DeclRefType>(type)->declRef.substitutions.substitutions));
+ return DeclRef<Decl>(extDecl, subst).as<ExtensionDecl>();
+ */
return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as<ExtensionDecl>();
}
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index cd5f58925..69ea29c7a 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -358,6 +358,7 @@ Result linkAndOptimizeIR(
// perform specialization of functions based on parameter
// values that need to be compile-time constants.
//
+
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-SPECIALIZE");
if (!codeGenContext->isSpecializationDisabled())
specializeModule(irModule);
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 4ee16aafc..c9ca687e4 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -92,16 +92,33 @@ struct DifferentialPairTypeBuilder
IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
{
- auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
- if (baseTypeInfo.isTrivial)
+ IRInst* pairType = nullptr;
+ if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
{
- if (key == globalPrimalKey)
- return baseInst;
- else
- return builder->getDifferentialBottom();
+ auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType());
+
+ // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types,
+ // especially since we probably don't need diff bottom anymore.
+ //
+ SLANG_ASSERT(!baseTypeInfo.isTrivial);
+
+ pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType);
+ }
+ else
+ {
+ auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
+ if (baseTypeInfo.isTrivial)
+ {
+ if (key == globalPrimalKey)
+ return baseInst;
+ else
+ return builder->getDifferentialBottom();
+ }
+
+ pairType = baseTypeInfo.loweredType;
}
- if (auto basePairStructType = as<IRStructType>(baseTypeInfo.loweredType))
+ if (auto basePairStructType = as<IRStructType>(pairType))
{
return as<IRFieldExtract>(builder->emitFieldExtract(
findField(basePairStructType, key)->getFieldType(),
@@ -109,7 +126,7 @@ struct DifferentialPairTypeBuilder
key
));
}
- else if (auto ptrType = as<IRPtrTypeBase>(baseTypeInfo.loweredType))
+ else if (auto ptrType = as<IRPtrTypeBase>(pairType))
{
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
@@ -135,7 +152,7 @@ struct DifferentialPairTypeBuilder
key));
}
}
- else if (auto specializedType = as<IRSpecialize>(baseTypeInfo.loweredType))
+ else if (auto specializedType = as<IRSpecialize>(pairType))
{
// TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
// type, emit the specialization type.
@@ -333,7 +350,9 @@ struct JVPTranscriber
JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
: differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
- {}
+ {
+
+ }
DiagnosticSink* getSink()
{
@@ -449,6 +468,17 @@ struct JVPTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
+ IRWitnessTable* getDifferentialBottomWitness()
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ SLANG_ASSERT(result);
+ return result;
+ }
+
// Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
{
@@ -456,23 +486,20 @@ struct JVPTranscriber
builder.setInsertInto(inDiffPairType->parent);
auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
SLANG_ASSERT(diffPairType);
- auto diffType = differentiateType(&builder, diffPairType->getValueType());
- IRInst* tableInst = nullptr;
- if (!differentiableTypeConformanceContext.differentiableWitnessDictionary.TryGetValue(diffPairType, tableInst))
- {
- IRWitnessTable* table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
- // The witness that `diffType`
- auto differentialType = builder.getDifferentialPairType(
- diffType,
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffType]
- .GetValue());
- builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
- // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
- differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
- tableInst = table;
- }
- return as<IRWitnessTable>(tableInst);
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+
+ // Differentiate the pair type to get it's differential (which is itself a pair)
+ auto diffDiffPairType = differentiateType(&builder, diffPairType);
+
+ // And place it in the synthesized witness table.
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, diffDiffPairType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ // Record this in the context for future lookups
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+
+ return table;
}
IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
@@ -490,10 +517,19 @@ struct JVPTranscriber
builder.setInsertInto(primalType->parent);
auto witness = as<IRWitnessTable>(
differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
- if (!witness && as<IRDifferentialPairType>(primalType))
+
+ if (!witness)
{
- witness = getDifferentialPairWitness(primalType);
+ if (auto primalPairType = as<IRDifferentialPairType>(primalType))
+ {
+ witness = getDifferentialPairWitness(primalPairType);
+ }
+ else
+ {
+ witness = getDifferentialBottomWitness();
+ }
}
+
return builder.getDifferentialPairType(
(IRType*)primalType,
witness);
@@ -630,8 +666,8 @@ struct JVPTranscriber
builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
SLANG_ASSERT(diffPairParam);
-
- if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairType))
{
return InstPair(
builder->emitDifferentialPairGetPrimal(diffPairParam),
@@ -639,16 +675,23 @@ struct JVPTranscriber
(IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
diffPairParam));
}
- // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
- // its primal and diff parts right now because they would represent a reference
- // to a pair field, which doesn't make sense since pair types are considered mutable.
- // We encode the result as if the param is non-differentiable, and handle it
- // with special care at load/store.
- return InstPair(diffPairParam, nullptr);
+ else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
+ {
+ auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType());
+
+ return InstPair(
+ builder->emitDifferentialPairAddressPrimal(diffPairParam),
+ builder->emitDifferentialPairAddressDifferential(
+ builder->getPtrType(
+ kIROp_PtrType,
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)),
+ diffPairParam));
+ }
}
+
return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
}
else
{
@@ -660,6 +703,7 @@ struct JVPTranscriber
}
return InstPair(primal, diff);
}
+
}
// Returns "d<var-name>" to use as a name hint for variables and parameters.
@@ -784,6 +828,7 @@ struct JVPTranscriber
{
// Special case load from an `out` param, which will not have corresponding `diff` and
// `primal` insts yet.
+
auto load = builder->emitLoad(primalPtr);
auto primalElement = builder->emitDifferentialPairGetPrimal(load);
auto diffElement = builder->emitDifferentialPairGetDifferential(
@@ -1401,30 +1446,25 @@ struct JVPTranscriber
InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
{
- // The loop comes with three blocks.. we just need to transcribe each one
- // and assemble the new loop instruction.
+ // IfElse Statements come with 4 blocks. We transcribe each block into it's
+ // linear form, and then wire them up in the same way as the original if-else
- // Transcribe the target block (this is the 'condition' part of the loop, which
- // will branch into the loop body).
- // Note that for the condition we use the primal inst (condition values should not have a
- // differential)
+ // Transcribe condition block
auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
SLANG_ASSERT(primalConditionBlock);
- // Transcribe the break block (this is the block after the exiting the loop)
+ // Transcribe 'true' block (condition block branches into this if true)
auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
SLANG_ASSERT(diffTrueBlock);
- // Transcribe the continue block (this is the 'update' part of the loop, which will
- // branch into the condition block)
+ // Transcribe 'false' block (condition block branches into this if true)
+ // TODO (sai): What happens if there's no false block?
auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
SLANG_ASSERT(diffFalseBlock);
- // Transcribe the continue block (this is the 'update' part of the loop, which will
- // branch into the condition block)
+ // Transcribe 'after' block (true and false blocks branch into this)
auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
SLANG_ASSERT(diffAfterBlock);
-
List<IRInst*> diffIfElseArgs;
diffIfElseArgs.add(primalConditionBlock);
@@ -2462,6 +2502,9 @@ struct JVPDerivativeContext : public InstPassBase
sharedBuilder->init(module);
sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+ // TODO(sai): Move this call.
+ transcriberStorage.differentiableTypeConformanceContext.buildGlobalWitnessDictionary();
+
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
@@ -2477,6 +2520,9 @@ struct JVPDerivativeContext : public InstPassBase
//
modified |= simplifyDifferentialBottomType(builder);
+ // De-duplicate any remaining types.
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
+
modified |= processPairTypes(builder, module->getModuleInst());
modified |= eliminateDifferentialBottomType(builder);
@@ -2665,7 +2711,13 @@ struct JVPDerivativeContext : public InstPassBase
{
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr))
+ auto pairType = getDiffInst->getBase()->getDataType();
+ if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
+ {
+ pairType = pairPtrType->getValueType();
+ }
+
+ if (lowerPairType(builder, pairType, nullptr))
{
builder->setInsertBefore(getDiffInst);
IRInst* diffFieldExtract = nullptr;
@@ -2677,7 +2729,13 @@ struct JVPDerivativeContext : public InstPassBase
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr))
+ auto pairType = getPrimalInst->getBase()->getDataType();
+ if (auto pairPtrType = as<IRPtrTypeBase>(pairType))
+ {
+ pairType = pairPtrType->getValueType();
+ }
+
+ if (lowerPairType(builder, pairType, nullptr))
{
builder->setInsertBefore(getPrimalInst);
@@ -2695,41 +2753,29 @@ struct JVPDerivativeContext : public InstPassBase
bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
- // Hoist and deduplicate all pair types to global scope when possible.
- // This avoids emitting different struct types for equivalent pair types.
+ // Hoist all pair types to global scope when possible.
auto moduleInst = module->getModuleInst();
- Dictionary<IRInst*, IRInst*> diffPairTypes;
- for (;;)
- {
- bool changed = false;
- sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
- processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* originalPairType)
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
+ {
+ if (originalPairType->parent != moduleInst)
{
- IRInst* finalType = nullptr;
- if (diffPairTypes.TryGetValue(originalPairType->getValueType(), finalType))
- {
- if (finalType != originalPairType)
- {
- originalPairType->replaceUsesWith(finalType);
- originalPairType->removeAndDeallocate();
- changed = true;
- return;
- }
- }
- diffPairTypes[originalPairType->getValueType()] = originalPairType;
- if (originalPairType->parent != moduleInst)
+ originalPairType->removeFromParent();
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
{
- if (originalPairType->getValueType()->getParent() != originalPairType->getParent())
- {
- originalPairType->insertAfter(originalPairType->getValueType());
- changed = true;
- return;
- }
+ operands.add(originalPairType->getOperand(i));
}
- });
- if (!changed)
- break;
- }
+ auto newPairType = builder->findOrEmitHoistableInst(
+ originalPairType->getFullType(),
+ originalPairType->getOp(),
+ originalPairType->getOperandCount(),
+ operands.getArrayView().getBuffer());
+ originalPairType->replaceUsesWith(newPairType);
+ originalPairType->removeAndDeallocate();
+ }
+ });
+
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
processAllInsts([&](IRInst* inst)
{
@@ -3138,4 +3184,14 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b
return nullptr;
}
+void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
+{
+ for (auto globalInst : sharedContext->moduleInst->getChildren())
+ {
+ if (auto pairType = as<IRDifferentialPairType>(globalInst))
+ {
+ differentiableWitnessDictionary.Add(pairType->getValueType(), pairType->getWitness());
+ }
+ }
+}
}
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
index 9e0f9cfcc..5e2a7f44f 100644
--- a/source/slang/slang-ir-diff-jvp.h
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -121,6 +121,7 @@ namespace Slang
void setFunc(IRGlobalValueWithCode* func);
+ void buildGlobalWitnessDictionary();
// Lookup a witness table for the concreteType. One should exist if concreteType
// inherits (successfully) from IDifferentiable.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fcdeed17a..4434210c9 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2696,6 +2696,8 @@ public:
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
+ IRInst* emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair);
+ IRInst* emitDifferentialPairAddressPrimal(IRInst* diffPair);
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index de86a6a52..c12872320 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3147,6 +3147,9 @@ namespace Slang
IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
{
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type));
+ SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr);
+
IRInst* args[] = {primal, differential};
auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
this, kIROp_MakeDifferentialPair, type, 2, args);
@@ -3160,6 +3163,18 @@ namespace Slang
UInt argCount,
IRInst* const* args)
{
+ auto innerReturnVal = findInnerMostGenericReturnVal(as<IRGeneric>(genericVal));
+
+ if (as<IRWitnessTable>(innerReturnVal))
+ {
+ return findOrEmitHoistableInst(
+ type,
+ kIROp_Specialize,
+ genericVal,
+ argCount,
+ args);
+ }
+
auto inst = createInstWithTrailingArgs<IRSpecialize>(
this,
kIROp_Specialize,
@@ -3186,15 +3201,13 @@ namespace Slang
//
SLANG_ASSERT(witnessTableVal->getOp() != kIROp_StructKey);
- auto inst = createInst<IRLookupWitnessMethod>(
- this,
- kIROp_lookup_interface_method,
- type,
- witnessTableVal,
- interfaceMethodVal);
+ IRInst* args[] = {witnessTableVal, interfaceMethodVal};
- addInst(inst);
- return inst;
+ return findOrEmitHoistableInst(
+ type,
+ kIROp_lookup_interface_method,
+ 2,
+ args);
}
IRInst* IRBuilder::emitGetSequentialIDInst(IRInst* rttiObj)
@@ -3467,6 +3480,15 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPairGetDifferential,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
{
auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
@@ -3477,6 +3499,17 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPairAddressPrimal(IRInst* diffPair)
+ {
+ auto valueType = as<IRDifferentialPairType>(
+ as<IRPtrTypeBase>(diffPair->getDataType())->getValueType())->getValueType();
+ return emitIntrinsicInst(
+ this->getPtrType(kIROp_PtrType, valueType),
+ kIROp_DifferentialPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitMakeMatrix(
IRType* type,
UInt argCount,
diff --git a/tests/autodiff/inout-struct-parameters-jvp.slang b/tests/autodiff/inout-struct-parameters-jvp.slang
new file mode 100644
index 000000000..80ff57b7d
--- /dev/null
+++ b/tests/autodiff/inout-struct-parameters-jvp.slang
@@ -0,0 +1,41 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+struct A : IDifferentiable
+{
+ float p;
+ float3 q;
+}
+
+[ForwardDifferentiable]
+void g(A a, inout A aout)
+{
+ float t = a.p + a.q.y * a.q.x;
+ aout.p = aout.p + t;
+ aout.q = aout.q * t;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float p = 1.0;
+ float3 q = float3(1.0, 2.0, 3.0);
+
+ float dp = 1.0;
+ float3 dq = float3(1.0, 0.5, 0.25);
+
+ DifferentialPair<A> dpa = DifferentialPair<A>({p, q}, {dp, dq});
+
+ __fwd_diff(g)(DifferentialPair<A>( { p, q }, { dp, dq }), dpa);
+
+ outputBuffer[0] = dpa.p.p; // Expect: 4.0
+ outputBuffer[1] = dpa.d.q.x; // Expect: 6.5
+ outputBuffer[2] = dpa.d.q.y; // Expect: 8.5
+ outputBuffer[3] = dpa.d.q.z; // Expect: 11.25
+
+} \ No newline at end of file
diff --git a/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt b/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt
new file mode 100644
index 000000000..4cc3c313d
--- /dev/null
+++ b/tests/autodiff/inout-struct-parameters-jvp.slang.expected.txt
@@ -0,0 +1,5 @@
+type: float
+4.000000
+6.500000
+8.500000
+11.25000
diff --git a/tests/autodiff/nondiff-call.slang b/tests/autodiff/nondiff-call.slang
new file mode 100644
index 000000000..d62de1b78
--- /dev/null
+++ b/tests/autodiff/nondiff-call.slang
@@ -0,0 +1,66 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef DifferentialPair<float3> dpfloat3;
+
+[ForwardDifferentiable]
+float f(float x)
+{
+ return x * x + x * x * x;
+}
+
+[ForwardDifferentiable]
+float f2(float x)
+{
+ return f(x);
+}
+
+float g(float x)
+{
+ return x * x + x * x * x;
+}
+
+[ForwardDifferentiable]
+float g2(float x)
+{
+ return no_diff(g(x));
+}
+
+struct A
+{
+ float o;
+
+ [ForwardDifferentiable]
+ float doSomethingDifferentiable(float b)
+ {
+ return o + b;
+ }
+
+ float doSomethingNotDifferentiable(float b)
+ {
+ return o * b;
+ }
+}
+
+[ForwardDifferentiable]
+float h2(A a, float k)
+{
+ float v = k * k;
+ return no_diff(a.doSomethingNotDifferentiable(k)) + v;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ {
+ outputBuffer[0] = f2(1.0); // Expect: 2.0
+ outputBuffer[1] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).d; // Expect: 5.0
+ outputBuffer[2] = __fwd_diff(f2)(dpfloat(1.0, 1.0)).p; // Expect: 2.0
+ outputBuffer[3] = __fwd_diff(g2)(dpfloat(1.0, 1.0)).d; // Expect: 0.0
+ outputBuffer[4] = __fwd_diff(h2)({1.0}, DifferentialPair<float>(1.0, 2.0)).d; // Expect: 4.0
+ }
+}
diff --git a/tests/autodiff/nondiff-call.slang.expected.txt b/tests/autodiff/nondiff-call.slang.expected.txt
new file mode 100644
index 000000000..8f85913bc
--- /dev/null
+++ b/tests/autodiff/nondiff-call.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+2.000000
+5.000000
+2.000000
+0.000000
+4.000000 \ No newline at end of file