diff options
Diffstat (limited to 'source/slang/ir-specialize.cpp')
| -rw-r--r-- | source/slang/ir-specialize.cpp | 473 |
1 files changed, 431 insertions, 42 deletions
diff --git a/source/slang/ir-specialize.cpp b/source/slang/ir-specialize.cpp index 57d14c11a..b57d2b58f 100644 --- a/source/slang/ir-specialize.cpp +++ b/source/slang/ir-specialize.cpp @@ -98,6 +98,9 @@ struct SpecializationContext // whether generic, existential, etc. // List<IRInst*> workList; + HashSet<IRInst*> workListSet; + + HashSet<IRInst*> cleanInsts; void addToWorkList( IRInst* inst) @@ -112,7 +115,14 @@ struct SpecializationContext return; } + if(workListSet.Contains(inst)) + return; + workList.add(inst); + workListSet.Add(inst); + cleanInsts.Remove(inst); + + addUsersToWorkList(inst); } // When a transformation makes a change to an instruction, @@ -367,6 +377,7 @@ struct SpecializationContext case kIROp_Specialize: case kIROp_lookup_interface_method: case kIROp_ExtractExistentialType: + case kIROp_BindExistentialsType: break; } } @@ -431,6 +442,13 @@ struct SpecializationContext maybeSpecializeLoad(as<IRLoad>(inst)); break; + case kIROp_FieldExtract: + maybeSpecializeFieldExtract(as<IRFieldExtract>(inst)); + break; + case kIROp_FieldAddress: + maybeSpecializeFieldAddress(as<IRFieldAddress>(inst)); + break; + case kIROp_BindExistentialsType: maybeSpecializeBindExistentialsType(as<IRBindExistentialsType>(inst)); break; @@ -566,12 +584,18 @@ struct SpecializationContext // addToWorkList(module->getModuleInst()); + while(workList.getCount() != 0) + { + // We will then iterate until our work list goes dry. // while(workList.getCount() != 0) { IRInst* inst = workList.getLast(); + workList.removeLast(); + workListSet.Remove(inst); + cleanInsts.Add(inst); // For each instruction we process, we want to perform // a few steps. @@ -610,6 +634,10 @@ struct SpecializationContext } } + addDirtyInstsToWorkListRec(module->getModuleInst()); + + } + // Once the work list has gone dry, we should have the invariant // that there are no `specialize` instructions inside of non-generic // functions that in turn reference a generic type/function, *except* @@ -617,6 +645,19 @@ struct SpecializationContext // which case we wouldn't want to specialize it anyway. } + void addDirtyInstsToWorkListRec(IRInst* inst) + { + if( !cleanInsts.Contains(inst) ) + { + addToWorkList(inst); + } + + for(auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addDirtyInstsToWorkListRec(child); + } + } + // Given a `call` instruction in the IR, we need to detect the case // where the callee has some interface-type parameter(s) and at the // call site it is statically clear what concrete type(s) the arguments @@ -719,6 +760,19 @@ struct SpecializationContext auto witnessTable = makeExistential->getWitnessTable(); key.vals.add(witnessTable); } + else if( auto wrapExistential = as<IRWrapExistential>(arg) ) + { + auto val = wrapExistential->getWrappedValue(); + auto valType = val->getFullType(); + key.vals.add(valType); + + UInt slotOperandCount = wrapExistential->getSlotOperandCount(); + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + auto slotOperand = wrapExistential->getSlotOperand(ii); + key.vals.add(slotOperand); + } + } else { SLANG_UNEXPECTED("missing case for existential argument"); @@ -770,6 +824,11 @@ struct SpecializationContext auto val = makeExistential->getWrappedValue(); newArgs.add(val); } + else if( auto wrapExistential = as<IRWrapExistential>(arg) ) + { + auto val = wrapExistential->getWrappedValue(); + newArgs.add(val); + } else { SLANG_UNEXPECTED("missing case for existential argument"); @@ -838,6 +897,13 @@ struct SpecializationContext if(as<IRMakeExistential>(inst)) return true; + // A `wrapExistential(v, T0,w0, T1, w1, ...)` instruction + // is just a generalization of `makeExistential`, so it + // can apply in the same cases. + // + if(as<IRWrapExistential>(inst)) + return true; + // If we start to specialize functions that take arrays // of existentials as input, we will need a strategy to // determine arguments suitable for use in specializing @@ -904,7 +970,66 @@ struct SpecializationContext // IRInst* replacementVal = nullptr; - if( !isExistentialType(oldParam->getDataType()) ) + // The trickier case is when we have an existential-type + // parameter, because we need to extract out the concrete + // type that is coming from the call site. + // + if( auto oldMakeExistential = as<IRMakeExistential>(arg) ) + { + // In this case, the `arg` is `makeExistential(val, witnessTable)` + // and we know that the specialized call site will just be + // passing in `val`. + // + auto val = oldMakeExistential->getWrappedValue(); + auto witnessTable = oldMakeExistential->getWitnessTable(); + + // Our specialized function needs to take a parameter with the + // same type as `val`, to match the call site(s) that will be + // created. + // + auto valType = val->getFullType(); + auto newParam = builder->createParam(valType); + newParams.add(newParam); + + // Within the body of the function we cannot just use `val` + // directly, because the existing code expects an existential + // value, including its witness table. + // + // Therefore we will create a `makeExistential(newParam, witnessTable)` + // in the body of the new function and use *that* as the replacement + // value for the original parameter (since it will have the + // correct existential type, and stores the right witness table). + // + auto newMakeExistential = builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable); + newBodyInsts.add(newMakeExistential); + replacementVal = newMakeExistential; + } + else if( auto oldWrapExistential = as<IRWrapExistential>(arg) ) + { + auto val = oldWrapExistential->getWrappedValue(); + auto valType = val->getFullType(); + + auto newParam = builder->createParam(valType); + newParams.add(newParam); + + // Within the body of the function we cannot just use `val` + // directly, because the existing code expects an existential + // value, including its witness table. + // + // Therefore we will create a `makeExistential(newParam, witnessTable)` + // in the body of the new function and use *that* as the replacement + // value for the original parameter (since it will have the + // correct existential type, and stores the right witness table). + // + auto newWrapExistential = builder->emitWrapExistential( + oldParam->getFullType(), + newParam, + oldWrapExistential->getSlotOperandCount(), + oldWrapExistential->getSlotOperands()); + newBodyInsts.add(newWrapExistential); + replacementVal = newWrapExistential; + } + else { // For parameters that don't have an existential type, // there is nothing interesting to do. The new function @@ -915,47 +1040,6 @@ struct SpecializationContext newParams.add(newParam); replacementVal = newParam; } - else - { - // The trickier case is when we have an existential-type - // parameter, because we need to extract out the concrete - // type that is coming from the call site. - // - if( auto oldMakeExistential = as<IRMakeExistential>(arg) ) - { - // In this case, the `arg` is `makeExistential(val, witnessTable)` - // and we know that the specialized call site will just be - // passing in `val`. - // - auto val = oldMakeExistential->getWrappedValue(); - auto witnessTable = oldMakeExistential->getWitnessTable(); - - // Our specialized function needs to take a parameter with the - // same type as `val`, to match the call site(s) that will be - // created. - // - auto valType = val->getFullType(); - auto newParam = builder->createParam(valType); - newParams.add(newParam); - - // Within the body of the function we cannot just use `val` - // directly, because the existing code expects an existential - // value, including its witness table. - // - // Therefore we will create a `makeExistential(newParam, witnessTable)` - // in the body of the new function and use *that* as the replacement - // value for the original parameter (since it will have the - // correct existential type, and stores the right witness table). - // - auto newMakeExistential = builder->emitMakeExistential(oldParam->getFullType(), newParam, witnessTable); - newBodyInsts.add(newMakeExistential); - replacementVal = newMakeExistential; - } - else - { - SLANG_UNEXPECTED("missing case for existential argument"); - } - } // Whatever replacement value was constructed, we need to // register it as the replacement for the original parameter. @@ -1207,6 +1291,246 @@ struct SpecializationContext } } + UInt calcExistentialBoxSlotCount(IRType* type) + { + top: + if( as<IRExistentialBoxType>(type) ) + { + return 2; + } + else if( auto ptrType = as<IRPtrTypeBase>(type) ) + { + type = ptrType->getValueType(); + goto top; + } + else if( auto ptrLikeType = as<IRPointerLikeType>(type) ) + { + type = ptrLikeType->getElementType(); + goto top; + } + else if( auto structType = as<IRStructType>(type) ) + { + UInt count = 0; + for( auto field : structType->getFields() ) + { + count += calcExistentialBoxSlotCount(field->getFieldType()); + } + return count; + } + else + { + return 0; + } + } + + void maybeSpecializeFieldExtract(IRFieldExtract* inst) + { + auto baseArg = inst->getBase(); + auto fieldKey = inst->getField(); + + if( auto wrapInst = as<IRWrapExistential>(baseArg) ) + { + // We have `getField(wrapExistential(val, ...), fieldKey)` + // + auto val = wrapInst->getWrappedValue(); + + // We know what type we are expected to produce. + // + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + // We'd *like* to replace this instruction with + // `wrapExistential(getField(val, fieldKey), ...)` instead, since that + // will enable subsequent specializations. + // + // To do that, we need to figure out: + // + // 1. What type that inner `getField` would return (what + // is the type of the `fieldKey` field in `val`?) + // + // 2. Which of the existential slot operands in `...` there + // actually apply to the given field. + // + + // To determine these things, we need the type of + // `val` to be a structure type so that we can look + // up the field corresponding to `fieldKey`. + // + auto valType = val->getDataType(); + auto valStructType = as<IRStructType>(valType); + if(!valStructType) + return; + + UInt slotOperandOffset = 0; + + IRStructField* foundField = nullptr; + for( auto valField : valStructType->getFields() ) + { + if( valField->getKey() == fieldKey ) + { + foundField = valField; + break; + } + + slotOperandOffset += calcExistentialBoxSlotCount(valField->getFieldType()); + } + + if(!foundField) + return; + + auto foundFieldType = foundField->getFieldType(); + + List<IRInst*> slotOperands; + UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); + + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + slotOperands.add(wrapInst->getSlotOperand(slotOperandOffset + ii)); + } + + auto newGetField = builder.emitFieldExtract( + foundFieldType, + val, + fieldKey); + + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, + newGetField, + slotOperandCount, + slotOperands.getBuffer()); + + addUsersToWorkList(inst); + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + + + void maybeSpecializeFieldAddress(IRFieldAddress* inst) + { + auto baseArg = inst->getBase(); + auto fieldKey = inst->getField(); + + if( auto wrapInst = as<IRWrapExistential>(baseArg) ) + { + // We have `getFieldAddr(wrapExistential(val, ...), fieldKey)` + // + auto val = wrapInst->getWrappedValue(); + + // We know what type we are expected to produce. + // + auto resultType = inst->getFullType(); + + IRBuilder builder; + builder.sharedBuilder = &sharedBuilderStorage; + builder.setInsertBefore(inst); + + // We'd *like* to replace this instruction with + // `wrapExistential(getFieldAddr(val, fieldKey), ...)` instead, since that + // will enable subsequent specializations. + // + // To do that, we need to figure out: + // + // 1. What type that inner `getFieldAddr` would return (what + // is the type of the `fieldKey` field in `val`?) + // + // 2. Which of the existential slot operands in `...` there + // actually apply to the given field. + // + + // To determine these things, we need the type of + // `val` to be a (pointer to a) structure type so that we can look + // up the field corresponding to `fieldKey`. + // + auto valType = tryGetPointedToType(&builder, val->getDataType()); + if(!valType) + return; + + auto valStructType = as<IRStructType>(valType); + if(!valStructType) + return; + + UInt slotOperandOffset = 0; + + IRStructField* foundField = nullptr; + for( auto valField : valStructType->getFields() ) + { + if( valField->getKey() == fieldKey ) + { + foundField = valField; + break; + } + + slotOperandOffset += calcExistentialBoxSlotCount(valField->getFieldType()); + } + + if(!foundField) + return; + + auto foundFieldType = foundField->getFieldType(); + + List<IRInst*> slotOperands; + UInt slotOperandCount = calcExistentialBoxSlotCount(foundFieldType); + + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + slotOperands.add(wrapInst->getSlotOperand(slotOperandOffset + ii)); + } + + auto newGetFieldAddr = builder.emitFieldAddress( + builder.getPtrType(foundFieldType), + val, + fieldKey); + + auto newWrapExistentialInst = builder.emitWrapExistential( + resultType, + newGetFieldAddr, + slotOperandCount, + slotOperands.getBuffer()); + + addUsersToWorkList(inst); + inst->replaceUsesWith(newWrapExistentialInst); + inst->removeAndDeallocate(); + } + } + + UInt calcExistentialTypeParamSlotCount(IRType* type) + { + top: + if( as<IRInterfaceType>(type) ) + { + return 2; + } + else if( auto ptrType = as<IRPtrTypeBase>(type) ) + { + type = ptrType->getValueType(); + goto top; + } + else if( auto ptrLikeType = as<IRPointerLikeType>(type) ) + { + type = ptrLikeType->getElementType(); + goto top; + } + else if( auto structType = as<IRStructType>(type) ) + { + UInt count = 0; + for( auto field : structType->getFields() ) + { + count += calcExistentialTypeParamSlotCount(field->getFieldType()); + } + return count; + } + else + { + return 0; + } + } + + Dictionary<IRSimpleSpecializationKey, IRStructType*> existentialSpecializedStructs; + void maybeSpecializeBindExistentialsType(IRBindExistentialsType* type) { auto baseType = type->getBaseType(); @@ -1253,17 +1577,82 @@ struct SpecializationContext baseElementType, slotOperandCount, type->getExistentialArgs()); + addToWorkList(wrappedElementType); auto newPtrLikeType = builder.getType( basePtrLikeType->op, 1, &wrappedElementType); + addToWorkList(newPtrLikeType); addUsersToWorkList(type); type->replaceUsesWith(newPtrLikeType); type->removeAndDeallocate(); return; } + else if( auto baseStructType = as<IRStructType>(baseType) ) + { + // In order to bind a `struct` type we will generate + // a new specialized `struct` type on demand and then + // cache and re-use it. + // + // We don't want to start specializing here unless + // all the operand types (and witness tables) we + // will be specializing to are themselves fully + // specialized, so that we can be sure that we + // have a unique type. + // + if( !areAllOperandsFullySpecialized(type) ) + return; + + // Now we we check to see if we've already created + // a specialized struct type or not. + // + IRSimpleSpecializationKey key; + key.vals.add(baseStructType); + for( UInt ii = 0; ii < slotOperandCount; ++ii ) + { + key.vals.add(type->getExistentialArg(ii)); + } + + IRStructType* newStructType = nullptr; + if( !existentialSpecializedStructs.TryGetValue(key, newStructType) ) + { + builder.setInsertBefore(baseStructType); + newStructType = builder.createStructType(); + + auto fieldSlotArgs = type->getExistentialArgs(); + + for( auto oldField : baseStructType->getFields() ) + { + // TODO: we need to figure out which of the specialization arguments + // apply to this field... + + auto oldFieldType = oldField->getFieldType(); + auto fieldSlotArgCount = calcExistentialTypeParamSlotCount(oldFieldType); + + auto newFieldType = builder.getBindExistentialsType( + oldFieldType, + fieldSlotArgCount, + fieldSlotArgs); + + addToWorkList(newFieldType); + + fieldSlotArgs += fieldSlotArgCount; + + builder.createStructField(newStructType, oldField->getKey(), newFieldType); + } + + existentialSpecializedStructs.Add(key, newStructType); + addToWorkList(newStructType); + } + + addUsersToWorkList(type); + type->replaceUsesWith(newStructType); + type->removeAndDeallocate(); + return; + + } } // The handling of specialization for global generic type |
