diff options
| author | Yong He <yonghe@outlook.com> | 2023-03-14 09:31:51 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-14 09:31:51 -0700 |
| commit | e291f60c6b083eaa74aed5307a6e9461274c1642 (patch) | |
| tree | bde9b45a9e09ebbe173fae1821237b258a9ff800 /source | |
| parent | a911ca6e06ce41e403b80fe6054162393491c8ac (diff) | |
Support `fwd_diff(bwd_diff(f))`. (#2697)
* Support `fwd_diff(bwd_diff(f))`.
* Fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-addr-inst-elimination.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 414 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 29 | ||||
| -rw-r--r-- | source/slang/slang-ir-peephole.cpp | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 104 |
9 files changed, 580 insertions, 14 deletions
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 42bc88106..6715f2c6a 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -150,10 +150,13 @@ struct AddressInstEliminationContext for (auto use = addrInst->firstUse; use; ) { + auto nextUse = use->nextUse; + if (as<IRDecoration>(use->getUser())) + { + use = nextUse; continue; - - auto nextUse = use->nextUse; + } switch (use->getUser()->getOp()) { diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 091e7f1ab..0f51a6c62 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -258,10 +258,6 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType); return (IRType*)findOrTranscribePrimalInst(builder, diffType); } - else if (origType->getOp() == kIROp_LookupWitness) - { - return (IRType*)findOrTranscribePrimalInst(builder, (IRInst*)primalType); - } return (IRType*)transcribe(builder, origType); } @@ -282,9 +278,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy { case kIROp_Param: if (as<IRTypeType>(primalType->getDataType())) - return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( - builder, - (IRType*)primalType)); + return differentiateType(builder, origType); else if (as<IRWitnessTableType>(primalType->getDataType())) return (IRType*)primalType; diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 2347c7a8f..e01452972 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -307,7 +307,12 @@ struct ExtractPrimalFuncContext IRCloneEnv cloneEnv; fieldType = cloneInst(&cloneEnv, &genTypeBuilder, fieldType); } - return genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); + auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); + if (auto diffFieldType = backwardPrimalTranscriber->differentiateType(&genTypeBuilder, (IRType*)fieldType)) + { + genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, diffFieldType); + } + return structField; } void storeInst( diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index edea3847d..4d22d9eed 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -482,6 +482,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent) case kIROp_BackwardDerivativePrimalReturnDecoration: case kIROp_AutoDiffOriginalValueDecoration: case kIROp_UserDefinedBackwardDerivativeDecoration: + case kIROp_IntermediateContextFieldDifferentialTypeDecoration: decor->removeAndDeallocate(); break; default: @@ -631,6 +632,8 @@ struct AutoDiffPass : public InstPassBase { List<IRInst*> args; auto subBase = processIntermediateContextTypeBase(builder, spec->getBase()); + if (!subBase) + return nullptr; for (UInt a = 0; a < spec->getArgCount(); a++) args.add(spec->getArg(a)); auto actualType = builder->emitSpecializeInst( @@ -645,6 +648,9 @@ struct AutoDiffPass : public InstPassBase auto inner = findGenericReturnVal(baseGeneric); if (auto typeDecor = inner->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) { + if (!isTypeFullyDifferentiated(typeDecor->getBackwardDerivativeIntermediateType())) + return nullptr; + return typeDecor->getBackwardDerivativeIntermediateType(); } } @@ -652,6 +658,8 @@ struct AutoDiffPass : public InstPassBase { if (auto typeDecor = func->findDecoration<IRBackwardDerivativeIntermediateTypeDecoration>()) { + if (!isTypeFullyDifferentiated(typeDecor->getBackwardDerivativeIntermediateType())) + return nullptr; return typeDecor->getBackwardDerivativeIntermediateType(); } } @@ -671,6 +679,10 @@ struct AutoDiffPass : public InstPassBase bool lowerIntermediateContextType(IRBuilder* builder) { bool changed = false; + OrderedHashSet<IRInst*> loweredIntermediateTypes; + + // Replace all `BackwardDiffIntermediateContextType` insts with the struct type + // that we generated during backward diff pass. processAllInsts([&](IRInst* inst) { switch (inst->getOp()) @@ -685,6 +697,7 @@ struct AutoDiffPass : public InstPassBase auto type = processIntermediateContextTypeBase(&subBuilder, baseFunc); if (type) { + loweredIntermediateTypes.Add(type); inst->replaceUsesWith(type); inst->removeAndDeallocate(); changed = true; @@ -695,14 +708,400 @@ struct AutoDiffPass : public InstPassBase break; } }); + + // Now we generate the differential type for the intermediate context type + // to allow higher order differentiation. + generateDifferentialImplementationForContextType(loweredIntermediateTypes); return changed; } + // Utility function for topology sorting the intermediate context types. + bool isIntermediateContextTypeReadyForProcess(OrderedHashSet<IRInst*>& contextTypes, OrderedHashSet<IRInst*>& sortedSet, IRInst* t) + { + if (!contextTypes.Contains(t)) + return true; + + switch (t->getOp()) + { + case kIROp_StructType: + { + bool canAddNow = true; + for (auto f : as<IRStructType>(t)->getFields()) + { + if (!isIntermediateContextTypeReadyForProcess(contextTypes, sortedSet, f->getFieldType())) + { + canAddNow = false; + break; + } + } + return canAddNow; + } + case kIROp_Specialize: + return isIntermediateContextTypeReadyForProcess(contextTypes, sortedSet, as<IRSpecialize>(t)->getBase()); + case kIROp_Generic: + return isIntermediateContextTypeReadyForProcess(contextTypes, sortedSet, findGenericReturnVal(as<IRGeneric>(t))); + default: + return true; + } + } + + struct IntermediateContextTypeDifferentialInfo + { + IRInst* diffType = nullptr; + IRInst* diffWitness = nullptr; + IRInst* diffDiffWitness = nullptr; + IRInst* zeroMethod = nullptr; + IRInst* addMethod = nullptr; + }; + + // Register the differential type for an intermediate context type to the derivative functions that uses the type. + void registerDiffContextType( + IRBuilder& builder, + IRDifferentiableTypeDictionaryDecoration* diffDecor, + OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes, + IRInst* origType) + { + HashSet<IRInst*> registeredType; + for (auto entry : diffDecor->getChildren()) + { + if (auto e = as<IRDifferentiableTypeDictionaryItem>(entry)) + { + registeredType.Add(e->getOperand(0)); + } + } + // Use a work list to recursively walk through all sub fields of the struct type. + List<IRInst*> wlist; + wlist.add(origType); + for (Index i = 0; i < wlist.getCount(); i++) + { + auto t = wlist[i]; + IntermediateContextTypeDifferentialInfo diffInfo; + if (!diffTypes.TryGetValue(t, diffInfo)) + continue; + if (registeredType.Add(t)) + builder.addDifferentiableTypeEntry(diffDecor, t, diffInfo.diffWitness); + else + continue; + + if (auto structType = as<IRStructType>(getResolvedInstForDecorations(t))) + { + for (auto f : structType->getFields()) + { + wlist.add(f->getFieldType()); + } + } + } + } + + void generateDifferentialImplementationForContextType(OrderedHashSet<IRInst*>& contextTypes) + { + // First we are going to topology sort all intermediate context types. + OrderedHashSet<IRInst*> sortedContextTypes; + for (;;) + { + auto lastCount = sortedContextTypes.Count(); + for (auto t : contextTypes) + { + if (sortedContextTypes.Contains(t)) + continue; + // Have all dependent types been added yet? + if (isIntermediateContextTypeReadyForProcess(contextTypes, sortedContextTypes, t)) + sortedContextTypes.Add(t); + } + if (lastCount == sortedContextTypes.Count()) + break; + } + + // After the types are sorted, we start to generate the differential type and IDifferentiable witnesses + // for them. + + OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo> diffTypes; + IRBuilder builder(module); + for (auto t : sortedContextTypes) + { + if (t->getOp() == kIROp_Generic || t->getOp() == kIROp_StructType) + { + // For generics/struct types, we will generate a new generic/struct type representing the differntial. + + SLANG_RELEASE_ASSERT(t->getParent() && t->getParent()->getOp() == kIROp_Module); + builder.setInsertBefore(t); + auto diffInfo = fillDifferentialTypeImplementation(diffTypes, t); + diffTypes[t] = diffInfo; + } + else if (auto specialize = as<IRSpecialize>(t)) + { + // A specialize of a context type translates to a specialize of its differential type/witness. + + IntermediateContextTypeDifferentialInfo baseInfo; + SLANG_RELEASE_ASSERT(diffTypes.TryGetValue(specialize->getBase(), baseInfo)); + builder.setInsertBefore(t); + List<IRInst*> args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + IntermediateContextTypeDifferentialInfo info; + info.diffType = builder.emitSpecializeInst( + builder.getTypeKind(), baseInfo.diffType, (UInt)args.getCount(), args.getBuffer()); + info.diffWitness = builder.emitSpecializeInst( + builder.getWitnessTableType(autodiffContext->differentiableInterfaceType), + baseInfo.diffWitness, + (UInt)args.getCount(), + args.getBuffer()); + diffTypes[t] = info; + } + else + { + // If `t` is not a specialize, it'd better be processed by now. + // We currently don't support the `LookupInterfaceMethod` case, since it can't + // appear in a derivative function because we will only call the backward diff function without a intermediate-type + // via an interface. + SLANG_RELEASE_ASSERT(diffTypes.ContainsKey(t)); + } + } + + // Register the differential types into the conformance dictionaries of the functions that uses them. + for (auto t : diffTypes) + { + HashSet<IRFunc*> registeredFuncs; + for (auto use = t.Key->firstUse; use; use = use->nextUse) + { + auto parentFunc = getParentFunc(use->getUser()); + if (!parentFunc) + continue; + if (!registeredFuncs.Add(parentFunc)) + continue; + if (auto dictDecor = parentFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + registerDiffContextType(builder, dictDecor, diffTypes, t.Key); + } + } + } + } + + IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementationForStruct( + OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes, + IRStructType* originalType, + IRStructType* diffType) + { + IntermediateContextTypeDifferentialInfo result; + result.diffType = diffType; + + IRBuilder builder(diffType); + builder.setInsertInto(diffType); + + // Generate the fields for all differentiable members of the original struct type. + for (auto field : originalType->getFields()) + { + IRInst* diffFieldType = nullptr; + if (auto diffDecor = field->findDecoration<IRIntermediateContextFieldDifferentialTypeDecoration>()) + { + diffFieldType = diffDecor->getDifferentialType(); + } + else + { + IntermediateContextTypeDifferentialInfo diffFieldTypeInfo; + diffTypes.TryGetValue(field->getDataType(), diffFieldTypeInfo); + diffFieldType = diffFieldTypeInfo.diffType; + } + if (diffFieldType) + { + IRBuilder keyBuilder = builder; + keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); + auto diffKey = keyBuilder.createStructKey(); + builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); + builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); + builder.addDecoration(diffKey, kIROp_DerivativeMemberDecoration, diffKey); + } + } + + builder.setInsertAfter(diffType); + + // For now, we are going to structurally derive dadd and dzero methods for intermediate context types, + // because it is tricky for us to obtain the original witness tables for the fields at this point. + // This is inconsistent with how we are dealing with dadd and dzero methods via witness table lookup, + // and can lead to problems if the user defines any non-trivial dadd/dzero methods. + // + // TODO: we should consider rewrite this logic to be witness table lookup based, or simplify the entire + // type system and IR passes to always use structurally derived methods instead of user-provided + // methods. + IRInst* zeroMethod = nullptr; + { + auto zeroMethodType = builder.getFuncType(List<IRType*>(), diffType); + zeroMethod = builder.createFunc(); + zeroMethod->setFullType(zeroMethodType); + result.zeroMethod = zeroMethod; + builder.setInsertInto(zeroMethod); + builder.emitBlock(); + builder.emitReturn(builder.emitDefaultConstruct(diffType)); + } + + builder.setInsertAfter(zeroMethod); + IRInst* addMethod = nullptr; + { + List<IRType*> paramTypes; + paramTypes.add(diffType); + paramTypes.add(diffType); + auto addMethodType = builder.getFuncType(List<IRType*>(), diffType); + addMethod = builder.createFunc(); + result.addMethod = addMethod; + addMethod->setFullType(addMethodType); + builder.setInsertInto(addMethod); + builder.emitBlock(); + auto param1 = builder.emitParam(diffType); + auto param2 = builder.emitParam(diffType); + builder.emitReturn(builder.emitStructuralAdd(param1, param2)); + } + + builder.setInsertAfter(addMethod); + auto diffTypeIsDiffWitness = builder.createWitnessTable(autodiffContext->differentiableInterfaceType, diffType); + auto origTypeIsDiffWitness = builder.createWitnessTable(autodiffContext->differentiableInterfaceType, originalType); + result.diffWitness = origTypeIsDiffWitness; + + builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->differentialAssocTypeStructKey, diffType); + builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->differentialAssocTypeWitnessStructKey, diffTypeIsDiffWitness); + builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->zeroMethodStructKey, zeroMethod); + builder.createWitnessTableEntry(origTypeIsDiffWitness, autodiffContext->addMethodStructKey, addMethod); + + builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->differentialAssocTypeStructKey, diffType); + builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->differentialAssocTypeWitnessStructKey, diffTypeIsDiffWitness); + builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->zeroMethodStructKey, zeroMethod); + builder.createWitnessTableEntry(diffTypeIsDiffWitness, autodiffContext->addMethodStructKey, addMethod); + return result; + } + + IntermediateContextTypeDifferentialInfo fillDifferentialTypeImplementation( + OrderedDictionary<IRInst*, IntermediateContextTypeDifferentialInfo>& diffTypes, + IRInst* originalType) + { + if (originalType->getOp() == kIROp_StructType) + { + IRBuilder builder(originalType); + builder.setInsertBefore(originalType); + auto diffType = builder.createStructType(); + return fillDifferentialTypeImplementationForStruct( + diffTypes, + as<IRStructType>(originalType), + as<IRStructType>(diffType)); + } + else if (auto genType = as<IRGeneric>(originalType)) + { + // For generics, we process the inner struct type as normal, + // and then hoist the additional insts we created from the generic. + + auto structType = as<IRStructType>(findGenericReturnVal(genType)); + SLANG_RELEASE_ASSERT(structType); + + auto innerResult = fillDifferentialTypeImplementation(diffTypes, structType); + IRBuilder builder(originalType); + builder.setInsertBefore(originalType); + + // Now we hoist the new values from the generic to form their independent generics. + IRInst* specInst = nullptr; + IntermediateContextTypeDifferentialInfo result; + if (innerResult.diffType) + result.diffType = hoistValueFromGeneric(builder, innerResult.diffType, specInst, true); + if (innerResult.zeroMethod) + { + hoistValueFromGeneric(builder, innerResult.zeroMethod->getFullType(), specInst, true); + result.zeroMethod = hoistValueFromGeneric(builder, innerResult.zeroMethod, specInst, true); + } + if (innerResult.addMethod) + { + hoistValueFromGeneric(builder, innerResult.addMethod->getFullType(), specInst, true); + result.addMethod = hoistValueFromGeneric(builder, innerResult.addMethod, specInst, true); + } + if (innerResult.diffDiffWitness) + result.diffDiffWitness = hoistValueFromGeneric(builder, innerResult.diffDiffWitness, specInst, true); + if (innerResult.diffWitness) + { + builder.setInsertBefore(innerResult.diffWitness); + List<IRInst*> args; + for (auto param : genType->getParams()) + args.add(param); + as<IRWitnessTable>(innerResult.diffWitness)->setConcreteType((IRType*)builder.emitSpecializeInst( + builder.getTypeKind(), originalType, (UInt)args.getCount(), args.getBuffer())); + result.diffWitness = hoistValueFromGeneric(builder, innerResult.diffWitness, specInst, true); + } + return result; + } + return IntermediateContextTypeDifferentialInfo(); + } + + HashSet<IRInst*> fullyDifferentiatedInsts; + + // Returns true if `type` is fully differentiated, i.e. does not have + // any unmaterialized intermediate context types. + bool isTypeFullyDifferentiated(IRInst* type) + { + if (fullyDifferentiatedInsts.Contains(type)) + return true; + if (type->getOp() == kIROp_BackwardDiffIntermediateContextType) + return false; + if (auto structType = as<IRStructType>(type)) + { + for (auto f : structType->getFields()) + if (!isTypeFullyDifferentiated(f->getFieldType())) + return false; + } + else if (auto genType = as<IRGeneric>(type)) + { + bool result = isTypeFullyDifferentiated(findGenericReturnVal(genType)); + if (result) + fullyDifferentiatedInsts.Add(genType); + return result; + } + switch (type->getOp()) + { + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + case kIROp_InOutType: + case kIROp_OutType: + case kIROp_PtrType: + case kIROp_DifferentialPairType: + case kIROp_DifferentialPairUserCodeType: + case kIROp_AttributedType: + for (UInt i = 0; i < type->getOperandCount(); i++) + if (!isTypeFullyDifferentiated(type->getOperand(i))) + return false; + default: + fullyDifferentiatedInsts.Add(type); + return true; + } + } + + // Returns true if `func` is fully differentiated, i.e. does not have + // any differentiate insts. + bool isFullyDifferentiated(IRFunc* func) + { + if (fullyDifferentiatedInsts.Contains(func)) + return true; + + for (auto block : func->getBlocks()) + { + for (auto ii : block->getChildren()) + { + switch (ii->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + case kIROp_BackwardDiffIntermediateContextType: + return false; + } + if (ii->getDataType() && !isTypeFullyDifferentiated(ii->getDataType())) + return false; + } + } + fullyDifferentiatedInsts.Add(func); + return true; + } + // Process all differentiate calls, and recursively generate code for forward and backward // derivative functions. // bool processReferencedFunctions(IRBuilder* builder) { + fullyDifferentiatedInsts.Clear(); bool hasChanges = false; for (;;) { @@ -725,6 +1124,12 @@ struct AutoDiffPass : public InstPassBase case kIROp_Func: case kIROp_Specialize: case kIROp_LookupWitness: + if (auto innerFunc = as<IRFunc>(getResolvedInstForDecorations(inst->getOperand(0)))) + { + // Skip functions whose body still has a differentiate inst (higher order func). + if (!isFullyDifferentiated(innerFunc)) + return; + } autoDiffWorkList.add(inst); break; default: @@ -845,6 +1250,11 @@ struct AutoDiffPass : public InstPassBase if (!changed) break; + if (lowerIntermediateContextType(builder)) + { + hasChanges = true; + } + // We have done transcribing the functions, now it is time to demote all DifferentialPair types // and their operations down to DifferentialPairUserCodeType and *UserCode operations so they // can be treated just like normal types with no special semantics in future processing, and won't @@ -854,10 +1264,6 @@ struct AutoDiffPass : public InstPassBase hasChanges |= changed; } - if (lowerIntermediateContextType(builder)) - { - hasChanges = true; - } return hasChanges; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 28c682c91..bb8cfc378 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -320,6 +320,9 @@ INST(MakeOptionalValue, makeOptionalValue, 1, 0) INST(MakeOptionalNone, makeOptionalNone, 1, 0) INST(Call, call, 1, 0) +// Structural addition of two values of the same type. +INST(StructuralAdd, structuralAdd, 2, 0) + INST(RTTIObject, rtti_object, 0, 0) INST(Alloca, alloca, 1, 0) @@ -814,6 +817,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// forward-differentiated updateElement inst. INST(PrimalElementTypeDecoration, primalElementType, 1, 0) + /// Used by the auto-diff pass to mark the differential type of an intermediate context field. + INST(IntermediateContextFieldDifferentialTypeDecoration, IntermediateContextFieldDifferentialTypeDecoration, 1, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f3181ecf7..43893bfe6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -783,6 +783,20 @@ struct IRPrimalElementTypeDecoration : IRDecoration IRInst* getPrimalElementType() { return getOperand(0); } }; +struct IRIntermediateContextFieldDifferentialTypeDecoration : IRDecoration +{ + enum + { + kOp = kIROp_IntermediateContextFieldDifferentialTypeDecoration + }; + + IR_LEAF_ISA(IntermediateContextFieldDifferentialTypeDecoration) + + IRInst* getDifferentialType() { return getOperand(0); } + IRInst* getDifferentialWitness() { return getOperand(1); } + +}; + struct IRBackwardDifferentiableDecoration : IRDecoration { enum @@ -2205,6 +2219,11 @@ struct IRWitnessTable : IRInst return (IRType*) getOperand(0); } + void setConcreteType(IRType* t) + { + return setOperand(0, t); + } + IR_LEAF_ISA(WitnessTable) }; @@ -2867,6 +2886,7 @@ public: IRInst* addPrimalValueStructKeyDecoration(IRInst* target, IRStructKey* key); IRInst* addPrimalElementTypeDecoration(IRInst* target, IRInst* type); + IRInst* addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type); // Add a differentiable type entry to the appropriate dictionary. IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness); @@ -2949,6 +2969,15 @@ public: /// the inst. IRInst* emitDefaultConstructRaw(IRType* type); + /// Emits appropriate inst for structurally adding two values of `type`. + /// If `fallback` is true, will emit `StructuralAdd` inst on unknown types. + /// Otherwise, returns nullptr if we can't materialize the inst. + IRInst* emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback = true); + + /// Emits a raw `StructuralAdd` opcode without attempting to fold/materialize + /// the inst. + IRInst* emitStructuralAddRaw(IRInst* val0, IRInst* val1); + IRInst* emitCast( IRType* type, IRInst* value); diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index 5d5a41726..a5ec50b2c 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -633,6 +633,19 @@ struct PeepholeContext : InstPassBase } } break; + case kIROp_StructuralAdd: + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + // See if we can replace the generic add inst with concrete values. + if (auto newCtor = builder.emitStructuralAdd(inst->getOperand(0), inst->getOperand(1), false)) + { + inst->replaceUsesWith(newCtor); + maybeRemoveOldInst(inst); + changed = true; + } + } + break; case kIROp_Add: case kIROp_Mul: case kIROp_Sub: diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index d5c0aa432..55e0f0168 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -208,6 +208,12 @@ namespace Slang if (allInGlobalScope) return; + // Allow exceptions. + switch (inst->getOp()) + { + case kIROp_DifferentiableTypeDictionaryItem: + return; + } // // We failed to find `operandParent` while walking the ancestors of `inst`, // so something had gone wrong. diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 08c066f5d..f61e5a10e 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3509,6 +3509,104 @@ namespace Slang return nullptr; } + IRInst* IRBuilder::emitStructuralAddRaw(IRInst* val0, IRInst* val1) + { + IRInst* args[2] = { val0, val1 }; + return emitIntrinsicInst(val0->getFullType(), kIROp_StructuralAdd, 2, args); + } + + IRInst* IRBuilder::emitStructuralAdd(IRInst* val0, IRInst* val1, bool fallback) + { + auto type = val0->getFullType(); + SLANG_RELEASE_ASSERT(val0->getFullType() == val1->getFullType()); + IRType* actualType = val0->getFullType(); + for (;;) + { + if (auto attr = as<IRAttributedType>(actualType)) + actualType = attr->getBaseType(); + else if (auto rateQualified = as<IRRateQualifiedType>(actualType)) + actualType = rateQualified->getValueType(); + else + break; + } + if (as<IRBasicType>(actualType)) + return emitAdd(type, val0, val1); + + switch (actualType->getOp()) + { + case kIROp_PtrType: + case kIROp_VectorType: + case kIROp_MatrixType: + return emitAdd(type, val0, val1); + case kIROp_TupleType: + { + List<IRInst*> elements; + auto tupleType = as<IRTupleType>(actualType); + for (UInt i = 0; i < tupleType->getOperandCount(); i++) + { + auto operand = tupleType->getOperand(i); + if (as<IRAttr>(operand)) + break; + auto inner = emitStructuralAdd( + emitGetTupleElement((IRType*)operand, val0, i), + emitGetTupleElement((IRType*)operand, val1, i), + fallback); + if (!inner) + return nullptr; + elements.add(inner); + } + return emitMakeTuple(tupleType, elements); + } + case kIROp_StructType: + { + List<IRInst*> elements; + auto structType = as<IRStructType>(actualType); + for (auto field : structType->getFields()) + { + auto fieldType = field->getFieldType(); + auto inner = emitStructuralAdd( + emitFieldExtract(fieldType, val0, field->getKey()), + emitFieldExtract(fieldType, val1, field->getKey()), + fallback); + if (!inner) + return nullptr; + elements.add(inner); + } + return emitMakeStruct(type, elements); + } + case kIROp_ArrayType: + { + auto arrayType = as<IRArrayType>(actualType); + if (auto count = as<IRIntLit>(arrayType->getElementCount())) + { + auto elementType = arrayType->getElementType(); + List<IRInst*> elements; + constexpr int maxCount = 4096; + if (count->getValue() > maxCount) + break; + for (IRIntegerValue i = 0; i < count->getValue(); i++) + { + auto index = getIntValue(getIntType(), i); + auto element = emitStructuralAdd( + emitElementExtract(elementType, val0, index), + emitElementExtract(elementType, val1, index), + fallback); + elements.add(element); + } + return emitMakeArray(type, elements.getCount(), elements.getBuffer()); + } + break; + } + default: + break; + } + if (fallback) + { + return emitStructuralAddRaw(val0, val1); + } + return nullptr; + } + static int _getTypeStyleId(IRType* type) { if (auto vectorType = as<IRVectorType>(type)) @@ -3928,6 +4026,11 @@ namespace Slang return addDecoration(target, kIROp_PrimalElementTypeDecoration, type); } + IRInst* IRBuilder::addIntermediateContextFieldDifferentialTypeDecoration(IRInst* target, IRInst* type) + { + return addDecoration(target, kIROp_IntermediateContextFieldDifferentialTypeDecoration, type); + } + RefPtr<IRModule> IRModule::create(Session* session) { RefPtr<IRModule> module = new IRModule(session); @@ -7028,6 +7131,7 @@ namespace Slang case kIROp_Nop: case kIROp_undefined: case kIROp_DefaultConstruct: + case kIROp_StructuralAdd: case kIROp_Specialize: case kIROp_LookupWitness: case kIROp_GetSequentialID: |
