summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-diff-jvp.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-diff-jvp.cpp')
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp1478
1 files changed, 1167 insertions, 311 deletions
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 5eee13d5e..843428c01 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -7,6 +7,10 @@
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
+// origX, primalX, diffX
+// origX -> primalX (cloneEnv)
+// origX -> diffX (instMapD)
+
namespace Slang
{
@@ -24,7 +28,7 @@ typedef Pair<IRInst*, IRInst*> InstPair;
struct DifferentiableTypeConformanceContext
{
- Dictionary<IRInst*, IRInst*> witnessTableMap;
+ Dictionary<IRInst*, IRInst*> witnessTableMap;
IRInst* inst = nullptr;
@@ -39,6 +43,18 @@ struct DifferentiableTypeConformanceContext
// type in the conformance table associated with the concrete type.
//
IRStructKey* differentialAssocTypeStructKey = nullptr;
+
+ // The struct key for the 'zero()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of zero() for a given type.
+ //
+ IRStructKey* zeroMethodStructKey = nullptr;
+
+ // The struct key for the 'add()' associated type
+ // defined inside IDifferential. We use this to lookup the
+ // implementation of add() for a given type.
+ //
+ IRStructKey* addMethodStructKey = nullptr;
// Modules that don't use differentiable types
// won't have the IDifferentiable interface type available.
@@ -56,6 +72,9 @@ struct DifferentiableTypeConformanceContext
{
differentiableInterfaceType = parent->differentiableInterfaceType;
differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey;
+ zeroMethodStructKey = parent->zeroMethodStructKey;
+ addMethodStructKey = parent->addMethodStructKey;
+
isInterfaceAvailable = parent->isInterfaceAvailable;
}
else
@@ -64,17 +83,13 @@ struct DifferentiableTypeConformanceContext
if (differentiableInterfaceType)
{
differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
if (differentialAssocTypeStructKey)
isInterfaceAvailable = true;
}
}
-
- if (isInterfaceAvailable)
- {
- // Load all witness tables corresponding to the IDifferentiable interface.
- loadWitnessTablesForInterface(differentiableInterfaceType);
- }
}
DifferentiableTypeConformanceContext(IRInst* inst) :
@@ -84,35 +99,30 @@ struct DifferentiableTypeConformanceContext
// Lookup a witness table for the concreteType. One should exist if concreteType
// inherits (successfully) from IDifferentiable.
//
- IRInst* lookUpConformanceForType(IRInst* type)
+ IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type)
{
SLANG_ASSERT(isInterfaceAvailable);
+ // TODO: Cache the returned value to avoid repeatedly scanning through
+ // blocks looking for the type entries.
+ //
+ if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent()))
+ {
+ return irWitness;
+ }
- if (witnessTableMap.ContainsKey(type))
- return witnessTableMap[type];
- else if (parent)
- return parent->lookUpConformanceForType(type);
- else
- return nullptr;
+ return nullptr;
}
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- SLANG_ASSERT(isInterfaceAvailable);
- if (auto conformance = lookUpConformanceForType(origType))
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
+ {
+ if (auto conformance = lookUpConformanceForType(builder, origType))
{
if (auto witnessTable = as<IRWitnessTable>(conformance))
{
for (auto entry : witnessTable->getEntries())
{
- if (entry->getRequirementKey() == differentialAssocTypeStructKey)
- return as<IRType>(entry->getSatisfyingVal());
+ if (entry->getRequirementKey() == key)
+ return entry->getSatisfyingVal();
}
}
else if (auto witnessTableParam = as<IRParam>(conformance))
@@ -120,12 +130,32 @@ struct DifferentiableTypeConformanceContext
return builder->emitLookupInterfaceMethodInst(
builder->getTypeKind(),
witnessTableParam,
- differentialAssocTypeStructKey);
+ key);
}
}
return nullptr;
}
+
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey);
+ }
+
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, addMethodStructKey);
+ }
private:
@@ -150,11 +180,26 @@ struct DifferentiableTypeConformanceContext
IRStructKey* findDifferentialTypeStructKey()
{
+ return getIDifferentiableStructKeyAtIndex(0);
+ }
+
+ IRStructKey* findZeroMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(1);
+ }
+
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(2);
+ }
+
+ IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
+ {
if (as<IRModuleInst>(inst) && differentiableInterfaceType)
{
- // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1);
- if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0)))
+ // Assume for now that IDifferentiable has exactly three fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
+ if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
return as<IRStructKey>(entry->getRequirementKey());
else
{
@@ -200,12 +245,18 @@ struct DifferentiableTypeConformanceContext
genericParam = genericParam->getNextParam();
}
- UCount tableIndex = 0;
+ Count tableIndex = 0;
while (genericParam)
{
SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType()));
+
+ if (tableIndex >= typeParams.getCount())
+ break;
+
if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType()))
{
+ // TODO(sai): Heavily flawed way to find the right witness table.
+ // Rewrite this part
if (witnessTableType->getConformanceType() == differentiableInterfaceType)
witnessTableMap.Add(typeParams[tableIndex], genericParam);
}
@@ -222,6 +273,40 @@ struct DifferentiableTypeConformanceContext
};
+
+IRInst* findGlobal(IRInst* inst)
+{
+ if (inst->getParent() != inst->getModule()->getModuleInst())
+ {
+ return findGlobal(inst->getParent());
+ }
+
+ return inst;
+}
+
+void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst)
+{
+ HashSet<IRInst*> globalsOfUses;
+ for (auto use = globalInst->firstUse; use; use = use->nextUse)
+ {
+ globalsOfUses.Add(findGlobal(use->getUser()));
+ }
+
+ IRInst* earliestUse = nullptr;
+ for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst())
+ {
+ if (globalsOfUses.Contains(cursor))
+ {
+ earliestUse = cursor;
+ }
+ }
+
+ if (earliestUse)
+ {
+ globalInst->insertBefore(earliestUse);
+ }
+}
+
struct DifferentialPairTypeBuilder
{
@@ -229,95 +314,246 @@ struct DifferentialPairTypeBuilder
diffConformanceContext(diffConformanceContext)
{}
- IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ IRStructField* findField(IRInst* type, IRStructKey* key)
{
- if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
+ if (auto irStructType = as<IRStructType>(type))
{
- auto primalField = as<IRStructField>(basePairStructType->getFirstChild());
- SLANG_ASSERT(primalField);
-
- return as<IRFieldExtract>(builder->emitFieldExtract(
- primalField->getFieldType(),
- baseInst,
- primalField->getKey()
- ));
+ for (auto field : irStructType->getFields())
+ {
+ if (field->getKey() == key)
+ {
+ return field;
+ }
+ }
}
- else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ else if (auto irSpecialize = as<IRSpecialize>(type))
{
- if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
+ if (auto irGeneric = as<IRGeneric>(irSpecialize->getBase()))
{
- auto primalField = as<IRStructField>(pairStructType->getFirstChild());
- SLANG_ASSERT(primalField);
-
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType(primalField->getFieldType()),
- baseInst,
- primalField->getKey()
- ));
+ if (auto irGenericStructType = as<IRStructType>(findInnerMostGenericReturnVal(irGeneric)))
+ {
+ return findField(irGenericStructType, key);
+ }
}
}
- else
+
+ return nullptr;
+ }
+
+ IRInst* findSpecializationForParam(IRInst* specializeInst, IRInst* genericParam)
+ {
+ // Get base generic that's being specialized.
+ auto genericType = as<IRGeneric>(as<IRSpecialize>(specializeInst)->getBase());
+ SLANG_ASSERT(genericType);
+
+ // Find the index of genericParam in the base generic.
+ int paramIndex = -1;
+ int currentIndex = 0;
+ for (auto param : genericType->getParams())
{
- SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
+ if (param == genericParam)
+ paramIndex = currentIndex;
+ currentIndex ++;
}
- return nullptr;
+
+ SLANG_ASSERT(paramIndex >= 0);
+
+ // Return the corresponding operand in the specialization inst.
+ return specializeInst->getOperand(1 + paramIndex);
}
- IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
{
if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
{
- auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst());
- SLANG_ASSERT(diffField);
-
return as<IRFieldExtract>(builder->emitFieldExtract(
- diffField->getFieldType(),
+ findField(basePairStructType, key)->getFieldType(),
baseInst,
- diffField->getKey()
+ key
));
}
else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
{
- if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
+ if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
- auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst());
- SLANG_ASSERT(diffField);
-
- return as<IRFieldAddress>(builder->emitFieldAddress(
- builder->getPtrType(diffField->getFieldType()),
+ auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(ptrInnerSpecializedType->getBase()));
+ if (auto genericBasePairStructType = as<IRStructType>(genericType))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ ptrInnerSpecializedType,
+ findField(ptrInnerSpecializedType, key)->getFieldType())),
baseInst,
- diffField->getKey()
+ key
));
+ }
+ }
+ else if (auto ptrBaseStructType = as<IRStructType>(ptrType->getValueType()))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findField(ptrBaseStructType, key)->getFieldType()),
+ baseInst,
+ key));
+ }
+ }
+ else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType()))
+ {
+ // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
+ // type, emit the specialization type.
+ //
+ auto genericType = findInnerMostGenericReturnVal(as<IRGeneric>(specializedType->getBase()));
+ if (auto genericBasePairStructType = as<IRStructType>(genericType))
+ {
+ return as<IRFieldExtract>(builder->emitFieldExtract(
+ (IRType*)findSpecializationForParam(
+ specializedType,
+ findField(genericBasePairStructType, key)->getFieldType()),
+ baseInst,
+ key
+ ));
+ }
+ else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
+ {
+ if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
+ {
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType((IRType*)
+ findSpecializationForParam(
+ specializedType,
+ findField(genericPairStructType, key)->getFieldType())),
+ baseInst,
+ key
+ ));
+ }
}
}
else
{
- SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
+ SLANG_UNEXPECTED("Unrecognized field. Cannot emit field accessor");
}
return nullptr;
}
+
+ IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalPrimalKey);
+ }
+
+ IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ {
+ return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
+ }
+
+ void relocateNewTypes(IRBuilder* builder)
+ {
+ for (auto typeInst : generatedTypeList)
+ {
+ moveGlobalToBeforeUses(builder, typeInst);
+ }
+ }
+
+ void _createGenericDiffPairType(IRBuilder* builder)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ auto insertLoc = builder->getInsertLoc();
+ builder->setInsertInto(builder->getModule()->getModuleInst());
+
+ // Make a generic version of the pair struct.
+ auto irGeneric = builder->emitGeneric();
+ irGeneric->setFullType(builder->getTypeKind());
+ builder->setInsertInto(irGeneric);
+
+ generatedTypeList.add(irGeneric);
+
+ auto irBlock = builder->emitBlock();
+ builder->setInsertInto(irBlock);
+
+ auto pTypeParam = builder->emitParam(builder->getTypeType());
+ builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT"));
+
+ auto dTypeParam = builder->emitParam(builder->getTypeType());
+ builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT"));
+
+ auto irStructType = builder->createStructType();
+ builder->emitReturn(irStructType);
+
+ auto primalKey = _getOrCreatePrimalStructKey(builder);
+ builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal"));
+ builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam);
+
+ auto diffKey = _getOrCreateDiffStructKey(builder);
+ builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
+ builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam);
+
+ // Reset cursor when done.
+ builder->setInsertLoc(insertLoc);
+
+ this->genericDiffPairType = irGeneric;
+ }
+
+ IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
+ {
+ if (!this->globalDiffKey)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ auto insertLoc = builder->getInsertLoc();
+ builder->setInsertInto(builder->getModule()->getModuleInst());
+
+ this->globalDiffKey = builder->createStructKey();
+ builder->addNameHintDecoration(this->globalDiffKey , UnownedTerminatedStringSlice("differential"));
+
+ builder->setInsertLoc(insertLoc);
+ }
+
+ return this->globalDiffKey;
+ }
+
+ IRStructKey* _getOrCreatePrimalStructKey(IRBuilder* builder)
+ {
+ if (!this->globalPrimalKey)
+ {
+ // Insert directly at top level (skip any generic scopes etc.)
+ auto insertLoc = builder->getInsertLoc();
+ builder->setInsertInto(builder->getModule()->getModuleInst());
+
+ this->globalPrimalKey = builder->createStructKey();
+ builder->addNameHintDecoration(this->globalPrimalKey , UnownedTerminatedStringSlice("primal"));
+
+ builder->setInsertLoc(insertLoc);
+ }
+
+ return this->globalPrimalKey;
+ }
+
+ IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder)
+ {
+ if (!this->genericDiffPairType)
+ {
+ _createGenericDiffPairType(builder);
+ }
+
+ SLANG_ASSERT(this->genericDiffPairType);
+ return this->genericDiffPairType;
+ }
- IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
{
if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
{
- auto diffPairType = builder->createStructType();
-
- // Create a keys for the primal and differential fields.
- IRStructKey* origKey = builder->createStructKey();
- builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal"));
- builder->createStructField(diffPairType, origKey, origBaseType);
+ SLANG_ASSERT(!as<IRParam>(origBaseType));
- IRStructKey* diffKey = builder->createStructKey();
- builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
- builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType));
+ auto pairStructType = builder->createStructType();
+ builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
+ builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType);
- return diffPairType;
+ return pairStructType;
}
return nullptr;
}
- IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
{
if (pairTypeCache.ContainsKey(origBaseType))
return pairTypeCache[origBaseType];
@@ -328,10 +564,17 @@ struct DifferentialPairTypeBuilder
return pairType;
}
- Dictionary<IRType*, IRStructType*> pairTypeCache;
+ Dictionary<IRInst*, IRInst*> pairTypeCache;
DifferentiableTypeConformanceContext* diffConformanceContext;
+
+ IRStructKey* globalPrimalKey = nullptr;
+
+ IRStructKey* globalDiffKey = nullptr;
+ IRInst* genericDiffPairType = nullptr;
+
+ List<IRInst*> generatedTypeList;
};
struct JVPTranscriber
@@ -341,6 +584,9 @@ struct JVPTranscriber
// their differential values.
Dictionary<IRInst*, IRInst*> instMapD;
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
// Cloning environment to hold mapping from old to new copies for the primal
// instructions.
IRCloneEnv cloneEnv;
@@ -362,7 +608,17 @@ struct JVPTranscriber
void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
{
- instMapD.Add(origInst, diffInst);
+ if (hasDifferentialInst(origInst))
+ {
+ if (lookupDiffInst(origInst) != diffInst)
+ {
+ SLANG_UNEXPECTED("Inconsistent differential mappings");
+ }
+ }
+ else
+ {
+ instMapD.Add(origInst, diffInst);
+ }
}
void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
@@ -439,6 +695,7 @@ struct JVPTranscriber
for (UIndex i = 0; i < funcType->getParamCount(); i++)
{
auto origType = funcType->getParamType(i);
+ origType = (IRType*) lookupPrimalInst(origType, origType);
if (auto diffPairType = tryGetDiffPairType(builder, origType))
newParameterTypes.add(diffPairType);
else
@@ -448,7 +705,8 @@ struct JVPTranscriber
// Transcribe return type to a pair.
// This will be void if the primal return type is non-differentiable.
//
- if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType()))
+ auto origResultType = (IRType*) lookupPrimalInst(funcType->getResultType(), funcType->getResultType());
+ if (auto returnPairType = tryGetDiffPairType(builder, origResultType))
diffReturnType = returnPairType;
else
diffReturnType = builder->getVoidType();
@@ -458,41 +716,101 @@ struct JVPTranscriber
IRType* differentiateType(IRBuilder* builder, IRType* origType)
{
- switch (origType->getOp())
- {
- case kIROp_HalfType:
- case kIROp_FloatType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType));
- case kIROp_OutType:
- return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType()));
- case kIROp_InOutType:
- return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType()));
- default:
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = lookupPrimalInst(origType, origType);
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(diffConformanceContext->getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType));
}
}
- IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType)
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
{
// If this is a PtrType (out, inout, etc..), then create diff pair from
// value type and re-apply the appropropriate PtrType wrapper.
//
- if (auto origPtrType = as<IRPtrTypeBase>(origType))
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
{
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
- return builder->getPtrType(origType->getOp(), diffPairValueType);
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
else
return nullptr;
}
- return pairBuilder->getOrCreateDiffPairType(builder, origType);
+ return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType);
}
InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
{
- if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType()))
+ auto primalDataType = lookupPrimalInst(origParam->getDataType(), origParam->getDataType());
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
{
IRParam* diffPairParam = builder->emitParam(diffPairType);
@@ -507,6 +825,7 @@ struct JVPTranscriber
pairBuilder->emitDiffFieldAccess(builder, diffPairParam));
}
+
return InstPair(
cloneInst(&cloneEnv, builder, origParam),
nullptr);
@@ -570,15 +889,13 @@ struct JVPTranscriber
auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
auto diffRight = findOrTranscribeDiffInst(builder, origRight);
- auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0);
- auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0);
if (diffLeft || diffRight)
{
- diffLeft = diffLeft ? diffLeft : leftZero;
- diffRight = diffRight ? diffRight : rightZero;
+ diffLeft = diffLeft ? diffLeft : getDifferentialZeroOfType(builder, primalLeft->getDataType());
+ diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType());
- auto resultType = origArith->getDataType();
+ auto resultType = primalArith->getDataType();
switch(origArith->getOp())
{
case kIROp_Add:
@@ -608,17 +925,36 @@ struct JVPTranscriber
return InstPair(primalArith, nullptr);
}
+
+ InstPair transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic)
+ {
+ SLANG_ASSERT(origLogic->getOperandCount() == 2);
+
+ // TODO: Check other boolean cases.
+ if (as<IRBoolType>(origLogic->getDataType()))
+ {
+ // Boolean operations are not differentiable. For the linearization
+ // pass, we do not need to do anything but copy them over to the ne
+ // function.
+ auto primalLogic = cloneInst(&cloneEnv, builder, origLogic);
+ return InstPair(primalLogic, nullptr);
+ }
+
+ SLANG_UNEXPECTED("Logical operation with non-boolean result");
+ }
+
InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ IRInst* diffLoad = nullptr;
+
if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
{
- IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
- SLANG_ASSERT(diffLoad);
-
+ // Default case, we're loading from a known differential inst.
+ diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
return InstPair(primalLoad, diffLoad);
}
return InstPair(primalLoad, nullptr);
@@ -634,15 +970,17 @@ struct JVPTranscriber
auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+ IRInst* diffStore = nullptr;
+
// If the stored value has a differential version,
// emit a store instruction for the differential parameter.
// Otherwise, emit nothing since there's nothing to load.
//
if (diffStoreLocation && diffStoreVal)
{
- IRStore* diffStore = as<IRStore>(
- builder->emitStore(diffStoreLocation, diffStoreVal));
- SLANG_ASSERT(diffStore);
+ // Default case, storing the entire type (and not a member)
+ diffStore = as<IRStore>(
+ builder->emitStore(diffStoreLocation, diffStoreVal));
return InstPair(primalStore, diffStore);
}
@@ -653,14 +991,31 @@ struct JVPTranscriber
InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
{
IRInst* origReturnVal = origReturn->getVal();
-
- if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType()))
+
+ auto returnDataType = (IRType*) lookupPrimalInst(origReturnVal->getDataType(), origReturnVal->getDataType());
+ if (as<IRFunc>(origReturnVal) || as<IRGeneric>(origReturnVal) || as<IRStructType>(origReturnVal) || as<IRFuncType>(origReturnVal))
+ {
+ // If the return value is itself a function, generic or a struct then this
+ // is likely to be a generic scope. In this case, we lookup the differential
+ // and return that.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+
+ // Neither of these should be nullptr.
+ SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal);
+ IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal));
+
+ return InstPair(diffReturn, diffReturn);
+ }
+ else if (auto pairType = tryGetDiffPairType(builder, returnDataType))
{
IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
-
IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
if(!diffReturnVal)
- diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType());
+ diffReturnVal = getDifferentialZeroOfType(builder, returnDataType);
+
+ // If the pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffReturnVal);
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
@@ -668,10 +1023,12 @@ struct JVPTranscriber
}
else
{
- // If the differential return value is not available, emit a
- // void return.
- IRInst* voidReturn = builder->emitReturn();
- return InstPair(voidReturn, voidReturn);
+ // If the return type is not differentiable, emit the primal value only.
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+
+ IRInst* primalReturn = builder->emitReturn(primalReturnVal);
+ return InstPair(primalReturn, nullptr);
+
}
}
@@ -682,15 +1039,43 @@ struct JVPTranscriber
InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
{
IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+
+ // Check if the output type can be differentiated. If it cannot be
+ // differentiated, don't differentiate the inst
+ //
+ auto primalConstructType = (IRType*) lookupPrimalInst(origConstruct->getDataType(), origConstruct->getDataType());
+ if (auto diffConstructType = differentiateType(builder, primalConstructType))
+ {
+ UCount operandCount = origConstruct->getOperandCount();
- if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1)
- return InstPair(primalConstruct, nullptr);
+ List<IRInst*> diffOperands;
+ for (UIndex ii = 0; ii < operandCount; ii++)
+ {
+ // If the operand has a differential version, replace the original with
+ // the differential. Otherwise, use a zero.
+ //
+ if (auto diffInst = lookupDiffInst(origConstruct->getOperand(ii), nullptr))
+ diffOperands.add(diffInst);
+ else
+ {
+ auto operandDataType = origConstruct->getOperand(ii)->getDataType();
+ operandDataType = (IRType*) lookupPrimalInst(operandDataType, operandDataType);
+ diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ }
+ }
+
+ return InstPair(
+ primalConstruct,
+ builder->emitIntrinsicInst(
+ diffConstructType,
+ origConstruct->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
+ }
else
- getSink()->diagnose(origConstruct->sourceLoc,
- Diagnostics::unimplemented,
- "this construct instruction cannot be differentiated");
-
- return InstPair(primalConstruct, nullptr);
+ {
+ return InstPair(primalConstruct, nullptr);
+ }
}
// Differentiating a call instruction here is primarily about generating
@@ -699,13 +1084,21 @@ struct JVPTranscriber
//
InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
{
- if (auto origCallee = as<IRFunc>(origCall->getCallee()))
+
+ if (as<IRFunc>(origCall->getCallee()))
{
-
+ auto origCallee = origCall->getCallee();
+
+ // Since concrete functions are globals, the primal callee is the same
+ // as the original callee.
+ //
+ auto primalCallee = origCallee;
+
+ // TODO: If inner is not differentiable, treat as non-differentiable call.
// Build the differential callee
IRInst* diffCall = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())),
- origCallee);
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
List<IRInst*> args;
// Go over the parameter list and create pairs for each input (if required)
@@ -715,17 +1108,17 @@ struct JVPTranscriber
auto primalArg = findOrTranscribePrimalInst(builder, origArg);
SLANG_ASSERT(primalArg);
- auto origType = origArg->getDataType();
- if (auto pairType = tryGetDiffPairType(builder, origType))
+ auto primalType = primalArg->getDataType();
+ if (auto pairType = tryGetDiffPairType(builder, primalType))
{
-
auto diffArg = findOrTranscribeDiffInst(builder, origArg);
- // TODO(sai): This part is flawed. Replace with a call to the
- // 'zero()' interface method.
if (!diffArg)
- diffArg = getZeroOfType(builder, origType);
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+ // If a pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffArg);
+
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
args.add(diffPair);
@@ -737,8 +1130,11 @@ struct JVPTranscriber
}
}
+ auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+ SLANG_ASSERT(diffReturnType);
+
auto callInst = builder->emitCallInst(
- tryGetDiffPairType(builder, origCall->getFullType()),
+ diffReturnType,
diffCall,
args);
@@ -746,6 +1142,13 @@ struct JVPTranscriber
pairBuilder->emitPrimalFieldAccess(builder, callInst),
pairBuilder->emitDiffFieldAccess(builder, callInst));
}
+ else if(as<IRSpecialize>(origCall->getCallee()) ||
+ as<IRLookupWitnessMethod>(origCall->getCallee()))
+ {
+ getSink()->diagnose(origCall->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unspecialized callee or an interface method");
+ }
else
{
// Note that this can only happen if the callee is a result
@@ -774,7 +1177,7 @@ struct JVPTranscriber
return InstPair(
primalSwizzle,
builder->emitSwizzle(
- differentiateType(builder, origSwizzle->getDataType()),
+ differentiateType(builder, primalSwizzle->getDataType()),
diffBase,
origSwizzle->getElementCount(),
swizzleIndices.getBuffer()));
@@ -806,7 +1209,7 @@ struct JVPTranscriber
return InstPair(
primalInst,
builder->emitIntrinsicInst(
- differentiateType(builder, origInst->getDataType()),
+ differentiateType(builder, primalInst->getDataType()),
origInst->getOp(),
operandCount,
diffOperands.getBuffer()));
@@ -819,17 +1222,44 @@ struct JVPTranscriber
case kIROp_unconditionalBranch:
auto origBranch = as<IRUnconditionalBranch>(origInst);
- // Branches with extra operands not handled currently.
- if (origBranch->getOperandCount() > 1)
- break;
+ // Grab the differentials for any phi nodes.
+ List<IRInst*> pairArgs;
+ for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
+ {
+ auto origArg = origBranch->getArg(ii);
- IRInst* diffBranch = nullptr;
+ IRInst* pairArg = nullptr;
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType()))
+ {
+ auto diffArg = lookupDiffInst(origArg, nullptr);
+ if (!diffArg)
+ {
+ diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType());
+ }
+
+ pairArg = builder->emitMakeDifferentialPair(
+ diffPairType,
+ lookupPrimalInst(origArg),
+ diffArg);
+ }
+ else
+ {
+ pairArg = lookupPrimalInst(origArg);
+ }
+ pairArgs.add(pairArg);
+ }
- if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr))
- diffBranch = builder->emitBranch(as<IRBlock>(diffBlock));
+ IRInst* diffBranch = nullptr;
+ if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
+ {
+ diffBranch = builder->emitBranch(
+ as<IRBlock>(diffBlock),
+ pairArgs.getCount(),
+ pairArgs.getBuffer());
+ }
// For now, every block in the original fn must have a corresponding
- // block to compute both primals and derivatives.
+ // block to compute *both* primals and derivatives (i.e linearized block)
SLANG_ASSERT(diffBranch);
return InstPair(diffBranch, diffBranch);
@@ -843,12 +1273,13 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
-
InstPair transcribeConst(IRBuilder*, IRInst* origInst)
{
switch(origInst->getOp())
{
case kIROp_FloatLit:
+ case kIROp_VoidLit:
+ case kIROp_IntLit:
return InstPair(origInst, nullptr);
}
@@ -860,49 +1291,439 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
+ InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+ {
+ // This is slightly counter-intuitive, but we don't perform any differentiation
+ // logic here. We simple clone the original specialize which points to the original function,
+ // or the cloned version in case we're inside a generic scope.
+ // The differentiation logic is inserted later when this is used in an IRCall.
+ // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn))
+ // rather than have Specialize(JVPDifferentiate(Fn))
+ //
+ auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize);
+ return InstPair(diffSpecialize, diffSpecialize);
+ }
+
+ InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup)
+ {
+ // This is slightly counter-intuitive, but we don't perform any differentiation
+ // logic here. We simple clone the original lookup which points to the original function,
+ // or the cloned version in case we're inside a generic scope.
+ // The differentiation logic is inserted later when this is used in an IRCall.
+ // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Lookup(Table))
+ // rather than have Lookup(JVPDifferentiate(Table))
+ //
+ auto diffLookup = cloneInst(&cloneEnv, builder, origLookup);
+ return InstPair(diffLookup, diffLookup);
+ }
+
// In differential computation, the 'default' differential value is always zero.
// This is a consequence of differential computing being inherently linear. As a
// result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
//
- IRInst* getZeroOfType(IRBuilder* builder, IRType* type)
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ // We special case a few non-differentiable types that sometimes appear in places
+ // where we're forced to provide a differential zero value. For instance,
+ // float3(float, float, int) is accepted by the compiler, but is tricky in the context
+ // of differentiation since int is non-differentiable, and should be cast to float first.
+ // In the absence of such casts, this piece of code generates appropriate zero values.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_IntType:
+ return builder->getIntValue(primalType, 0);
+ default:
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ auto oldLoc = builder->getInsertLoc();
+
+ IRInst* diffBlock = builder->emitBlock();
+
+ // Note: for blocks, we setup the mapping _before_
+ // processing the children since we could encounter
+ // a lookup while processing the children.
+ //
+ mapPrimalInst(origBlock, diffBlock);
+ mapDifferentialInst(origBlock, diffBlock);
+
+ builder->setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->transcribe(builder, param);
+
+ // Look for the differentiable type dictionary and clone it (and anything else we might need).
+ // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement)
+ // that are operands.
+ // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks.
+ if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock))
+ {
+ auto clonedDict = cloneInst(&cloneEnv, builder, origDict);
+ mapPrimalInst(origDict, clonedDict);
+ mapDifferentialInst(origDict, clonedDict);
+ }
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->transcribe(builder, child);
+
+ builder->setInsertLoc(oldLoc);
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract)
{
- switch (type->getOp())
+ IRInst* origBase = origExtract->getBase();
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+
+ auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType());
+
+ IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField());
+ IRInst* diffExtract = nullptr;
+
+ if (auto diffExtractType = differentiateType(builder, primalExtractType))
{
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- return builder->getFloatValue(type, 0.0);
- case kIROp_IntType:
- return builder->getIntValue(type, 0);
- case kIROp_VectorType:
+ // Check if we have a getter.
+ if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>())
{
- IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())};
- return builder->emitIntrinsicInst(
- type,
- kIROp_constructVectorFromScalar,
- 1,
+
+ IRInst* getterFunc = getterDecoration->getGetterFunc();
+
+ // Must be a method with a single parameter.
+ SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1);
+
+ // Our getter func accepts a _pointer_ to the target type
+ // So we have to create a variable and store our type into memory
+ // here. This will eventually get optimized out in later passes.
+ //
+ auto diffTempVar = builder->emitVar(
+ diffBase->getDataType());
+
+ builder->emitStore(diffTempVar, diffBase);
+
+ List<IRInst*> args;
+ args.add(diffTempVar);
+
+ // Emit a call to the getter. The getter will return a reference type.
+ // We need to load from this to go to a non-ptr 'solid' type.
+ //
+ auto diffGetterCall = builder->emitCallInst(
+ as<IRFuncType>(getterFunc->getDataType())->getResultType(),
+ getterFunc,
args);
+
+ diffExtract = builder->emitLoad(diffGetterCall);
}
- default:
- getSink()->diagnose(type->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
}
+
+ return InstPair(primalExtract, diffExtract);
+ }
+
+ InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress)
+ {
+ IRInst* origBase = origAddress->getBase();
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+
+ auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType());
+
+ IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField());
+ IRInst* diffAddress = nullptr;
+
+ if (auto diffAddressType = differentiateType(builder, primalAddressType))
+ {
+ // If we have a getter associated with this field, we want to use that.
+ if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>())
+ {
+ auto getterFunc = getterDecoration->getGetterFunc();
+
+ // Add the base differential inst as the argument.
+ List<IRInst*> args;
+ args.add(diffBase);
+
+ diffAddress = builder->emitCallInst(
+ as<IRFuncType>(getterFunc->getDataType())->getResultType(),
+ getterFunc,
+ args);
+ }
+
+ }
+
+ return InstPair(primalAddress, diffAddress);
+ }
+
+
+ InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
+ {
+ SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
+
+ IRInst* origBase = origGetElementPtr->getOperand(0);
+ auto primalBase = findOrTranscribePrimalInst(builder, origBase);
+ auto primalIndex = findOrTranscribePrimalInst(builder, origGetElementPtr->getOperand(1));
+
+ auto primalType = (IRType*)lookupPrimalInst(origGetElementPtr->getDataType(), origGetElementPtr->getDataType());
+
+ IRInst* primalOperands[] = {primalBase, primalIndex};
+ IRInst* primalGetElementPtr = builder->emitIntrinsicInst(
+ primalType,
+ origGetElementPtr->getOp(),
+ 2,
+ primalOperands);
+
+ IRInst* diffGetElementPtr = nullptr;
+
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
+ {
+ IRInst* diffOperands[] = {diffBase, primalIndex};
+ diffGetElementPtr = builder->emitIntrinsicInst(
+ diffType,
+ origGetElementPtr->getOp(),
+ 2,
+ diffOperands);
+ }
+ }
+
+ return InstPair(primalGetElementPtr, diffGetElementPtr);
+ }
+
+
+ InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
+ {
+ // The loop comes with three blocks.. we just need to transcribe each one
+ // and assemble the new loop instruction.
+
+ // Transcribe the target block (this is the 'condition' part of the loop, which
+ // will branch into the loop body)
+ auto diffTargetBlock = findOrTranscribeDiffInst(builder, origLoop->getTargetBlock());
+
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffBreakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffContinueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+
+
+ List<IRInst*> diffLoopOperands;
+ diffLoopOperands.add(diffTargetBlock);
+ diffLoopOperands.add(diffBreakBlock);
+ diffLoopOperands.add(diffContinueBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffLoopOperands.getCount(); ii < origLoop->getOperandCount(); ii++)
+ {
+ auto primalOperand = findOrTranscribePrimalInst(builder, origLoop->getOperand(ii));
+ diffLoopOperands.add(primalOperand);
+ }
+
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ diffLoopOperands.getCount(),
+ diffLoopOperands.getBuffer());
+
+ return InstPair(diffLoop, diffLoop);
+ }
+
+ InstPair transcribeIfElse(IRBuilder* builder, IRIfElse* origIfElse)
+ {
+ // The loop comes with three blocks.. we just need to transcribe each one
+ // and assemble the new loop instruction.
+
+ // Transcribe the target block (this is the 'condition' part of the loop, which
+ // will branch into the loop body).
+ // Note that for the condition we use the primal inst (condition values should not have a
+ // differential)
+ auto primalConditionBlock = findOrTranscribePrimalInst(builder, origIfElse->getCondition());
+ SLANG_ASSERT(primalConditionBlock);
+
+ // Transcribe the break block (this is the block after the exiting the loop)
+ auto diffTrueBlock = findOrTranscribeDiffInst(builder, origIfElse->getTrueBlock());
+ SLANG_ASSERT(diffTrueBlock);
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
+ SLANG_ASSERT(diffFalseBlock);
+
+ // Transcribe the continue block (this is the 'update' part of the loop, which will
+ // branch into the condition block)
+ auto diffAfterBlock = findOrTranscribeDiffInst(builder, origIfElse->getAfterBlock());
+ SLANG_ASSERT(diffAfterBlock);
+
+
+ List<IRInst*> diffIfElseArgs;
+ diffIfElseArgs.add(primalConditionBlock);
+ diffIfElseArgs.add(diffTrueBlock);
+ diffIfElseArgs.add(diffFalseBlock);
+ diffIfElseArgs.add(diffAfterBlock);
+
+ // If there are any other operands, use their primal versions.
+ for (UIndex ii = diffIfElseArgs.getCount(); ii < origIfElse->getOperandCount(); ii++)
+ {
+ auto primalOperand = findOrTranscribePrimalInst(builder, origIfElse->getOperand(ii));
+ diffIfElseArgs.add(primalOperand);
+ }
+
+ IRInst* diffLoop = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_ifElse,
+ diffIfElseArgs.getCount(),
+ diffIfElseArgs.getBuffer());
+
+ return InstPair(diffLoop, diffLoop);
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc)
+ {
+ IRFunc* primalFunc = nullptr;
+
+ auto oldLoc = builder->getInsertLoc();
+
+ // If this is a top-level function, there is no need to clone it
+ // since it is visible in all the scopes.
+ // Otherwise, we need to clone it in case of generic scopes.
+ //
+ // TODO(sai): Is this the correct thing to do? Can a function cloned inside a
+ // generic scope but is not the return value of that generic, be used within
+ // that scope? Or do we have to call out to the original generic specialized with
+ // the current generic params?
+ //
+ bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr);
+ if (isTopLevelFunc)
+ {
+ builder->setInsertBefore(origFunc);
+ primalFunc = origFunc;
+ }
+ else
+ {
+ // TODO(sai): this might never be called, and it might never make sense
+ // to call it either. Potentially remove this.
+ primalFunc = as<IRFunc>(
+ cloneInst(&cloneEnv, builder, origFunc));
+ }
+
+ auto diffFunc = builder->createFunc();
+
+ SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ builder,
+ as<IRFuncType>(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ // TODO(sai): Replace naming scheme
+ // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
+ // builder->addNameHintDecoration(diffFunc, jvpName);
+
+ // Transcribe children from origFunc into diffFunc
+ builder->setInsertInto(diffFunc);
+ for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(builder, block);
+
+ // Reset builder position
+ builder->setInsertLoc(oldLoc);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ // Transcribe a generic definition
+ InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric)
+ {
+ // For now, we assume there's only one generic layer. So this inst must be top level
+ bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr);
+ SLANG_RELEASE_ASSERT(isTopLevel);
+
+ IRGeneric* primalGeneric = origGeneric;
+
+ auto oldLoc = builder->getInsertLoc();
+ builder->setInsertBefore(origGeneric);
+
+ auto diffGeneric = builder->emitGeneric();
+
+ // Process type of generic. If the generic is a function, then it's type will also be a
+ // generic and this logic will transcribe that generic first before continuing with the
+ // function itself.
+ //
+ auto primalType = primalGeneric->getFullType();
+
+ IRType* diffType = nullptr;
+ if (primalType)
+ {
+ diffType = (IRType*) findOrTranscribeDiffInst(builder, primalType);
+ }
+
+ diffGeneric->setFullType(diffType);
+
+ // TODO(sai): Replace naming scheme
+ // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
+ // builder->addNameHintDecoration(diffFunc, jvpName);
+
+ // Transcribe children from origFunc into diffFunc.
+ builder->setInsertInto(diffGeneric);
+ for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribe(builder, block);
+
+ // Reset builder position.
+ builder->setInsertLoc(oldLoc);
+
+ return InstPair(primalGeneric, diffGeneric);
}
IRInst* transcribe(IRBuilder* builder, IRInst* origInst)
{
+ // If a differential intstruction is already mapped for
+ // this original inst, return that.
+ //
+ if (auto diffInst = lookupDiffInst(origInst, nullptr))
+ {
+ SLANG_ASSERT(lookupPrimalInst(origInst)); // Consistency check.
+ return diffInst;
+ }
+
+ // Otherwise, dispatch to the appropriate method
+ // depending on the op-code.
+ //
+ instsInProgress.Add(origInst);
InstPair pair = transcribeInst(builder, origInst);
if (auto primalInst = pair.primal)
{
mapPrimalInst(origInst, pair.primal);
-
mapDifferentialInst(origInst, pair.differential);
return pair.differential;
}
+ instsInProgress.Remove(origInst);
+
getSink()->diagnose(origInst->sourceLoc,
Diagnostics::internalCompilerError,
"failed to transcibe instruction");
@@ -911,7 +1732,7 @@ struct JVPTranscriber
InstPair transcribeInst(IRBuilder* builder, IRInst* origInst)
{
- // Handle common operations
+ // Handle common SSA-style operations
switch (origInst->getOp())
{
case kIROp_Param:
@@ -934,6 +1755,14 @@ struct JVPTranscriber
case kIROp_Sub:
case kIROp_Div:
return transcribeBinaryArith(builder, origInst);
+
+ case kIROp_Less:
+ case kIROp_Greater:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ return transcribeBinaryLogic(builder, origInst);
case kIROp_Construct:
return transcribeConstruct(builder, origInst);
@@ -945,24 +1774,91 @@ struct JVPTranscriber
return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
case kIROp_constructVectorFromScalar:
+ case kIROp_MakeTuple:
return transcribeByPassthrough(builder, origInst);
case kIROp_unconditionalBranch:
- case kIROp_conditionalBranch:
return transcribeControlFlow(builder, origInst);
case kIROp_FloatLit:
+ case kIROp_IntLit:
+ case kIROp_VoidLit:
return transcribeConst(builder, origInst);
+ case kIROp_Specialize:
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::unexpected,
+ "should not be attempting to differentiate anything specialized here.");
+
+ case kIROp_lookup_interface_method:
+ return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
+
+ case kIROp_FieldExtract:
+ return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst));
+
+ case kIROp_FieldAddress:
+ return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst));
+
+ case kIROp_getElement:
+ case kIROp_getElementPtr:
+ return transcribeGetElement(builder, origInst);
+
+ case kIROp_loop:
+ return transcribeLoop(builder, as<IRLoop>(origInst));
+
+ case kIROp_ifElse:
+ return transcribeIfElse(builder, as<IRIfElse>(origInst));
+
+ case kIROp_DifferentiableTypeDictionary:
+ // Ignore dictionary insts.
+ return InstPair(nullptr, nullptr);
+
}
// If none of the cases have been hit, check if the instruction is a
- // type.
- // For now we don't have logic to differentiate types that appear in blocks.
- // So, we clone and avoid differentiating them.
- //
+ // type. Only need to explicitly differentiate types if they appear inside a block.
+ //
if (auto origType = as<IRType>(origInst))
- return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr);
+ {
+ // If this is a generic type, transcibe the parent
+ // generic and derive the type from the transcribed generic's
+ // return value.
+ //
+ if (as<IRGeneric>(origType->getParent()->getParent()) &&
+ findInnerMostGenericReturnVal(as<IRGeneric>(origType->getParent()->getParent())) == origType &&
+ !instsInProgress.Contains(origType->getParent()->getParent()))
+ {
+ auto origGenericType = origType->getParent()->getParent();
+ auto diffGenericType = findOrTranscribeDiffInst(builder, origGenericType);
+ auto innerDiffGenericType = findInnerMostGenericReturnVal(as<IRGeneric>(diffGenericType));
+ return InstPair(
+ origGenericType,
+ innerDiffGenericType
+ );
+ }
+ else if (as<IRBlock>(origType->getParent()))
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ differentiateType(builder, origType));
+ else
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origType),
+ nullptr);
+ }
+
+ // Handle instructions with children
+ switch (origInst->getOp())
+ {
+ case kIROp_Func:
+ return transcribeFunc(builder, as<IRFunc>(origInst));
+
+ case kIROp_Block:
+ return transcribeBlock(builder, as<IRBlock>(origInst));
+
+ case kIROp_Generic:
+ return transcribeGeneric(builder, as<IRGeneric>(origInst));
+ }
+
// If we reach this statement, the instruction type is likely unhandled.
getSink()->diagnose(origInst->sourceLoc,
@@ -1042,6 +1938,14 @@ struct JVPDerivativeContext
// IRMakeDifferentialPair with an IRMakeStruct.
//
modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage));
+
+ // Temporary fix: Move generated types, if any, to before their use locations.
+ (&pairBuilderStorage)->relocateNewTypes(builder);
+
+ // Remove all kIROp_DifferentiableTypeDictionary instructions and
+ // kIROp_DifferentialGetterDecoration decorations
+ //
+ modified |= stripDiffTypeInformation(builder, module->getModuleInst());
return modified;
}
@@ -1079,19 +1983,45 @@ struct JVPDerivativeContext
if (auto jvpDiffInst = as<IRJVPDifferentiate>(child))
{
- auto baseFunction = jvpDiffInst->getBaseFn();
+ auto baseInst = jvpDiffInst->getBaseFn();
+
+ IRGlobalValueWithCode* baseFunction = nullptr;
+
+ if (auto specializeInst = as<IRSpecialize>(baseInst))
+ {
+ baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase());
+ }
+ else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst))
+ {
+ baseFunction = globalValWithCode;
+ }
+
+ SLANG_ASSERT(baseFunction);
+
// If the JVP Reference already exists, no need to
// differentiate again.
//
- if(lookupJVPReference(baseFunction)) continue;
+ if (lookupJVPReference(baseFunction)) continue;
- if (isFunctionMarkedForJVP(as<IRGlobalValueWithCode>(baseFunction)))
+ if (isMarkedForJVP(baseFunction))
{
- IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(baseFunction));
- builder->addJVPDerivativeReferenceDecoration(baseFunction, jvpFunction);
- workQueue->push(jvpFunction);
+ if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ {
+ IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction);
+ SLANG_ASSERT(diffFunc);
+ builder->addJVPDerivativeReferenceDecoration(baseFunction, diffFunc);
+ workQueue->push(diffFunc);
+ }
+ else
+ {
+ // TODO(Sai): This would probably be better with a more specific
+ // error code.
+ getSink()->diagnose(jvpDiffInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
}
- else
+ else
{
// TODO(Sai): This would probably be better with a more specific
// error code.
@@ -1106,55 +2036,33 @@ struct JVPDerivativeContext
return true;
}
- // Run through all the global-level instructions,
- // looking for callables.
- // Note: We're only processing global callables (IRGlobalValueWithCode)
- // for now.
- //
- bool processMarkedGlobalFunctions(IRBuilder* builder)
+ IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*)
{
- for (auto inst : module->getGlobalInsts())
+
+ if (auto pairType = as<IRDifferentialPairType>(type))
{
- // If the instr is a callable, get all the basic blocks
- if (auto callable = as<IRGlobalValueWithCode>(inst))
- {
- if (isFunctionMarkedForJVP(callable))
- {
- SLANG_ASSERT(as<IRFunc>(callable));
+ builder->setInsertBefore(pairType);
- IRFunc* jvpFunction = emitJVPFunction(builder, as<IRFunc>(callable));
- builder->addJVPDerivativeReferenceDecoration(callable, jvpFunction);
+ auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
+ builder,
+ pairType->getValueType());
- unmarkForJVP(callable);
- }
- }
- }
- return true;
- }
+ pairType->replaceUsesWith(diffPairStructType);
+ pairType->removeAndDeallocate();
- IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext)
- {
- if (diffContext->isInterfaceAvailable)
+ return diffPairStructType;
+ }
+ else if (auto loweredStructType = as<IRStructType>(type))
{
- if (auto pairType = as<IRDifferentialPairType>(type))
- {
- builder->setInsertBefore(pairType);
-
- auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
- builder,
- pairType->getValueType());
-
- pairType->replaceUsesWith(diffPairStructType);
- pairType->removeAndDeallocate();
-
- return diffPairStructType;
- }
- else if (auto loweredStructType = as<IRStructType>(type))
- {
- // Already lowered to struct.
- return loweredStructType;
- }
+ // Already lowered to struct.
+ return loweredStructType;
}
+ else if (auto specializedStructType = as<IRSpecialize>(type))
+ {
+ // Already lowered to specialized struct.
+ return specializedStructType;
+ }
+
return nullptr;
}
@@ -1171,7 +2079,7 @@ struct JVPDerivativeContext
operands.add(makePairInst->getPrimalValue());
operands.add(makePairInst->getDifferentialValue());
- auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands);
+ auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
makePairInst->replaceUsesWith(makeStructInst);
makePairInst->removeAndDeallocate();
@@ -1258,10 +2166,43 @@ struct JVPDerivativeContext
return modified;
}
+ bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent)
+ {
+ bool modified = false;
+
+ auto child = parent->getFirstChild();
+ while (child)
+ {
+ auto nextChild = child->getNextInst();
+
+ if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ {
+ child->removeAndDeallocate();
+ child = nextChild;
+ modified = true;
+ continue;
+ }
+
+ if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>())
+ {
+ getterDecoration->removeAndDeallocate();
+ }
+
+ if (child->getFirstChild() != nullptr)
+ {
+ modified |= stripDiffTypeInformation(builder, child);
+ }
+
+ child = nextChild;
+ }
+
+ return modified;
+ }
+
// Checks decorators to see if the function should
// be differentiated (kIROp_JVPDerivativeMarkerDecoration)
//
- bool isFunctionMarkedForJVP(IRGlobalValueWithCode* callable)
+ bool isMarkedForJVP(IRGlobalValueWithCode* callable)
{
for(auto decoration = callable->getFirstDecoration();
decoration;
@@ -1292,63 +2233,8 @@ struct JVPDerivativeContext
}
}
- List<IRParam*> emitFuncParameters(IRBuilder* builder, IRFuncType* dataType)
- {
- List<IRParam*> params;
- for(UIndex i = 0; i < dataType->getParamCount(); i++)
- {
- params.add(
- builder->emitParam(dataType->getParamType(i)));
- }
- return params;
- }
-
- // Perform forward-mode automatic differentiation on
- // the intstructions.
- //
- IRFunc* emitJVPFunction(IRBuilder* builder,
- IRFunc* primalFn)
- {
- eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn);
-
- builder->setInsertBefore(primalFn->getNextInst());
-
- auto jvpFn = builder->createFunc();
-
- SLANG_ASSERT(as<IRFuncType>(primalFn->getFullType()));
- IRType* jvpFuncType = transcriberStorage.differentiateFunctionType(
- builder,
- as<IRFuncType>(primalFn->getFullType()));
- jvpFn->setFullType(jvpFuncType);
-
- if (auto jvpName = getJVPFuncName(builder, primalFn))
- builder->addNameHintDecoration(jvpFn, jvpName);
-
- builder->setInsertInto(jvpFn);
-
- // Emit a block instruction for every block in the function, and map it as the
- // corresponding differential.
- //
- for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
- {
- auto jvpBlock = builder->emitBlock();
- transcriberStorage.mapDifferentialInst(block, jvpBlock);
- transcriberStorage.mapPrimalInst(block, jvpBlock);
- }
-
- // Go back over the blocks, and process the children of each block.
- for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
- {
- auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block));
- SLANG_ASSERT(jvpBlock);
- emitJVPBlock(builder, block, jvpBlock);
- }
-
- return jvpFn;
- }
-
IRStringLit* getJVPFuncName(IRBuilder* builder,
- IRFunc* func)
+ IRInst* func)
{
auto oldLoc = builder->getInsertLoc();
builder->setInsertBefore(func);
@@ -1368,36 +2254,6 @@ struct JVPDerivativeContext
return name;
}
- IRBlock* emitJVPBlock(IRBuilder* builder,
- IRBlock* origBlock,
- IRBlock* jvpBlock = nullptr)
- {
- JVPTranscriber* transcriber = &(transcriberStorage);
-
- // Create if not already created, and then insert into new block.
- if (!jvpBlock)
- jvpBlock = builder->emitBlock();
- else
- builder->setInsertInto(jvpBlock);
-
-
- // First transcribe every parameter in the block.
- for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
- {
- transcriber->transcribe(builder, param);
- }
-
- // Then, run through every instruction and use the transcriber to generate the appropriate
- // derivative code.
- //
- for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
- {
- transcriber->transcribe(builder, child);
- }
-
- return jvpBlock;
- }
-
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
module(module), sink(sink),
diffConformanceContextStorage(module->getModuleInst()),