diff options
27 files changed, 984 insertions, 99 deletions
@@ -1943,6 +1943,7 @@ extern "C" SLANG_TYPE_KIND_MESH_OUTPUT, SLANG_TYPE_KIND_SPECIALIZED, SLANG_TYPE_KIND_FEEDBACK, + SLANG_TYPE_KIND_POINTER, SLANG_TYPE_KIND_COUNT, }; @@ -2476,6 +2477,7 @@ namespace slang OutputStream = SLANG_TYPE_KIND_OUTPUT_STREAM, Specialized = SLANG_TYPE_KIND_SPECIALIZED, Feedback = SLANG_TYPE_KIND_FEEDBACK, + Pointer = SLANG_TYPE_KIND_POINTER, }; enum ScalarType : SlangScalarTypeIntegral diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 6e43b378a..a2fdaf32e 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -208,7 +208,7 @@ SLANG_FORCE_INLINE const T* as(const Type* obj); // `typedef` which gives them a good name when printed as // part of diagnostic messages. // -// In order to operation on types, though, we often want +// In order to operate on types, though, we often want // to look past any sugar, and operate on an underlying // "canonical" type. The representation caches a pointer to // a canonical type on every type, so we can easily diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 872a7a8b2..0b5a91d87 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -3738,6 +3738,21 @@ void CLikeSourceEmitter::ensureInstOperandsRec(ComputeEmitActionsContext* ctx, I auto requiredLevel = EmitAction::Definition; switch (inst->getOp()) { + case kIROp_PtrType: + { + auto ptrType = static_cast<IRPtrType*>(inst); + auto valueType = ptrType->getValueType(); + + if (ctx->openInsts.contains(valueType)) + { + requiredLevel = EmitAction::ForwardDeclaration; + } + else + { + requiredLevel = EmitAction::Definition; + } + break; + } case kIROp_NativePtrType: requiredLevel = EmitAction::ForwardDeclaration; break; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 0c5afb4af..b33565700 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -967,7 +967,9 @@ INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0) INST(TaggedUnionTypeLayout, taggedUnionTypeLayout, 0, HOISTABLE) INST(ExistentialTypeLayout, existentialTypeLayout, 0, HOISTABLE) INST(StructTypeLayout, structTypeLayout, 0, HOISTABLE) - INST_RANGE(TypeLayout, TypeLayoutBase, StructTypeLayout) + // TODO(JS): Ideally we'd have the layout to the pointed to value type (ie 1 instead of 0 here). But to avoid infinite recursion we don't. + INST(PointerTypeLayout, ptrTypeLayout, 0, HOISTABLE) + INST_RANGE(TypeLayout, TypeLayoutBase, PointerTypeLayout) INST(EntryPointLayout, EntryPointLayout, 1, HOISTABLE) INST_RANGE(Layout, VarLayout, EntryPointLayout) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 91bc4343d..4c540bfdd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1499,6 +1499,52 @@ struct IRArrayTypeLayout : IRTypeLayout }; }; +/* TODO(JS): + +It would arguably be "more correct" if the IRPointerTypeLayout, contained a refence to the value/target +type layout. Ie... + +``` +IRTypeLayout* m_valueTypeLayout; +``` + +Unfortunately that doesn't work because it leads to an infinite loop if the target contains a Ptr to the containing struct. + +This isn't so simple to fix (as has been done with similar problems elsewhere), because Layout +also hoists/deduped layouts. + +As it stands the "attributes" describing the layout fields are held as operands and as such are part +of the hash that is used for deduping. That makes sense (if the fields change depending on where/how +a struct type is used), but creates a problem because we can't lookup the type until it is "complete" +(ie has all the fields) and we can't have all the fields if one is a pointer that causes infinite recursion +in lookup. + +The work around for now is to observe that layout of a Ptr doesn't depend on what is being pointed to +and as such we don't store the this in the pointer. +*/ +struct IRPointerTypeLayout : IRTypeLayout +{ + typedef IRTypeLayout Super; + + IR_LEAF_ISA(PointerTypeLayout) + + struct Builder : Super::Builder + { + Builder(IRBuilder* irBuilder) + : Super::Builder(irBuilder) + {} + + IRPointerTypeLayout* build() + { + return cast<IRPointerTypeLayout>(Super::Builder::build()); + } + + protected: + IROp getOp() SLANG_OVERRIDE { return kIROp_PointerTypeLayout; } + void addOperandsImpl(List<IRInst*>& ioOperands) SLANG_OVERRIDE; + }; +}; + /// Specialized layout information for stream-output types struct IRStreamOutputTypeLayout : IRTypeLayout { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index cac49f5f7..fb121d245 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -967,6 +967,18 @@ namespace Slang ioOperands.add(m_elementTypeLayout); } + // + // IRPointerTypeLayout + // + + void IRPointerTypeLayout::Builder::addOperandsImpl(List<IRInst*>& ioOperands) + { + SLANG_UNUSED(ioOperands); + // TODO(JS): For now we don't store the value types layout to avoid + // infinite recursion. + //ioOperands.add(m_valueTypeLayout); + } + // // IRStreamOutputTypeLayout // diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 080d78f6a..75ce44862 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -578,7 +578,7 @@ struct IRInst // Each instruction can have zero or more "decorations" // attached to it. A decoration is a specialized kind // of instruction that either attaches metadata to, - // or modifies the sematnics of, its parent instruction. + // or modifies the semantics of, its parent instruction. // IRDecoration* getFirstDecoration(); IRDecoration* getLastDecoration(); diff --git a/source/slang/slang-legalize-types.cpp b/source/slang/slang-legalize-types.cpp index c6d76205f..65e279f52 100644 --- a/source/slang/slang-legalize-types.cpp +++ b/source/slang/slang-legalize-types.cpp @@ -1219,7 +1219,44 @@ LegalType legalizeTypeImpl( } else if (auto ptrType = as<IRPtrTypeBase>(type)) { - auto legalValueType = legalizeType(context, ptrType->getValueType()); + typedef TypeLegalizationContext::PointerValue PointerValue; + + auto valueType = ptrType->getValueType(); + + { + const Index activeIndex = context->activePointerValues.findFirstIndex([valueType](const PointerValue& value) -> bool { return value.type == valueType; }); + + if (activeIndex >= 0) + { + context->activePointerValues[activeIndex].usedCount++; + // If it's *active* then it's currently being legalized. + // We will *assume* that value type will be the same type. + return LegalType::simple(ptrType); + } + } + + // Add the value type so we don't end up in a recursive loop + context->activePointerValues.add(PointerValue{valueType, 0}); + + auto legalValueType = legalizeType(context, valueType); + + const auto lastPointerValue = context->activePointerValues.getLast(); + // Remove it as we don't need anymore + context->activePointerValues.removeLast(); + + if (lastPointerValue.usedCount) + { + // It was recursively used, so we want to make sure our previous assumption was correct + if (legalValueType.flavor != LegalType::Flavor::simple || + legalValueType.obj != nullptr || + legalValueType.irType != valueType) + { + // TODO(JS): + // Ideally we'd handle this in some better way... + SLANG_ASSERT(!"We assumed a Ptr behavior if recursive, but that assumption didn't seem to work out"); + } + } + // If element type hasn't change, return original type. if (legalValueType.flavor == LegalType::Flavor::simple && legalValueType.getSimple() == ptrType->getValueType()) diff --git a/source/slang/slang-legalize-types.h b/source/slang/slang-legalize-types.h index 17029b6b6..4909433e3 100644 --- a/source/slang/slang-legalize-types.h +++ b/source/slang/slang-legalize-types.h @@ -632,6 +632,22 @@ struct IRTypeLegalizationContext /// Dictionary<IRFunc*, RefPtr<LegalFuncInfo>> mapFuncToInfo; + /// + /// Special handling for pointer types. If we have a situation where + /// a type could end up in a loop pointing to itself, the activePointerValues + /// stack records which pointer value types (ie the thing being pointed to) + /// are "active". The usedCount is to indicate how many times the type was + /// used whilst active. If it's !=0, we should check the assumption about what + /// should have been produced. + /// + struct PointerValue + { + IRType* type = nullptr; + Index usedCount = 0; + }; + + List<PointerValue> activePointerValues; + IRBuilder* getBuilder() { return builder; } /// Customization point to decide what types are "special." diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5df1db03b..58768f2ad 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -101,6 +101,8 @@ struct ExtractedExistentialValInfo; // values are also supported. struct LoweredValInfo { + typedef LoweredValInfo ThisType; + // Which of the cases of value are we looking at? enum class Flavor { @@ -136,11 +138,20 @@ struct LoweredValInfo union { - IRInst* val; + IRInst* val; ExtendedValueInfo* ext; + + // We can compare any of the pointers above by comparing this pointer. If the union + // ever becomes something other than a union of pointers, this would no longer be applicable. + void* aliasPtr; }; Flavor flavor; + // NOTE! This relies on the union, allowing the comparison of any of the pointer type in the union. + // Assumes equality is the same as val pointer/or ext pointer being equal. + bool operator==(const ThisType& rhs) const { return flavor == rhs.flavor && aliasPtr == rhs.aliasPtr; } + bool operator!=(const ThisType& rhs) const { return !(*this == rhs); } + LoweredValInfo() { flavor = Flavor::None; @@ -421,22 +432,7 @@ struct IRGenEnv struct SharedIRGenContext { - SharedIRGenContext( - Session* session, - DiagnosticSink* sink, - bool obfuscateCode, - ModuleDecl* mainModuleDecl = nullptr) - : m_session(session) - , m_sink(sink) - , m_obfuscateCode(obfuscateCode) - , m_mainModuleDecl(mainModuleDecl) - {} - - Session* m_session = nullptr; - DiagnosticSink* m_sink = nullptr; - bool m_obfuscateCode = false; - ModuleDecl* m_mainModuleDecl = nullptr; - + // The "global" environment for mapping declarations to their IR values. IRGenEnv globalEnv; @@ -460,6 +456,27 @@ struct SharedIRGenContext Dictionary<Stmt*, IRBlock*> breakLabels; Dictionary<Stmt*, IRBlock*> continueLabels; + void setGlobalValue(Decl* decl, LoweredValInfo value) + { + globalEnv.mapDeclToValue[decl] = value; + } + + SharedIRGenContext( + Session* session, + DiagnosticSink* sink, + bool obfuscateCode, + ModuleDecl* mainModuleDecl = nullptr) + : m_session(session) + , m_sink(sink) + , m_obfuscateCode(obfuscateCode) + , m_mainModuleDecl(mainModuleDecl) + {} + + Session* m_session = nullptr; + DiagnosticSink* m_sink = nullptr; + bool m_obfuscateCode = false; + ModuleDecl* m_mainModuleDecl = nullptr; + // List of all string literals used in user code, regardless // of how they were used (i.e., whether or not they were hashed). // @@ -473,7 +490,6 @@ struct SharedIRGenContext List<IRInst*> m_stringLiterals; }; - struct IRGenContext { ASTBuilder* astBuilder; @@ -509,6 +525,16 @@ struct IRGenContext , irBuilder(nullptr) {} + void setGlobalValue(Decl* decl, LoweredValInfo value) + { + shared->setGlobalValue(decl, value); + } + + void setValue(Decl* decl, LoweredValInfo value) + { + env->mapDeclToValue[decl] = value; + } + Session* getSession() { return shared->m_session; @@ -537,21 +563,6 @@ struct IRGenContext } }; -void setGlobalValue(SharedIRGenContext* sharedContext, Decl* decl, LoweredValInfo value) -{ - sharedContext->globalEnv.mapDeclToValue[decl] = value; -} - -void setGlobalValue(IRGenContext* context, Decl* decl, LoweredValInfo value) -{ - setGlobalValue(context->shared, decl, value); -} - -void setValue(IRGenContext* context, Decl* decl, LoweredValInfo value) -{ - context->env->mapDeclToValue[decl] = value; -} - ModuleDecl* findModuleDecl(Decl* decl) { for (auto dd = decl; dd; dd = dd->parentDecl) @@ -1935,8 +1946,11 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower IRType* visitPtrType(PtrType* type) { - IRType* valueType = lowerType(context, type->getValueType()); - return getBuilder()->getPtrType(valueType); + auto astValueType = type->getValueType(); + + IRType* irValueType = lowerType(context, astValueType); + + return getBuilder()->getPtrType(irValueType); } IRType* visitDeclRefType(DeclRefType* type) @@ -3830,7 +3844,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { SLANG_ASSERT(argIndex < subst->getArgs().getCount()); auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]); - setValue(subContext, paramDecl, argVal); + subContext->setValue(paramDecl, argVal); } void _lowerSubstitutionEnv(IRGenContext* subContext, Substitutions* subst) @@ -4587,7 +4601,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // feels kind of messy and gross. auto initVal = lowerLValueExpr(context, expr->decl->initExpr); - setGlobalValue(context, expr->decl, initVal); + context->setGlobalValue(expr->decl, initVal); auto bodyVal = lowerSubExpr(expr->body); return bodyVal; } @@ -6726,7 +6740,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto irKey = getBuilder()->createStructKey(); addLinkageDecoration(context, irKey, inheritanceDecl); auto keyVal = LoweredValInfo::simple(irKey); - setGlobalValue(context, inheritanceDecl, keyVal); + context->setGlobalValue(inheritanceDecl, keyVal); return keyVal; } } @@ -6760,7 +6774,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto irWitnessTable = subBuilder->createWitnessTable(irWitnessTableBaseType, nullptr); // Register the value now, rather than later, to avoid any possible infinite recursion. - setGlobalValue(context, inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); + context->setGlobalValue(inheritanceDecl, LoweredValInfo::simple(findOuterMostGeneric(irWitnessTable))); auto irSubType = lowerType(subContext, subType); irWitnessTable->setConcreteType(irSubType); @@ -6917,7 +6931,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // A global variable's SSA value is a *pointer* to // the underlying storage. - setGlobalValue(context, decl, paramVal); + context->setGlobalValue(decl, paramVal); irParam->moveToEnd(); @@ -7019,7 +7033,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // for any references to the constant from elsewhere // in the code. // - setGlobalValue(context, decl, loweredValue); + context->setGlobalValue(decl, loweredValue); return loweredValue; } @@ -7073,7 +7087,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // A global variable's SSA value is a *pointer* to // the underlying storage. - setGlobalValue(context, decl, globalVal); + context->setGlobalValue(decl, globalVal); if (isImportedDecl(decl)) { @@ -7198,7 +7212,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subBuilder->addHighLevelDeclDecoration(irGlobal, decl); LoweredValInfo globalVal = LoweredValInfo::ptr(irGlobal); - setValue(context, decl, globalVal); + context->setValue(decl, globalVal); // A `static` variable with an initializer needs special handling, // at least if the initializer isn't a compile-time constant. @@ -7289,7 +7303,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { auto initVal = lowerRValueExpr(context, initExpr); initVal = LoweredValInfo::simple(getSimpleVal(context, initVal)); - setGlobalValue(context, decl, initVal); + context->setGlobalValue(decl, initVal); return initVal; } } @@ -7304,7 +7318,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> assign(context, varVal, initVal); } - setGlobalValue(context, decl, varVal); + context->setGlobalValue(decl, varVal); return varVal; } @@ -7326,7 +7340,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } auto assocType = context->irBuilder->getAssociatedType( constraintInterfaces.getArrayView().arrayView); - setValue(context, decl, assocType); + context->setValue(decl, assocType); return LoweredValInfo::simple(assocType); } @@ -7392,7 +7406,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr); // Add `irInterface` to decl mapping now to prevent cyclic lowering. - setValue(context, decl, LoweredValInfo::simple(irInterface)); + context->setValue(decl, LoweredValInfo::simple(irInterface)); // Setup subContext for proper lowering `ThisType`, associated types and // the interface decl's self reference. @@ -7468,7 +7482,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> irInterface->setOperand(entryIndex, constraintEntry); entryIndex++; - setValue(context, constraintDecl, LoweredValInfo::simple(constraintEntry)); + context->setValue(constraintDecl, LoweredValInfo::simple(constraintEntry)); } } else @@ -7490,7 +7504,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Add lowered requirement entry to current decl mapping to prevent // the function requirements from being lowered again when we get to // `ensureAllDeclsRec`. - setValue(context, requirementDecl, LoweredValInfo::simple(entry)); + context->setValue(requirementDecl, LoweredValInfo::simple(entry)); } } @@ -7628,6 +7642,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(subBuilder->getVoidType()); } + const auto finishedVal = _getFinishOuterGenericsReturnValue(irAggType, outerGeneric); + + // We add the decl now such that if there are Ptr or other references + // to this type they can still complete + context->setValue(decl, LoweredValInfo::simple(finishedVal)); + addNameHint(context, irAggType, decl); addLinkageDecoration(context, irAggType, decl); @@ -7713,7 +7733,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (as<NonCopyableTypeAttribute>(modifier)) subBuilder->addNonCopyableTypeDecoration(irAggType); } - return LoweredValInfo::simple(finishOuterGenerics(subBuilder, irAggType, outerGeneric)); + + auto finalFinishedVal = finishOuterGenerics(subBuilder, irAggType, outerGeneric); + // Confirm that _getFinishOuterGenericsReturnValue above returned the same result + SLANG_ASSERT(finalFinishedVal == finishedVal); + + return LoweredValInfo::simple(finalFinishedVal); } void lowerPackOffsetModifier(IRInst* inst, HLSLPackOffsetSemantic* semantic) @@ -7933,7 +7958,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { auto supType = lowerType(context, constraintDecl->sup.type); auto value = emitGenericConstraintValue(subContext, constraintDecl, supType); - setValue(subContext, constraintDecl, LoweredValInfo::simple(value)); + subContext->setValue(constraintDecl, LoweredValInfo::simple(value)); } IRGeneric* emitOuterGeneric( @@ -7966,14 +7991,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // classifier of the parameter. auto param = subBuilder->emitParam(subBuilder->getTypeType()); addNameHint(context, param, typeParamDecl); - setValue(subContext, typeParamDecl, LoweredValInfo::simple(param)); + subContext->setValue(typeParamDecl, LoweredValInfo::simple(param)); } else if (auto valDecl = as<GenericValueParamDecl>(member)) { auto paramType = lowerType(subContext, valDecl->getType()); auto param = subBuilder->emitParam(paramType); addNameHint(context, param, valDecl); - setValue(subContext, valDecl, LoweredValInfo::simple(param)); + subContext->setValue(valDecl, LoweredValInfo::simple(param)); } } // Then we emit constraint parameters, again in @@ -8110,7 +8135,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> while (parentGeneric) { - // Create a universal type in `outterBlock` that will be used + // Create a universal type in `outerBlock` that will be used // as the type of this generic inst. The return value of the // generic inst will have a specialized type. // For example, if we have a generic function @@ -8227,6 +8252,28 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return v; } + // This function matches the return value from finishOuterGenerics + // so that we can create the target value without finishOuterGenerics having to be called. + IRInst* _getFinishOuterGenericsReturnValue( + IRInst* val, + IRGeneric* parentGeneric) + { + IRInst* v = val; + while (parentGeneric) + { + // There might be more outer generics, + // so we need to loop until we run out. + v = parentGeneric; + auto parentBlock = as<IRBlock>(v->getParent()); + if (!parentBlock) break; + + parentGeneric = as<IRGeneric>(parentBlock->getParent()); + if (!parentGeneric) break; + + } + return v; + } + // Attach target-intrinsic decorations to an instruction, // based on modifiers on an AST declaration. void addTargetIntrinsicDecorations( @@ -8629,7 +8676,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto paramDecl = paramInfo.decl) { - setValue(subContext, paramDecl, paramVal); + subContext->setValue(paramDecl, paramVal); } if (paramInfo.isThisParam) @@ -8784,7 +8831,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } // Register the value now, to avoid any possible infinite recursion when lowering ForwardDerivativeAttribute - setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); + context->setGlobalValue(decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); for (auto modifier : decl->modifiers) { @@ -9103,7 +9150,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // default to the `result` that is returned from this visitor, // so that all the declarations share the same IR representative. // - setGlobalValue(context->shared, funcDecl, result); + context->setGlobalValue(funcDecl, result); } return result; } @@ -9132,24 +9179,33 @@ LoweredValInfo lowerDecl( } } -// Ensure that a version of the given declaration has been emitted to the IR -LoweredValInfo ensureDecl( - IRGenContext* context, - Decl* decl) -{ - auto shared = context->shared; - - LoweredValInfo result; +// We will probably want to put the +LoweredValInfo* _findLoweredValInfo( + IRGenContext* context, + Decl* decl) +{ // Look for an existing value installed in this context auto env = context->env; while(env) { - if(env->mapDeclToValue.tryGetValue(decl, result)) + if(auto result = env->mapDeclToValue.tryGetValue(decl)) return result; env = env->outer; } + return nullptr; +} + +// Ensure that a version of the given declaration has been emitted to the IR +LoweredValInfo ensureDecl( + IRGenContext* context, + Decl* decl) +{ + if (auto valInfoPtr = _findLoweredValInfo(context, decl)) + { + return *valInfoPtr; + } // If we have a decl that's a generic value/type decl then something has gone seriously // wrong @@ -9168,11 +9224,11 @@ LoweredValInfo ensureDecl( subContext.irBuilder = &subIRBuilder; subContext.env = &subEnv; - result = lowerDecl(&subContext, decl); + auto result = lowerDecl(&subContext, decl); // By default assume that any value we are lowering represents // something that should be installed globally. - setGlobalValue(shared, decl, result); + context->setGlobalValue(decl, result); return result; } @@ -10139,6 +10195,16 @@ IRTypeLayout* lowerTypeLayout( IRArrayTypeLayout::Builder builder(context->irBuilder, irElementTypeLayout); return _lowerTypeLayoutCommon(context, &builder, arrayTypeLayout); } + else if (auto ptrTypeLayout = as<PointerTypeLayout>(typeLayout)) + { + // TODO(JS): + // For now we don't lower the value/target type because this could lead to inifinte recursion + // in the way this is currently implemented. + + //auto irValueTypeLayout = lowerTypeLayout(context, ptrTypeLayout->valueTypeLayout); + IRPointerTypeLayout::Builder builder(context->irBuilder); + return _lowerTypeLayoutCommon(context, &builder, ptrTypeLayout); + } else if( auto taggedUnionTypeLayout = as<TaggedUnionTypeLayout>(typeLayout) ) { IRTaggedUnionTypeLayout::Builder builder(context->irBuilder, taggedUnionTypeLayout->tagOffset); diff --git a/source/slang/slang-parameter-binding.cpp b/source/slang/slang-parameter-binding.cpp index f6637b13c..cd238f623 100644 --- a/source/slang/slang-parameter-binding.cpp +++ b/source/slang/slang-parameter-binding.cpp @@ -1997,6 +1997,18 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( else if (const auto textureType = as<TextureType>(type)) { return nullptr; } else if(const auto samplerStateType = as<SamplerStateType>(type)) { return nullptr; } else if(const auto constantBufferType = as<ConstantBufferType>(type)) { return nullptr; } + else if (auto ptrType = as<PtrType>(type)) + { + SLANG_ASSERT(ptrType->astNodeType == ASTNodeType::PtrType); + + // Work out the layout for the value/target type + auto valueTypeLayout = processEntryPointVaryingParameter(context, ptrType->getValueType(), state, varLayout); + + RefPtr<PointerTypeLayout> ptrTypeLayout = new PointerTypeLayout(); + ptrTypeLayout->valueTypeLayout = valueTypeLayout; + + return ptrTypeLayout; + } // Catch declaration-reference types late in the sequence, since // otherwise they will include all of the above cases... else if( auto declRefType = as<DeclRefType>(type) ) @@ -2136,6 +2148,7 @@ static RefPtr<TypeLayout> processEntryPointVaryingParameter( SLANG_UNEXPECTED("unhandled type kind"); } } + // If we ran into an error in checking the user's code, then skip this parameter else if( const auto errorType = as<ErrorType>(type) ) { diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index 5d35c7eef..8ec0979e4 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -310,7 +310,7 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) auto type = convert(inType); if(!type) return SLANG_TYPE_KIND_NONE; - // TODO(tfoley: Don't emit the same type more than once... + // TODO(tfoley): Don't emit the same type more than once... if (const auto basicType = as<BasicExpressionType>(type)) { @@ -360,6 +360,10 @@ SLANG_API SlangTypeKind spReflectionType_GetKind(SlangReflectionType* inType) { return SLANG_TYPE_KIND_FEEDBACK; } + else if (const auto ptrType = as<PtrType>(type)) + { + return SLANG_TYPE_KIND_POINTER; + } // TODO: need a better way to handle this stuff... #define CASE(TYPE) \ else if(as<TYPE>(type)) do { \ @@ -993,6 +997,10 @@ SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_GetElementTypeLayout { return convert(matrixTypeLayout->elementTypeLayout); } + else if (auto ptrTypeLayout = as<PointerTypeLayout>(typeLayout)) + { + return convert(ptrTypeLayout->valueTypeLayout.Ptr()); + } return nullptr; } diff --git a/source/slang/slang-type-layout.cpp b/source/slang/slang-type-layout.cpp index c7b9af40b..29cf86f5e 100644 --- a/source/slang/slang-type-layout.cpp +++ b/source/slang/slang-type-layout.cpp @@ -100,6 +100,12 @@ struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl } } + SimpleLayoutInfo GetPointerLayout() override + { + // We'll assume 64 pointers by default, with 8 byte alignment + return SimpleLayoutInfo(LayoutResourceKind::Uniform, 8, 8); + } + SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override { SLANG_RELEASE_ASSERT(elementInfo.size.isFinite()); @@ -250,6 +256,15 @@ struct GLSLBaseLayoutRulesImpl : DefaultLayoutRulesImpl return vectorInfo; } + SimpleLayoutInfo GetPointerLayout() override + { + // TODO(JS): + // We'll assume 64 bit "pointer". If we are using these extensions... + // https://github.com/KhronosGroup/GLSL/blob/master/extensions/ext/GLSL_EXT_buffer_reference.txt + // https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_KHR_buffer_device_address.html. + return SimpleLayoutInfo(LayoutResourceKind::Uniform, sizeof(int64_t), sizeof(int64_t)); + } + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) override { // The size of an array must be rounded up to be a multiple of its alignment. @@ -330,6 +345,12 @@ struct HLSLConstantBufferLayoutRulesImpl : DefaultLayoutRulesImpl return Super::GetArrayLayout(elementInfo, elementCount); } + SimpleLayoutInfo GetPointerLayout() override + { + // Not supported on HLSL currently... + return SimpleLayoutInfo(); + } + UniformLayoutInfo BeginStructLayout() override { return UniformLayoutInfo(0, 16); @@ -396,6 +417,16 @@ struct CPULayoutRulesImpl : DefaultLayoutRulesImpl } } + SimpleLayoutInfo GetPointerLayout() override + { + // TODO(JS): + // NOTE! We are assuming that the layout is the same for the *target* that it is for + // the compilation. + // If we are emitting C++, then there is no way in general to know how that C++ will be compiled + // it could be 32 or 64 (or other) sizes. For now we just assume they are the same. + return SimpleLayoutInfo(LayoutResourceKind::Uniform, sizeof(void*), SLANG_ALIGN_OF(void*)); + } + SimpleArrayLayoutInfo GetArrayLayout( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override { if (elementCount.isInfinite()) @@ -473,6 +504,12 @@ struct CUDALayoutRulesImpl : DefaultLayoutRulesImpl } } + SimpleLayoutInfo GetPointerLayout() override + { + // CUDA/NVTRC only support 64 bit pointers + return SimpleLayoutInfo(LayoutResourceKind::Uniform, sizeof(int64_t), sizeof(int64_t)); + } + SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) override { SLANG_RELEASE_ASSERT(elementInfo.size.isFinite()); @@ -587,6 +624,13 @@ struct DefaultVaryingLayoutRulesImpl : DefaultLayoutRulesImpl getKind(), 1); } + SimpleLayoutInfo GetPointerLayout() override + { + // For pointers assume same logic as for scalars + return SimpleLayoutInfo( + getKind(), + 1); + } SimpleLayoutInfo GetVectorLayout(BaseType elementType, SimpleLayoutInfo, size_t) override { @@ -631,6 +675,13 @@ struct GLSLSpecializationConstantLayoutRulesImpl : DefaultLayoutRulesImpl getKind(), 1); } + SimpleLayoutInfo GetPointerLayout() override + { + // In a sense pointer are just like ScalarLayout, so we'll use the same logic... + return SimpleLayoutInfo( + getKind(), + 1); + } SimpleLayoutInfo GetVectorLayout(BaseType elementType, SimpleLayoutInfo, size_t elementCount) override { @@ -3435,11 +3486,66 @@ static TypeLayoutResult createArrayLikeTypeLayout( return TypeLayoutResult(typeLayout, arrayUniformInfo); } +static void _addLayout(TypeLayoutContext const& context, + Type* type, + TypeLayout* layout) +{ + // Add it *without info*. + // The info can be added with _updateLayout + context.layoutMap[type] = TypeLayoutResult(layout, SimpleLayoutInfo()); +} + +static void _addLayout(TypeLayoutContext const& context, + Type* type, + const TypeLayoutResult& result) +{ + context.layoutMap[type] = result; +} + +static TypeLayoutResult _updateLayout(TypeLayoutContext const& context, + Type* type, + TypeLayout* layout, + const SimpleLayoutInfo& info) +{ + auto layoutResultPtr = context.layoutMap.tryGetValue(type); + SLANG_ASSERT(layoutResultPtr); + if (layoutResultPtr) + { + // Check the layout is the same! + SLANG_ASSERT(layoutResultPtr->layout.get() == layout); + // Update the info + layoutResultPtr->info = info; + } + + return TypeLayoutResult(layout, info); +} + +static TypeLayoutResult _updateLayout(TypeLayoutContext const& context, + Type* type, + const TypeLayoutResult& result) +{ + auto layoutResultPtr = context.layoutMap.tryGetValue(type); + SLANG_ASSERT(layoutResultPtr); + if (layoutResultPtr) + { + // Check the layout is the same! + SLANG_ASSERT(layoutResultPtr->layout.get() == result.layout); + // Update the info + layoutResultPtr->info = result.info; + } + + return result; +} static TypeLayoutResult _createTypeLayout( TypeLayoutContext const& context, Type* type) { + if (auto layoutResultPtr = context.layoutMap.tryGetValue(type)) + { + return *layoutResultPtr; + } + auto rules = context.rules; if (auto parameterGroupType = as<ParameterGroupType>(type)) @@ -3702,6 +3808,30 @@ static TypeLayoutResult _createTypeLayout( { return createArrayLikeTypeLayout(context, arrayType, arrayType->getElementType(), arrayType->getElementCount()); } + else if (auto ptrType = as<PtrType>(type)) + { + RefPtr<PointerTypeLayout> ptrLayout = new PointerTypeLayout(); + + const auto info = rules->GetPointerLayout(); + + const TypeLayoutResult result(ptrLayout, info); + _addLayout(context, type, result); + + ptrLayout->type = type; + ptrLayout->rules = rules; + + ptrLayout->uniformAlignment = info.alignment; + + ptrLayout->addResourceUsage(info.kind, info.size); + + const auto valueTypeLayout = _createTypeLayout( + context, + ptrType->getValueType()); + + ptrLayout->valueTypeLayout = valueTypeLayout.layout; + + return result; + } else if (auto declRefType = as<DeclRefType>(type)) { auto declRef = declRefType->declRef; @@ -3714,6 +3844,8 @@ static TypeLayoutResult _createTypeLayout( typeLayoutBuilder.beginLayout(type, rules); auto typeLayout = typeLayoutBuilder.getTypeLayout(); + _addLayout(context, type, typeLayout); + // First, add all fields with explicit offsets. for (auto field : getFields(structDeclRef, MemberFilterStyle::Instance)) { @@ -3790,7 +3922,7 @@ static TypeLayoutResult _createTypeLayout( typeLayout->pendingDataTypeLayout = pendingDataTypeLayout; } - return typeLayoutBuilder.getTypeLayoutResult(); + return _updateLayout(context, type, typeLayoutBuilder.getTypeLayoutResult()); } else if (auto globalGenericParamDecl = declRef.as<GlobalGenericParamDecl>()) { @@ -4092,6 +4224,9 @@ static TypeLayoutResult _createTypeLayout( UniformLayoutInfo info(0, 1); RefPtr<TaggedUnionTypeLayout> taggedUnionLayout = new TaggedUnionTypeLayout(); + + _addLayout(context, type, taggedUnionLayout); + taggedUnionLayout->type = type; taggedUnionLayout->rules = rules; @@ -4155,7 +4290,7 @@ static TypeLayoutResult _createTypeLayout( taggedUnionLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->count = info.size; taggedUnionLayout->uniformAlignment = info.alignment; - return TypeLayoutResult(taggedUnionLayout, info); + return _updateLayout(context, type, taggedUnionLayout, info); } else if( auto existentialSpecializedType = as<ExistentialSpecializedType>(type) ) { @@ -4171,6 +4306,9 @@ static TypeLayoutResult _createTypeLayout( rules->AddStructField(&info, baseTypeLayoutResult.info.getUniformLayout()); RefPtr<ExistentialSpecializedTypeLayout> typeLayout = new ExistentialSpecializedTypeLayout(); + + _addLayout(context, type, typeLayout); + typeLayout->type = type; typeLayout->rules = rules; @@ -4215,7 +4353,7 @@ static TypeLayoutResult _createTypeLayout( typeLayout->addResourceUsage(LayoutResourceKind::Uniform, info.size); } - return makeTypeLayoutResult(typeLayout); + return _updateLayout(context, type, makeTypeLayoutResult(typeLayout)); } // catch-all case in case nothing matched diff --git a/source/slang/slang-type-layout.h b/source/slang/slang-type-layout.h index 0b6dd42d8..a80f6afdd 100644 --- a/source/slang/slang-type-layout.h +++ b/source/slang/slang-type-layout.h @@ -613,6 +613,17 @@ public: RefPtr<TypeLayout> originalElementTypeLayout; }; +/// Type layout for an pointer type +class PointerTypeLayout : public TypeLayout +{ +public: + // TODO(JS): + // Should this derive from SequenceTypeLayout? A pointer is kind of like an array without + // bounds - in that it can be indexed. Of it it can be looked at as an indirection to a value. + // Is the "Just Work"iness applicable? + RefPtr<TypeLayout> valueTypeLayout; +}; + // type layout for a variable with stream-output type class StreamOutputTypeLayout : public TypeLayout { @@ -917,6 +928,9 @@ struct SimpleLayoutRulesImpl // Get size and alignment for an array of elements virtual SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) = 0; + /// Get pointer layout + virtual SimpleLayoutInfo GetPointerLayout() = 0; + // Get layout for a vector or matrix type virtual SimpleLayoutInfo GetVectorLayout(BaseType elementType, SimpleLayoutInfo elementInfo, size_t elementCount) = 0; virtual SimpleArrayLayoutInfo GetMatrixLayout(BaseType elementType, SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) = 0; @@ -949,7 +963,10 @@ struct LayoutRulesImpl { return simpleRules->GetScalarLayout(baseType); } - + SimpleLayoutInfo GetPointerLayout() + { + return simpleRules->GetPointerLayout(); + } SimpleArrayLayoutInfo GetArrayLayout(SimpleLayoutInfo elementInfo, LayoutSize elementCount) { return simpleRules->GetArrayLayout(elementInfo, elementCount); @@ -1013,6 +1030,30 @@ struct LayoutRulesFamilyImpl virtual LayoutRulesImpl* getStructuredBufferRules(TargetRequest* request) = 0; }; + /// A custom tuple to capture the outputs of type layout +struct TypeLayoutResult +{ + /// The actual heap-allocated layout object with all the details + RefPtr<TypeLayout> layout; + + /// A simplified representation of layout information. + /// + /// This information is suitable for the case where a type only + /// consumes a single resource. + /// + SimpleLayoutInfo info; + + /// Default constructor. + TypeLayoutResult() + {} + + /// Construct a result from the given layout object and simple layout info. + TypeLayoutResult(RefPtr<TypeLayout> inLayout, SimpleLayoutInfo const& inInfo) + : layout(inLayout) + , info(inInfo) + {} +}; + struct TypeLayoutContext { ASTBuilder* astBuilder; @@ -1040,6 +1081,9 @@ struct TypeLayoutContext Int specializationArgCount = 0; ExpandedSpecializationArg const* specializationArgs = nullptr; + // Map types to their type layout + Dictionary<Type*, TypeLayoutResult> layoutMap; + LayoutRulesImpl* getRules() { return rules; } LayoutRulesFamilyImpl* getRulesFamily() const { return rules->getLayoutRulesFamily(); } @@ -1088,29 +1132,7 @@ struct TypeLayoutContext // - /// A custom tuple to capture the outputs of type layout -struct TypeLayoutResult -{ - /// The actual heap-allocated layout object with all the details - RefPtr<TypeLayout> layout; - /// A simplified representation of layout information. - /// - /// This information is suitable for the case where a type only - /// consumes a single resource. - /// - SimpleLayoutInfo info; - - /// Default constructor. - TypeLayoutResult() - {} - - /// Construct a result from the given layout object and simple layout info. - TypeLayoutResult(RefPtr<TypeLayout> inLayout, SimpleLayoutInfo const& inInfo) - : layout(inLayout) - , info(inInfo) - {} -}; /// Helper type for building `struct` type layouts struct StructTypeLayoutBuilder diff --git a/tests/language-feature/pointer/pointer-self-reference.slang b/tests/language-feature/pointer/pointer-self-reference.slang new file mode 100644 index 000000000..e78b70db0 --- /dev/null +++ b/tests/language-feature/pointer/pointer-self-reference.slang @@ -0,0 +1,37 @@ +// pointer-self-reference.slang + +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +struct Thing +{ + int value; + Ptr<Thing> next; +}; + +[numthreads(4, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int idx = dispatchThreadID.x; + + Thing things[2]; + + things[0].next = &things[1]; + things[0].value = 27; + + things[1].next = &things[0]; + things[1].value = idx * idx; + + Ptr<Thing> cur = &things[0]; + + for (int i = 0; cur && i < idx; ++i) + { + cur = cur.next; + } + + int v = cur.value; + + outputBuffer[idx] = v; +} diff --git a/tests/language-feature/pointer/pointer-self-reference.slang.expected.txt b/tests/language-feature/pointer/pointer-self-reference.slang.expected.txt new file mode 100644 index 000000000..a5e22e1d0 --- /dev/null +++ b/tests/language-feature/pointer/pointer-self-reference.slang.expected.txt @@ -0,0 +1,5 @@ +type: int32_t +27 +1 +27 +9 diff --git a/tests/reflection/ptr/ptr-generic.slang b/tests/reflection/ptr/ptr-generic.slang new file mode 100644 index 000000000..22f9877e9 --- /dev/null +++ b/tests/reflection/ptr/ptr-generic.slang @@ -0,0 +1,20 @@ +//TEST(64-bit):REFLECTION:-stage compute -no-codegen -target host-callable -entry computeMain + +struct GenericStruct<T, let N: int> +{ + T someT; + int values[N]; + + Ptr<GenericStruct<float, 2>> genericPtr; +}; + +Ptr<GenericStruct<int, 4>> genericPtr; + +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + + outputBuffer[dispatchThreadID.x] = genericPtr.someT; +}
\ No newline at end of file diff --git a/tests/reflection/ptr/ptr-generic.slang.expected b/tests/reflection/ptr/ptr-generic.slang.expected new file mode 100644 index 000000000..7aebc0187 --- /dev/null +++ b/tests/reflection/ptr/ptr-generic.slang.expected @@ -0,0 +1,51 @@ +result code = 0 +standard error = { +} +standard output = { +{ + "parameters": [ + { + "name": "genericPtr", + "binding": {"kind": "uniform", "offset": 0, "size": 8}, + "type": { + "kind": "pointer", + "valueType": "GenericStruct" + } + }, + { + "name": "outputBuffer", + "binding": {"kind": "uniform", "offset": 8, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "scalar", + "scalarType": "int32" + } + } + } + ], + "entryPoints": [ + { + "name": "computeMain", + "stage:": "compute", + "parameters": [ + { + "name": "dispatchThreadID", + "semanticName": "SV_DISPATCHTHREADID", + "type": { + "kind": "vector", + "elementCount": 3, + "elementType": { + "kind": "scalar", + "scalarType": "uint32" + } + } + } + ], + "threadGroupSize": [4, 1, 1] + } + ] +} +} diff --git a/tests/reflection/ptr/ptr-global.slang b/tests/reflection/ptr/ptr-global.slang new file mode 100644 index 000000000..98ddd7e45 --- /dev/null +++ b/tests/reflection/ptr/ptr-global.slang @@ -0,0 +1,18 @@ +//TEST(64-bit):REFLECTION:-stage compute -no-codegen -target host-callable -entry computeMain + +struct SomeStruct +{ + Ptr<int> regularGlobal; + int* regularGlobal2; + int regularGlobal3; +}; + +RWStructuredBuffer<SomeStruct> inputBuffer; + +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[dispatchThreadID.x] = int(dispatchThreadID.x); +}
\ No newline at end of file diff --git a/tests/reflection/ptr/ptr-global.slang.expected b/tests/reflection/ptr/ptr-global.slang.expected new file mode 100644 index 000000000..bd57d6a4b --- /dev/null +++ b/tests/reflection/ptr/ptr-global.slang.expected @@ -0,0 +1,82 @@ +result code = 0 +standard error = { +} +standard output = { +{ + "parameters": [ + { + "name": "inputBuffer", + "binding": {"kind": "uniform", "offset": 0, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "struct", + "name": "SomeStruct", + "fields": [ + { + "name": "regularGlobal", + "type": { + "kind": "pointer", + "valueType": "int" + }, + "binding": {"kind": "uniform", "offset": 0, "size": 8} + }, + { + "name": "regularGlobal2", + "type": { + "kind": "pointer", + "valueType": "int" + }, + "binding": {"kind": "uniform", "offset": 8, "size": 8} + }, + { + "name": "regularGlobal3", + "type": { + "kind": "scalar", + "scalarType": "int32" + }, + "binding": {"kind": "uniform", "offset": 16, "size": 4} + } + ] + } + } + }, + { + "name": "outputBuffer", + "binding": {"kind": "uniform", "offset": 16, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "scalar", + "scalarType": "int32" + } + } + } + ], + "entryPoints": [ + { + "name": "computeMain", + "stage:": "compute", + "parameters": [ + { + "name": "dispatchThreadID", + "semanticName": "SV_DISPATCHTHREADID", + "type": { + "kind": "vector", + "elementCount": 3, + "elementType": { + "kind": "scalar", + "scalarType": "uint32" + } + } + } + ], + "threadGroupSize": [4, 1, 1] + } + ] +} +} diff --git a/tests/reflection/ptr/ptr-self-reference.slang b/tests/reflection/ptr/ptr-self-reference.slang new file mode 100644 index 000000000..437b7d61b --- /dev/null +++ b/tests/reflection/ptr/ptr-self-reference.slang @@ -0,0 +1,17 @@ +//TEST(64-bit):REFLECTION:-stage compute -no-codegen -target host-callable -entry computeMain + +struct SomeStruct +{ + int payload; + Ptr<SomeStruct> next; +}; + +RWStructuredBuffer<SomeStruct> inputBuffer; + +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[dispatchThreadID.x] = int(dispatchThreadID.x); +}
\ No newline at end of file diff --git a/tests/reflection/ptr/ptr-self-reference.slang.expected b/tests/reflection/ptr/ptr-self-reference.slang.expected new file mode 100644 index 000000000..e17de757a --- /dev/null +++ b/tests/reflection/ptr/ptr-self-reference.slang.expected @@ -0,0 +1,74 @@ +result code = 0 +standard error = { +} +standard output = { +{ + "parameters": [ + { + "name": "inputBuffer", + "binding": {"kind": "uniform", "offset": 0, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "struct", + "name": "SomeStruct", + "fields": [ + { + "name": "payload", + "type": { + "kind": "scalar", + "scalarType": "int32" + }, + "binding": {"kind": "uniform", "offset": 0, "size": 4} + }, + { + "name": "next", + "type": { + "kind": "pointer", + "valueType": "SomeStruct" + }, + "binding": {"kind": "uniform", "offset": 8, "size": 8} + } + ] + } + } + }, + { + "name": "outputBuffer", + "binding": {"kind": "uniform", "offset": 16, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "scalar", + "scalarType": "int32" + } + } + } + ], + "entryPoints": [ + { + "name": "computeMain", + "stage:": "compute", + "parameters": [ + { + "name": "dispatchThreadID", + "semanticName": "SV_DISPATCHTHREADID", + "type": { + "kind": "vector", + "elementCount": 3, + "elementType": { + "kind": "scalar", + "scalarType": "uint32" + } + } + } + ], + "threadGroupSize": [4, 1, 1] + } + ] +} +} diff --git a/tests/reflection/ptr/ptr-struct.slang b/tests/reflection/ptr/ptr-struct.slang new file mode 100644 index 000000000..ee11ca240 --- /dev/null +++ b/tests/reflection/ptr/ptr-struct.slang @@ -0,0 +1,27 @@ +//TEST(64-bit):REFLECTION:-stage compute -no-codegen -target host-callable -entry computeMain + +struct AnotherStruct +{ + float a; + int b; + Ptr<int> ptrC; +}; + +struct SomeStruct +{ + Ptr<int> ptrInt; + int* ptrInt2; + int anInt; + AnotherStruct another; + Ptr<AnotherStruct> anotherPtr; +}; + +RWStructuredBuffer<SomeStruct> inputBuffer; + +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + outputBuffer[dispatchThreadID.x] = int(dispatchThreadID.x); +}
\ No newline at end of file diff --git a/tests/reflection/ptr/ptr-struct.slang.expected b/tests/reflection/ptr/ptr-struct.slang.expected new file mode 100644 index 000000000..1a7ea5562 --- /dev/null +++ b/tests/reflection/ptr/ptr-struct.slang.expected @@ -0,0 +1,124 @@ +result code = 0 +standard error = { +} +standard output = { +{ + "parameters": [ + { + "name": "inputBuffer", + "binding": {"kind": "uniform", "offset": 0, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "struct", + "name": "SomeStruct", + "fields": [ + { + "name": "ptrInt", + "type": { + "kind": "pointer", + "valueType": "int" + }, + "binding": {"kind": "uniform", "offset": 0, "size": 8} + }, + { + "name": "ptrInt2", + "type": { + "kind": "pointer", + "valueType": "int" + }, + "binding": {"kind": "uniform", "offset": 8, "size": 8} + }, + { + "name": "anInt", + "type": { + "kind": "scalar", + "scalarType": "int32" + }, + "binding": {"kind": "uniform", "offset": 16, "size": 4} + }, + { + "name": "another", + "type": { + "kind": "struct", + "name": "AnotherStruct", + "fields": [ + { + "name": "a", + "type": { + "kind": "scalar", + "scalarType": "float32" + }, + "binding": {"kind": "uniform", "offset": 0, "size": 4} + }, + { + "name": "b", + "type": { + "kind": "scalar", + "scalarType": "int32" + }, + "binding": {"kind": "uniform", "offset": 4, "size": 4} + }, + { + "name": "ptrC", + "type": { + "kind": "pointer", + "valueType": "int" + }, + "binding": {"kind": "uniform", "offset": 8, "size": 8} + } + ] + }, + "binding": {"kind": "uniform", "offset": 24, "size": 16} + }, + { + "name": "anotherPtr", + "type": { + "kind": "pointer", + "valueType": "AnotherStruct" + }, + "binding": {"kind": "uniform", "offset": 40, "size": 8} + } + ] + } + } + }, + { + "name": "outputBuffer", + "binding": {"kind": "uniform", "offset": 16, "size": 16}, + "type": { + "kind": "resource", + "baseShape": "structuredBuffer", + "access": "readWrite", + "resultType": { + "kind": "scalar", + "scalarType": "int32" + } + } + } + ], + "entryPoints": [ + { + "name": "computeMain", + "stage:": "compute", + "parameters": [ + { + "name": "dispatchThreadID", + "semanticName": "SV_DISPATCHTHREADID", + "type": { + "kind": "vector", + "elementCount": 3, + "elementType": { + "kind": "scalar", + "scalarType": "uint32" + } + } + } + ], + "threadGroupSize": [4, 1, 1] + } + ] +} +} diff --git a/tools/gfx/cpu/cpu-device.cpp b/tools/gfx/cpu/cpu-device.cpp index 0db2b0fa7..4b8595e82 100644 --- a/tools/gfx/cpu/cpu-device.cpp +++ b/tools/gfx/cpu/cpu-device.cpp @@ -45,6 +45,11 @@ namespace cpu m_info.timestampFrequency = 1000000000; } + // Can support pointers (or something akin to that) + { + m_features.add("has-ptr"); + } + return SLANG_OK; } diff --git a/tools/gfx/cuda/cuda-device.cpp b/tools/gfx/cuda/cuda-device.cpp index bbf50cc58..7f931afa1 100644 --- a/tools/gfx/cuda/cuda-device.cpp +++ b/tools/gfx/cuda/cuda-device.cpp @@ -185,6 +185,9 @@ SLANG_NO_THROW SlangResult SLANG_MCALL DeviceImpl::initialize(const Desc& desc) // CUDA has support for realtime clock m_features.add("realtime-clock"); + + // Allows use of a ptr like type + m_features.add("has-ptr"); } cudaDeviceProp deviceProps; diff --git a/tools/slang-reflection-test/slang-reflection-test-main.cpp b/tools/slang-reflection-test/slang-reflection-test-main.cpp index 94062cf2b..cd726eb59 100644 --- a/tools/slang-reflection-test/slang-reflection-test-main.cpp +++ b/tools/slang-reflection-test/slang-reflection-test-main.cpp @@ -292,6 +292,7 @@ static void emitReflectionVarBindingInfoJSON( CASE(MIXED, mixed); CASE(REGISTER_SPACE, registerSpace); CASE(GENERIC, generic); + #undef CASE default: @@ -769,6 +770,16 @@ static void emitReflectionTypeInfoJSON( emitReflectionTypeJSON(writer, arrayType->getElementType()); } break; + case slang::TypeReflection::Kind::Pointer: + { + auto pointerType = type; + writer.maybeComma(); + writer << "\"kind\": \"pointer\""; + writer.maybeComma(); + writer << "\"targetType\": "; + emitReflectionTypeJSON(writer, pointerType->getElementType()); + } + break; case slang::TypeReflection::Kind::Struct: { @@ -888,6 +899,40 @@ static void emitReflectionTypeLayoutInfoJSON( emitReflectionTypeInfoJSON(writer, typeLayout->getType()); break; + case slang::TypeReflection::Kind::Pointer: + { + auto valueTypeLayout = typeLayout->getElementTypeLayout(); + SLANG_ASSERT(valueTypeLayout); + + writer.maybeComma(); + writer << "\"kind\": \"pointer\""; + + writer.maybeComma(); + writer << "\"valueType\": "; + + auto typeName = valueTypeLayout->getType()->getName(); + + if (typeName && typeName[0]) + { + // TODO(JS): + // We can't emit the type layout, because the type could contain + // a pointer and we end up in a recursive loop. For now we output the typename. + writer.writeEscapedString(UnownedStringSlice(typeName)); + } + else + { + // TODO(JS): We will need to generate name that we will associate with this type + // as it doesn't seem to have one + writer.writeEscapedString(toSlice("unknown name!")); + SLANG_ASSERT(!"Doesn't have an associated name"); + } + + /* + emitReflectionTypeLayoutJSON( + writer, + valueTypeLayout); */ + } + break; case slang::TypeReflection::Kind::Array: { auto arrayTypeLayout = typeLayout; |
