summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang.h39
-rw-r--r--source/core/hash.h28
-rw-r--r--source/slang/check.cpp35
-rw-r--r--source/slang/compiler.h33
-rw-r--r--source/slang/diagnostics.h2
-rw-r--r--source/slang/ir-legalize-types.cpp54
-rw-r--r--source/slang/lower-to-ir.cpp18
-rw-r--r--source/slang/reflection.cpp44
-rw-r--r--source/slang/slang.cpp19
-rw-r--r--source/slang/syntax.cpp102
-rw-r--r--source/slang/syntax.h26
-rw-r--r--source/slang/type-defs.h13
-rw-r--r--tests/compute/interface-shader-param-legalization.slang54
-rw-r--r--tests/compute/interface-shader-param-legalization.slang.expected.txt4
14 files changed, 431 insertions, 40 deletions
diff --git a/slang.h b/slang.h
index 32a9ba17e..6175d419e 100644
--- a/slang.h
+++ b/slang.h
@@ -1812,6 +1812,8 @@ extern "C"
SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLayout* type);
+ SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getPendingDataTypeLayout(SlangReflectionTypeLayout* type);
+
// Variable Reflection
SLANG_API char const* spReflectionVariable_GetName(SlangReflectionVariable* var);
@@ -1843,6 +1845,9 @@ extern "C"
SLANG_API SlangStage spReflectionVariableLayout_getStage(
SlangReflectionVariableLayout* var);
+
+ SLANG_API SlangReflectionVariableLayout* spReflectionVariableLayout_getPendingDataLayout(SlangReflectionVariableLayout* var);
+
// Shader Parameter Reflection
typedef SlangReflectionVariableLayout SlangReflectionParameter;
@@ -1897,6 +1902,14 @@ extern "C"
SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
+ SLANG_API SlangReflectionType* spReflection_specializeType(
+ SlangReflection* reflection,
+ SlangReflectionType* type,
+ SlangInt specializationArgCount,
+ SlangReflectionType* const* specializationArgs,
+ ISlangBlob** outDiagnostics);
+
+
#ifdef __cplusplus
}
@@ -2232,6 +2245,13 @@ namespace slang
return spReflectionTypeLayout_getGenericParamIndex(
(SlangReflectionTypeLayout*) this);
}
+
+ TypeLayoutReflection* getPendingDataTypeLayout()
+ {
+ return (TypeLayoutReflection*) spReflectionTypeLayout_getPendingDataTypeLayout(
+ (SlangReflectionTypeLayout*) this);
+ }
+
};
struct Modifier
@@ -2350,6 +2370,11 @@ namespace slang
{
return spReflectionVariableLayout_getStage((SlangReflectionVariableLayout*) this);
}
+
+ VariableLayoutReflection* getPendingDataLayout()
+ {
+ return (VariableLayoutReflection*) spReflectionVariableLayout_getPendingDataLayout((SlangReflectionVariableLayout*) this);
+ }
};
struct EntryPointReflection
@@ -2487,6 +2512,20 @@ namespace slang
(SlangReflection*) this,
name);
}
+
+ TypeReflection* specializeType(
+ TypeReflection* type,
+ SlangInt specializationArgCount,
+ TypeReflection* const* specializationArgs,
+ ISlangBlob** outDiagnostics)
+ {
+ return (TypeReflection*) spReflection_specializeType(
+ (SlangReflection*) this,
+ (SlangReflectionType*) type,
+ specializationArgCount,
+ (SlangReflectionType* const*) specializationArgs,
+ outDiagnostics);
+ }
};
}
diff --git a/source/core/hash.h b/source/core/hash.h
index fc0bca737..83e99179b 100644
--- a/source/core/hash.h
+++ b/source/core/hash.h
@@ -7,6 +7,8 @@
namespace Slang
{
+ typedef int HashCode;
+
inline int GetHashCode(double key)
{
return FloatAsInt((float)key);
@@ -120,6 +122,32 @@ namespace Slang
{
return (left * 16777619) ^ right;
}
+
+ struct Hasher
+ {
+ public:
+ Hasher() {}
+
+ template<typename T>
+ void hashValue(T const& value)
+ {
+ m_hashCode = combineHash(m_hashCode, GetHashCode(value));
+ }
+
+ template<typename T>
+ void hashObject(T const& object)
+ {
+ m_hashCode = combineHash(m_hashCode, object->GetHashCode());
+ }
+
+ HashCode getResult() const
+ {
+ return m_hashCode;
+ }
+
+ private:
+ HashCode m_hashCode = 0;
+ };
}
#endif
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 6a40f436a..d51785112 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -10763,6 +10763,41 @@ static bool doesParameterMatch(
Slang::_specializeExistentialTypeParams(getLinkage(), m_globalExistentialSlots, args, sink);
}
+ Type* Linkage::specializeType(
+ Type* unspecializedType,
+ Int argCount,
+ Type* const* args,
+ DiagnosticSink* sink)
+ {
+ // TODO: We should cache and re-use specialized types
+ // when the exact same arguments are provided again later.
+
+ SemanticsVisitor visitor(this, sink);
+
+
+ ExistentialTypeSlots slots;
+ _collectExistentialTypeParamsRec(slots, unspecializedType);
+
+ assert(slots.paramTypes.getCount() == argCount);
+
+ for( Int aa = 0; aa < argCount; ++aa )
+ {
+ auto argType = args[aa];
+
+ ExistentialTypeSlots::Arg arg;
+ arg.type = argType;
+ arg.witness = visitor.tryGetSubtypeWitness(argType, slots.paramTypes[aa]);
+ slots.args.add(arg);
+ }
+
+ RefPtr<ExistentialSpecializedType> specializedType = new ExistentialSpecializedType();
+ specializedType->baseType = unspecializedType;
+ specializedType->slots = slots;
+
+ m_specializedTypes.add(specializedType);
+
+ return specializedType;
+ }
/// Specialize a program to global generic arguments
RefPtr<Program> createSpecializedProgram(
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 233fa3d05..9b7e06be0 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -133,31 +133,6 @@ namespace Slang
ComPtr<ISlangBlob> blob;
};
- /// Collects information about existential type parameters and their arguments.
- struct ExistentialTypeSlots
- {
- /// For each type parameter, holds the interface/existential type that constrains it.
- List<RefPtr<Type>> paramTypes;
-
- /// An argument for an existential type parameter.
- ///
- /// Comprises a concrete type and a witness for its conformance to the desired
- /// interface/existential type for the corresponding parameter.
- ///
- struct Arg
- {
- RefPtr<Type> type;
- RefPtr<Val> witness;
- };
-
- /// Any arguments provided for the existential type parameters.
- ///
- /// It is possible for `args` to be empty even if `paramTypes` is non-empty;
- /// that situation represents an unspecialized program or entry point.
- ///
- List<Arg> args;
- };
-
/// Information collected about global or entry-point shader parameters
struct ShaderParamInfo
{
@@ -665,6 +640,12 @@ namespace Slang
RefPtr<Expr> parseTypeString(String typeStr, RefPtr<Scope> scope);
+ Type* specializeType(
+ Type* unspecializedType,
+ Int argCount,
+ Type* const* args,
+ DiagnosticSink* sink);
+
/// Add a mew target amd return its index.
UInt addTarget(
CodeGenTarget target);
@@ -754,6 +735,8 @@ namespace Slang
/// Is the given module in the middle of being imported?
bool isBeingImported(Module* module);
+
+ List<RefPtr<Type>> m_specializedTypes;
};
/// Shared functionality between front- and back-end compile requests.
diff --git a/source/slang/diagnostics.h b/source/slang/diagnostics.h
index 7d76ffa85..8e5ba809b 100644
--- a/source/slang/diagnostics.h
+++ b/source/slang/diagnostics.h
@@ -227,6 +227,8 @@ namespace Slang
/// During propagation of an exception for an internal
/// error, note that this source location was involved
void noteInternalErrorLoc(SourceLoc const& loc);
+
+ SlangResult getBlobIfNeeded(ISlangBlob** outBlob);
};
/// An `ISlangWriter` that writes directly to a diagnostic sink.
diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp
index cfc495070..18039315e 100644
--- a/source/slang/ir-legalize-types.cpp
+++ b/source/slang/ir-legalize-types.cpp
@@ -127,7 +127,8 @@ static LegalVal declareVars(
LegalVarChain const& varChain,
UnownedStringSlice nameHint,
IRInst* leafVar,
- IRGlobalNameInfo* globalNameInfo);
+ IRGlobalNameInfo* globalNameInfo,
+ bool isSpecial);
/// Unwrap a value with flavor `wrappedBuffer`
///
@@ -1266,9 +1267,10 @@ static LegalVal legalizeLocalVar(
IRVar* irLocalVar)
{
// Legalize the type for the variable's value
+ auto originalValueType = irLocalVar->getDataType()->getValueType();
auto legalValueType = legalizeType(
context,
- irLocalVar->getDataType()->getValueType());
+ originalValueType);
auto originalRate = irLocalVar->getRate();
@@ -1311,7 +1313,7 @@ static LegalVal legalizeLocalVar(
UnownedStringSlice nameHint = findNameHint(irLocalVar);
context->builder->setInsertBefore(irLocalVar);
- LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, irLocalVar, nullptr);
+ LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain, nameHint, irLocalVar, nullptr, context->isSpecialType(originalValueType));
// Remove the old local var.
irLocalVar->removeFromParent();
@@ -1345,7 +1347,7 @@ static LegalVal legalizeParam(
UnownedStringSlice nameHint = findNameHint(originalParam);
context->builder->setInsertBefore(originalParam);
- auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, LegalVarChain(), nameHint, originalParam, nullptr);
+ auto newVal = declareVars(context, kIROp_Param, legalParamType, nullptr, LegalVarChain(), nameHint, originalParam, nullptr, context->isSpecialType(originalParam->getDataType()));
originalParam->removeFromParent();
context->replacedInstructions.add(originalParam);
@@ -2219,12 +2221,31 @@ static LegalVal declareVars(
IRTypeLegalizationContext* context,
IROp op,
LegalType type,
- TypeLayout* typeLayout,
- LegalVarChain const& varChain,
+ TypeLayout* inTypeLayout,
+ LegalVarChain const& inVarChain,
UnownedStringSlice nameHint,
IRInst* leafVar,
- IRGlobalNameInfo* globalNameInfo)
+ IRGlobalNameInfo* globalNameInfo,
+ bool isSpecial)
{
+ LegalVarChain varChain = inVarChain;
+ TypeLayout* typeLayout = inTypeLayout;
+ if( isSpecial )
+ {
+ if( varChain.pendingChain )
+ {
+ varChain.primaryChain = varChain.pendingChain;
+ varChain.pendingChain = nullptr;
+ }
+ if( typeLayout )
+ {
+ if( auto pendingTypeLayout = typeLayout->pendingDataTypeLayout )
+ {
+ typeLayout = pendingTypeLayout;
+ }
+ }
+ }
+
switch (type.flavor)
{
case LegalType::Flavor::none:
@@ -2247,7 +2268,8 @@ static LegalVal declareVars(
varChain,
nameHint,
leafVar,
- globalNameInfo);
+ globalNameInfo,
+ isSpecial);
return LegalVal::implicitDeref(val);
}
break;
@@ -2255,8 +2277,8 @@ static LegalVal declareVars(
case LegalType::Flavor::pair:
{
auto pairType = type.getPair();
- auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, leafVar, globalNameInfo);
- auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, leafVar, globalNameInfo);
+ auto ordinaryVal = declareVars(context, op, pairType->ordinaryType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, false);
+ auto specialVal = declareVars(context, op, pairType->specialType, typeLayout, varChain, nameHint, leafVar, globalNameInfo, true);
return LegalVal::pair(ordinaryVal, specialVal, pairType->pairInfo);
}
@@ -2305,7 +2327,8 @@ static LegalVal declareVars(
newVarChain,
fieldNameHint,
ee.key,
- globalNameInfo);
+ globalNameInfo,
+ true);
TuplePseudoVal::Element element;
element.key = ee.key;
@@ -2348,9 +2371,10 @@ static LegalVal legalizeGlobalVar(
IRGlobalVar* irGlobalVar)
{
// Legalize the type for the variable's value
+ auto originalValueType = irGlobalVar->getDataType()->getValueType();
auto legalValueType = legalizeType(
context,
- irGlobalVar->getDataType()->getValueType());
+ originalValueType);
switch (legalValueType.flavor)
{
@@ -2373,7 +2397,7 @@ static LegalVal legalizeGlobalVar(
UnownedStringSlice nameHint = findNameHint(irGlobalVar);
context->builder->setInsertBefore(irGlobalVar);
- LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalVar, &globalNameInfo);
+ LegalVal newVal = declareVars(context, kIROp_GlobalVar, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalVar, &globalNameInfo, context->isSpecialType(originalValueType));
// Register the new value as the replacement for the old
registerLegalizedValue(context, irGlobalVar, newVal);
@@ -2417,7 +2441,7 @@ static LegalVal legalizeGlobalConstant(
UnownedStringSlice nameHint = findNameHint(irGlobalConstant);
context->builder->setInsertBefore(irGlobalConstant);
- LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalConstant, &globalNameInfo);
+ LegalVal newVal = declareVars(context, kIROp_GlobalConstant, legalValueType, nullptr, LegalVarChain(), nameHint, irGlobalConstant, &globalNameInfo, context->isSpecialType(irGlobalConstant->getDataType()));
// Register the new value as the replacement for the old
registerLegalizedValue(context, irGlobalConstant, newVal);
@@ -2466,7 +2490,7 @@ static LegalVal legalizeGlobalParam(
UnownedStringSlice nameHint = findNameHint(irGlobalParam);
context->builder->setInsertBefore(irGlobalParam);
- LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, irGlobalParam, &globalNameInfo);
+ LegalVal newVal = declareVars(context, kIROp_GlobalParam, legalValueType, typeLayout, varChain, nameHint, irGlobalParam, &globalNameInfo, context->isSpecialType(irGlobalParam->getDataType()));
// Register the new value as the replacement for the old
registerLegalizedValue(context, irGlobalParam, newVal);
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 2e9915669..a7be244c8 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -1608,6 +1608,24 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return LoweredValInfo::simple(irType);
}
+ LoweredValInfo visitExistentialSpecializedType(ExistentialSpecializedType* type)
+ {
+ auto irBaseType = lowerType(context, type->baseType);
+
+ List<IRInst*> slotArgs;
+ for(auto arg : type->slots.args)
+ {
+ auto irArgType = lowerType(context, arg.type);
+ auto irArgWitness = lowerSimpleVal(context, arg.witness);
+
+ slotArgs.add(irArgType);
+ slotArgs.add(irArgWitness);
+ }
+
+ auto irType = getBuilder()->getBindExistentialsType(irBaseType, slotArgs.getCount(), slotArgs.getBuffer());
+ return LoweredValInfo::simple(irType);
+ }
+
// We do not expect to encounter the following types in ASTs that have
// passed front-end semantic checking.
#define UNEXPECTED_CASE(NAME) IRType* visit##NAME(NAME*) { SLANG_UNEXPECTED(#NAME); UNREACHABLE_RETURN(nullptr); }
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 326a27854..4ac48d2e7 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -860,6 +860,24 @@ SLANG_API int spReflectionTypeLayout_getGenericParamIndex(SlangReflectionTypeLay
}
}
+SLANG_API SlangReflectionTypeLayout* spReflectionTypeLayout_getPendingDataTypeLayout(SlangReflectionTypeLayout* inTypeLayout)
+{
+ auto typeLayout = convert(inTypeLayout);
+ if(!typeLayout) return nullptr;
+
+ auto pendingDataTypeLayout = typeLayout->pendingDataTypeLayout.Ptr();
+ return convert(pendingDataTypeLayout);
+}
+
+SLANG_API SlangReflectionVariableLayout* spReflectionVariableLayout_getPendingDataLayout(SlangReflectionVariableLayout* inVarLayout)
+{
+ auto varLayout = convert(inVarLayout);
+ if(!varLayout) return nullptr;
+
+ auto pendingDataLayout = varLayout->pendingVarLayout.Ptr();
+ return convert(pendingDataLayout);
+}
+
// Variable Reflection
@@ -1381,3 +1399,29 @@ SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* inPro
if (!uniform) return 0;
return getReflectionSize(uniform->count);
}
+
+SLANG_API SlangReflectionType* spReflection_specializeType(
+ SlangReflection* inProgramLayout,
+ SlangReflectionType* inType,
+ SlangInt specializationArgCount,
+ SlangReflectionType* const* specializationArgs,
+ ISlangBlob** outDiagnostics)
+{
+ auto programLayout = convert(inProgramLayout);
+ if(!programLayout) return nullptr;
+
+ auto unspecializedType = convert(inType);
+ if(!unspecializedType) return nullptr;
+
+ auto linkage = programLayout->getProgram()->getLinkage();
+
+ DiagnosticSink sink;
+ sink.sourceManager = linkage->getSourceManager();
+
+ auto specializedType = linkage->specializeType(unspecializedType, specializationArgCount, (Type* const*) specializationArgs, &sink);
+
+ sink.getBlobIfNeeded(outDiagnostics);
+
+ return convert(specializedType);
+}
+
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 7556ac9b2..c78a27f54 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -567,6 +567,9 @@ Type* Program::getTypeFromString(String typeStr, DiagnosticSink* sink)
return type;
}
+
+
+
CompileRequestBase::CompileRequestBase(
Linkage* linkage,
DiagnosticSink* sink)
@@ -1444,6 +1447,22 @@ void DiagnosticSink::noteInternalErrorLoc(SourceLoc const& loc)
internalErrorLocsNoted++;
}
+SlangResult DiagnosticSink::getBlobIfNeeded(ISlangBlob** outBlob)
+{
+ // If the client doesn't want an output blob, there is nothing to do.
+ //
+ if(!outBlob) return SLANG_OK;
+
+ // If there were no errors, and there was no diagnostic output, there is nothing to do.
+ if(!GetErrorCount() && !outputBuffer.getLength()) return SLANG_OK;
+
+ Slang::ComPtr<ISlangBlob> blob = Slang::StringUtil::createStringBlob(outputBuffer);
+ *outBlob = blob.detach();
+
+ return SLANG_OK;
+}
+
+
Session* CompileRequestBase::getSession()
{
return getLinkage()->getSession();
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index c069c69d7..17c85175d 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -2757,7 +2757,109 @@ char const* getGLSLNameForImageFormat(ImageFormat format)
}
}
+//
+// ExistentialSpecializedType
+//
+
+String ExistentialSpecializedType::ToString()
+{
+ String result;
+ result.append("__ExistentialSpecializedType(");
+ result.append(baseType->ToString());
+ for( auto arg : slots.args )
+ {
+ result.append(", ");
+ result.append(arg.type->ToString());
+ }
+ result.append(")");
+ return result;
+}
+
+bool ExistentialSpecializedType::EqualsImpl(Type * type)
+{
+ auto other = as<ExistentialSpecializedType>(type);
+ if(!other)
+ return false;
+
+ if(!baseType->Equals(other->baseType))
+ return false;
+ auto argCount = slots.args.getCount();
+ if(argCount != other->slots.args.getCount())
+ return false;
+ for( Index ii = 0; ii < argCount; ++ii )
+ {
+ if(!slots.args[ii].type->Equals(other->slots.args[ii].type))
+ return false;
+
+ if(!slots.args[ii].witness->EqualsVal(other->slots.args[ii].witness))
+ return false;
+ }
+ return true;
+}
+
+int ExistentialSpecializedType::GetHashCode()
+{
+ Hasher hasher;
+ hasher.hashObject(baseType);
+ for(auto arg : slots.args)
+ {
+ hasher.hashObject(arg.type);
+ hasher.hashObject(arg.witness);
+ }
+ return hasher.getResult();
+}
+
+RefPtr<Type> ExistentialSpecializedType::CreateCanonicalType()
+{
+ RefPtr<ExistentialSpecializedType> canType = new ExistentialSpecializedType();
+ canType->setSession(getSession());
+
+ canType->baseType = baseType->GetCanonicalType();
+ for( auto paramType : slots.paramTypes )
+ {
+ canType->slots.paramTypes.add( paramType->GetCanonicalType() );
+ }
+ for( auto arg : slots.args )
+ {
+ ExistentialTypeSlots::Arg canArg;
+ canArg.type = arg.type->GetCanonicalType();
+ canArg.witness = arg.witness;
+ canType->slots.args.add(canArg);
+ }
+ return canType;
+}
+
+RefPtr<Val> ExistentialSpecializedType::SubstituteImpl(SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+
+ auto substBaseType = baseType->SubstituteImpl(subst, &diff).as<Type>();
+
+ ExistentialTypeSlots substSlots;
+ for( auto paramType : slots.paramTypes )
+ {
+ substSlots.paramTypes.add( paramType->SubstituteImpl(subst, &diff).as<Type>() );
+ }
+ for( auto arg : slots.args )
+ {
+ ExistentialTypeSlots::Arg substArg;
+ substArg.type = arg.type->SubstituteImpl(subst, &diff).as<Type>();
+ substArg.witness = arg.witness->SubstituteImpl(subst, &diff);
+ substSlots.args.add(substArg);
+ }
+
+ if(!diff)
+ return this;
+
+ (*ioDiff)++;
+
+ RefPtr<ExistentialSpecializedType> substType = new ExistentialSpecializedType();
+ substType->setSession(getSession());
+ substType->baseType = substBaseType;
+ substType->slots = substSlots;
+ return substType;
+}
} // namespace Slang
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index eb7cee40a..aa3944d0a 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -1104,6 +1104,32 @@ namespace Slang
typedef Dictionary<unsigned int, RefPtr<RefObject>> AttributeArgumentValueDict;
+ /// Collects information about existential type parameters and their arguments.
+ struct ExistentialTypeSlots
+ {
+ /// For each type parameter, holds the interface/existential type that constrains it.
+ List<RefPtr<Type>> paramTypes;
+
+ /// An argument for an existential type parameter.
+ ///
+ /// Comprises a concrete type and a witness for its conformance to the desired
+ /// interface/existential type for the corresponding parameter.
+ ///
+ struct Arg
+ {
+ RefPtr<Type> type;
+ RefPtr<Val> witness;
+ };
+
+ /// Any arguments provided for the existential type parameters.
+ ///
+ /// It is possible for `args` to be empty even if `paramTypes` is non-empty;
+ /// that situation represents an unspecialized program or entry point.
+ ///
+ List<Arg> args;
+ };
+
+
// Generate class definition for all syntax classes
#define SYNTAX_FIELD(TYPE, NAME) TYPE NAME;
#define FIELD(TYPE, NAME) TYPE NAME;
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index 2d376d754..d0c00c73a 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -475,3 +475,16 @@ RAW(
virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override;
)
END_SYNTAX_CLASS()
+
+SYNTAX_CLASS(ExistentialSpecializedType, Type)
+RAW(
+ RefPtr<Type> baseType;
+ ExistentialTypeSlots slots;
+
+ virtual String ToString() override;
+ virtual bool EqualsImpl(Type * type) override;
+ virtual int GetHashCode() override;
+ virtual RefPtr<Type> CreateCanonicalType() override;
+ virtual RefPtr<Val> SubstituteImpl(SubstitutionSet subst, int* ioDiff) override;
+)
+END_SYNTAX_CLASS() \ No newline at end of file
diff --git a/tests/compute/interface-shader-param-legalization.slang b/tests/compute/interface-shader-param-legalization.slang
new file mode 100644
index 000000000..8c63d81ac
--- /dev/null
+++ b/tests/compute/interface-shader-param-legalization.slang
@@ -0,0 +1,54 @@
+// interface-shader-param-legalization.slang
+
+// Test case where concrete type implementing
+// an interface has resource-type fields nested in it.
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
+
+interface IModifier
+{
+ int modify(int val);
+}
+
+IModifier gModifier;
+
+int test(
+ int val)
+{
+ return gModifier.modify(val);
+}
+
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+RWStructuredBuffer<int> gOutputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(
+ uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ let tid = dispatchThreadID.x;
+
+ let inputVal : int = tid;
+ let outputVal = test(inputVal);
+
+ gOutputBuffer[tid] = outputVal;
+}
+
+// Now that we've define all the logic of the entry point,
+// we will define some concrete types that we can plug
+// in for the interface-type parameters.
+
+struct ConcreteData
+{
+ int offset;
+}
+
+struct ConcreteModifier : IModifier
+{
+ ConstantBuffer<ConcreteData> data;
+
+ int modify(int val) { return val + data.offset; }
+}
+
+//TEST_INPUT: globalExistentialType ConcreteModifier
+//TEST_INPUT:cbuffer(data=[256], stride=4):dxbinding(0),glbinding(0),out
diff --git a/tests/compute/interface-shader-param-legalization.slang.expected.txt b/tests/compute/interface-shader-param-legalization.slang.expected.txt
new file mode 100644
index 000000000..f94894bb2
--- /dev/null
+++ b/tests/compute/interface-shader-param-legalization.slang.expected.txt
@@ -0,0 +1,4 @@
+100
+101
+102
+103