summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-12-20 17:35:10 -0800
committerGitHub <noreply@github.com>2017-12-20 17:35:10 -0800
commit6f681279d99e72e717bb2b91763b80e570ae725b (patch)
tree501c547ff405aa5227a0ad165b9ec371fcd94ef8 /source
parent35318fb2b08c82f80cbd464e93d81ebe719c40be (diff)
IR: fixes for subscript accessors (#322)
* IR: fixes for subscript accessors Fixes #320 This is a bunch of fixes for handling of `__subscript` operations on builtin types (notably `RWStructuredBuffer` and `StructuredBuffer` at this point). - Automatically add a `GetterDecl` to any subscript decalratio was declithout any accessors. This avoids hitting a null- dereference in the emit logic. - Add a notion of a `RefAccessor` (declared with `ref`) as a peer to getters and setters. The idea is that a `ref` accessor returns a pointer to the element data, so that it can be used for both getting and setting values. This is closer to the behavior of `RWStructuredBuffer` element access in HLSL. - Fixes for dealing with "access chains" where there might be a combination of a subscript (where the is a `get` and `set` but no `ref`) and member access, so that we have to read the base value into a temp, modify it, and then write it back. - This logic is still a bit of a mess, so we will eventually want to take a more consistent pass over this to deal with how we "materialize" values for setters. - Update `RWStructuredBuffer` to have a `ref` accessor, and then fix up the IR tests to handle the new opcode that I added for it. - Note: I didn't handle this as an intrinsic simply because the `tests/ir/*` tests aren't really set up to handle builtins with ugly mangled names. * Fixup: type error in VM for buffer element ref I was using the result type of the op as the element type for computing the element address, but the result type is a pointer to the real element type. This caused test failures on 64-bit platforms, where the stride of the buffer in the `ir/factorial` test needs to be 4. The fix is to assume the result type is a pointer, and extract the pointed-to type out of that.
Diffstat (limited to 'source')
-rw-r--r--source/slang/check.cpp30
-rw-r--r--source/slang/decl-defs.h1
-rw-r--r--source/slang/emit.cpp2
-rw-r--r--source/slang/hlsl.meta.slang10
-rw-r--r--source/slang/hlsl.meta.slang.h10
-rw-r--r--source/slang/ir-inst-defs.h1
-rw-r--r--source/slang/lower-to-ir.cpp263
-rw-r--r--source/slang/parser.cpp4
-rw-r--r--source/slang/vm.cpp27
9 files changed, 292 insertions, 56 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 4840ae30d..b981fb778 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -3575,6 +3575,32 @@ namespace Slang
decl->SetCheckState(DeclCheckState::CheckedHeader);
+ // If we have a subscript declaration with no accessor declarations,
+ // then we should create a single `GetterDecl` to represent
+ // the implicit meaning of their declaration, so:
+ //
+ // subscript(uint index) -> T;
+ //
+ // becomes:
+ //
+ // subscript(uint index) -> T { get; }
+ //
+
+ bool anyAccessors = false;
+ for(auto accessorDecl : decl->getMembersOfType<AccessorDecl>())
+ {
+ anyAccessors = true;
+ }
+
+ if(!anyAccessors)
+ {
+ RefPtr<GetterDecl> getterDecl = new GetterDecl();
+ getterDecl->loc = decl->loc;
+
+ getterDecl->ParentDecl = decl;
+ decl->Members.Add(getterDecl);
+ }
+
for(auto mm : decl->Members)
{
checkDecl(mm);
@@ -4662,6 +4688,10 @@ namespace Slang
{
callExpr->type.IsLeftValue = true;
}
+ for(auto refAccessor : subscriptDeclRef.getDecl()->getMembersOfType<RefAccessorDecl>())
+ {
+ callExpr->type.IsLeftValue = true;
+ }
}
// TODO: there may be other cases that confer l-value-ness
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index e24a535c5..7cb9ffc0f 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -179,6 +179,7 @@ SIMPLE_SYNTAX_CLASS(AccessorDecl, FunctionDeclBase)
SIMPLE_SYNTAX_CLASS(GetterDecl, AccessorDecl)
SIMPLE_SYNTAX_CLASS(SetterDecl, AccessorDecl)
+SIMPLE_SYNTAX_CLASS(RefAccessorDecl, AccessorDecl)
SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase)
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 311f903cc..5cd0ea832 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -4875,6 +4875,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_FieldAddress:
case kIROp_getElementPtr:
case kIROp_specialize:
+ case kIROp_BufferElementRef:
return true;
}
@@ -5516,6 +5517,7 @@ emitDeclImpl(decl, nullptr);
break;
case kIROp_BufferLoad:
+ case kIROp_BufferElementRef:
emitIROperand(ctx, inst->getArg(0));
emit("[");
emitIROperand(ctx, inst->getArg(1));
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index cdf720006..09cf731b5 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -44,8 +44,7 @@ struct StructuredBuffer
T Load(int location);
T Load(int location, out uint status);
- __intrinsic_op(bufferLoad)
- __subscript(uint index) -> T;
+ __subscript(uint index) -> T { __intrinsic_op(bufferLoad) get; };
};
__generic<T> __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer
@@ -201,11 +200,8 @@ struct RWStructuredBuffer
__subscript(uint index) -> T
{
- __intrinsic_op(bufferLoad)
- get;
-
- __intrinsic_op(bufferStore)
- set;
+ __intrinsic_op(bufferElementRef)
+ ref;
}
};
diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h
index c9ccfcc81..acaef8fbd 100644
--- a/source/slang/hlsl.meta.slang.h
+++ b/source/slang/hlsl.meta.slang.h
@@ -45,8 +45,7 @@ sb << "\n";
sb << " T Load(int location);\n";
sb << " T Load(int location, out uint status);\n";
sb << "\n";
-sb << " __intrinsic_op(bufferLoad)\n";
-sb << " __subscript(uint index) -> T;\n";
+sb << " __subscript(uint index) -> T { __intrinsic_op(bufferLoad) get; };\n";
sb << "};\n";
sb << "\n";
sb << "__generic<T> __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer\n";
@@ -203,11 +202,8 @@ sb << " T Load(int location, out uint status);\n";
sb << "\n";
sb << "\t__subscript(uint index) -> T\n";
sb << "\t{\n";
-sb << "\t\t__intrinsic_op(bufferLoad)\n";
-sb << "\t\tget;\n";
-sb << "\n";
-sb << "\t\t__intrinsic_op(bufferStore)\n";
-sb << "\t\tset;\n";
+sb << " __intrinsic_op(bufferElementRef)\n";
+sb << " ref;\n";
sb << "\t}\n";
sb << "};\n";
sb << "\n";
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index ffad04467..7eafe89f7 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -119,6 +119,7 @@ INST(Store, store, 2, 0)
INST(BufferLoad, bufferLoad, 2, 0)
INST(BufferStore, bufferStore, 3, 0)
+INST(BufferElementRef, bufferElementRef, 2, 0)
INST(FieldExtract, get_field, 2, 0)
INST(FieldAddress, get_field_addr, 2, 0)
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 8c8b34908..8fc1239cd 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -207,6 +207,9 @@ struct BoundMemberInfo : ExtendedValueInfo
// The (AST-level) declaration reference.
DeclRef<Decl> declRef;
+
+ // The type of this value
+ RefPtr<Type> type;
};
// Represents the result of a swizzle operation in
@@ -362,6 +365,7 @@ LoweredValInfo emitCallToVal(
switch (funcVal.flavor)
{
case LoweredValInfo::Flavor::None:
+ SLANG_UNEXPECTED("null function");
default:
return LoweredValInfo::simple(
builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args));
@@ -531,25 +535,60 @@ LoweredValInfo emitCallToDeclRef(
if (auto subscriptDeclRef = funcDeclRef.As<SubscriptDecl>())
{
- // A reference to a subscript declaration is potentially a
- // special case, if we have more than just a getter.
+ // A reference to a subscript declaration is a special case,
+ // because it is not possible to call a subscript directly;
+ // we must call one of its accessors.
+ //
+ // TODO: everything here will also apply to propery declarations
+ // once we have them, so some of this code might be shared
+ // some day.
DeclRef<GetterDecl> getterDeclRef;
bool justAGetter = true;
for (auto accessorDeclRef : getMembersOfType<AccessorDecl>(subscriptDeclRef))
{
+ // If the subscript declares a `ref` accessor, then we can just
+ // invoke that directly to get an l-value we can use.
+ if(auto refAccessorDeclRef = accessorDeclRef.As<RefAccessorDecl>())
+ {
+ // The `ref` accessor will return a pointer to the value, so
+ // we need to reflect that in the type of our `call` instruction.
+ RefPtr<Type> ptrType = context->getSession()->getPtrType(type);
+
+ // Rather than call `emitCallToVal` here, we make a recursive call
+ // to `emitCallToDeclRef` so that it can handle things like intrinsic-op
+ // modifiers attached to the acecssor.
+ LoweredValInfo callVal = emitCallToDeclRef(
+ context,
+ ptrType,
+ refAccessorDeclRef,
+ funcExpr,
+ argCount,
+ args);
+
+ // The result from the call needs to be implicitly dereferenced,
+ // so that it can work as an l-value of the desired result type.
+ return LoweredValInfo::ptr(getSimpleVal(context, callVal));
+ }
+
+ // If we don't find a `ref` accessor, then we want to track whether
+ // this subscript has any accessors other than `get` (assuming
+ // that everything except `get` can be used for setting...).
+
if (auto foundGetterDeclRef = accessorDeclRef.As<GetterDecl>())
{
+ // We found a getter.
getterDeclRef = foundGetterDeclRef;
}
else
{
+ // There was something other than a getter, so we can't
+ // invoke an accessor just now.
justAGetter = false;
- break;
}
}
- if (!justAGetter)
+ if (!justAGetter || !getterDeclRef)
{
// We can't perform an actual call right now, because
// this expression might appear in an r-value or l-value
@@ -575,8 +614,7 @@ LoweredValInfo emitCallToDeclRef(
// Otherwise we are just call the getter, and so that
// is what we need to be emitting a call to...
- if (getterDeclRef)
- funcDeclRef = getterDeclRef;
+ funcDeclRef = getterDeclRef;
}
auto funcDecl = funcDeclRef.getDecl();
@@ -660,7 +698,64 @@ LoweredValInfo emitCallToDeclRef(
return emitCallToDeclRef(context, type, funcDeclRef, funcExpr, args.Count(), args.Buffer());
}
-IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered)
+LoweredValInfo extractField(
+ IRGenContext* context,
+ Type* fieldType,
+ LoweredValInfo base,
+ DeclRef<StructField> field)
+{
+ IRBuilder* builder = context->irBuilder;
+
+ switch (base.flavor)
+ {
+ default:
+ {
+ IRValue* irBase = getSimpleVal(context, base);
+ return LoweredValInfo::simple(
+ builder->emitFieldExtract(
+ fieldType,
+ irBase,
+ builder->getDeclRefVal(field)));
+ }
+ break;
+
+ case LoweredValInfo::Flavor::BoundMember:
+ case LoweredValInfo::Flavor::BoundSubscript:
+ {
+ // The base value is one that is trying to defer a get-vs-set
+ // decision, so we will need to do the same.
+
+ RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo();
+ boundMemberInfo->type = fieldType;
+ boundMemberInfo->base = base;
+ boundMemberInfo->declRef = field;
+
+ context->shared->extValues.Add(boundMemberInfo);
+ return LoweredValInfo::boundMember(boundMemberInfo);
+ }
+ break;
+
+ case LoweredValInfo::Flavor::Ptr:
+ {
+ // We are "extracting" a field from an lvalue address,
+ // which means we should just compute an lvalue
+ // representing the field address.
+ IRValue* irBasePtr = base.val;
+ return LoweredValInfo::ptr(
+ builder->emitFieldAddress(
+ context->getSession()->getPtrType(fieldType),
+ irBasePtr,
+ builder->getDeclRefVal(field)));
+ }
+ break;
+ }
+}
+
+
+
+LoweredValInfo materialize(
+ IRGenContext* context,
+ LoweredValInfo lowered)
{
auto builder = context->irBuilder;
@@ -668,13 +763,9 @@ top:
switch(lowered.flavor)
{
case LoweredValInfo::Flavor::None:
- return nullptr;
-
case LoweredValInfo::Flavor::Simple:
- return lowered.val;
-
case LoweredValInfo::Flavor::Ptr:
- return builder->emitLoad(lowered.val);
+ return lowered;
case LoweredValInfo::Flavor::BoundSubscript:
{
@@ -693,7 +784,27 @@ top:
}
SLANG_UNEXPECTED("subscript had no getter");
- UNREACHABLE_RETURN(nullptr);
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+ break;
+
+ case LoweredValInfo::Flavor::BoundMember:
+ {
+ auto boundMemberInfo = lowered.getBoundMemberInfo();
+ auto base = materialize(context, boundMemberInfo->base);
+
+ auto declRef = boundMemberInfo->declRef;
+ if( auto fieldDeclRef = declRef.As<StructField>() )
+ {
+ lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef);
+ goto top;
+ }
+ else
+ {
+
+ SLANG_UNEXPECTED("unexpected member flavor");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
}
break;
@@ -701,15 +812,41 @@ top:
{
auto swizzleInfo = lowered.getSwizzledLValueInfo();
- return builder->emitSwizzle(
+ return LoweredValInfo::simple(builder->emitSwizzle(
swizzleInfo->type,
getSimpleVal(context, swizzleInfo->base),
swizzleInfo->elementCount,
- swizzleInfo->elementIndices);
+ swizzleInfo->elementIndices));
}
default:
SLANG_UNEXPECTED("unhandled value flavor");
+ UNREACHABLE_RETURN(LoweredValInfo());
+ }
+
+}
+
+IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered)
+{
+ auto builder = context->irBuilder;
+
+ // First, try to eliminate any "bound" operations along the chain,
+ // so that we are dealing with an ordinary value, or an l-value pointer.
+ lowered = materialize(context, lowered);
+
+ switch(lowered.flavor)
+ {
+ case LoweredValInfo::Flavor::None:
+ return nullptr;
+
+ case LoweredValInfo::Flavor::Simple:
+ return lowered.val;
+
+ case LoweredValInfo::Flavor::Ptr:
+ return builder->emitLoad(lowered.val);
+
+ default:
+ SLANG_UNEXPECTED("unhandled value flavor");
UNREACHABLE_RETURN(nullptr);
}
}
@@ -1042,6 +1179,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
else if (auto callableDeclRef = declRef.As<CallableDecl>())
{
RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo();
+ boundMemberInfo->type = nullptr;
boundMemberInfo->base = loweredBase;
boundMemberInfo->declRef = callableDeclRef;
return LoweredValInfo::boundMember(boundMemberInfo);
@@ -1543,33 +1681,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
LoweredValInfo base,
DeclRef<StructField> field)
{
- switch (base.flavor)
- {
- default:
- {
- IRValue* irBase = getSimpleVal(context, base);
- return LoweredValInfo::simple(
- getBuilder()->emitFieldExtract(
- getSimpleType(fieldType),
- irBase,
- getBuilder()->getDeclRefVal(field)));
- }
- break;
-
- case LoweredValInfo::Flavor::Ptr:
- {
- // We are "extracting" a field from an lvalue address,
- // which means we should just compute an lvalue
- // representing the field address.
- IRValue* irBasePtr = base.val;
- return LoweredValInfo::ptr(
- getBuilder()->emitFieldAddress(
- context->getSession()->getPtrType(getSimpleType(fieldType)),
- irBasePtr,
- getBuilder()->getDeclRefVal(field)));
- }
- break;
- }
+ return Slang::extractField(context, getSimpleType(fieldType), base, field);
}
LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr)
@@ -2392,6 +2504,28 @@ void lowerStmt(
return visitor.dispatch(stmt);
}
+static LoweredValInfo maybeMoveMutableTemp(
+ IRGenContext* context,
+ LoweredValInfo const& val)
+{
+ switch(val.flavor)
+ {
+ case LoweredValInfo::Flavor::Ptr:
+ return val;
+
+ default:
+ {
+ IRValue* irVal = getSimpleVal(context, val);
+ auto type = irVal->getType();
+ auto var = createVar(context, type);
+
+ assign(context, var, LoweredValInfo::simple(irVal));
+ return var;
+ }
+ break;
+ }
+}
+
void assign(
IRGenContext* context,
LoweredValInfo const& inLeft,
@@ -2412,6 +2546,7 @@ top:
case LoweredValInfo::Flavor::Ptr:
case LoweredValInfo::Flavor::SwizzledLValue:
case LoweredValInfo::Flavor::BoundSubscript:
+ case LoweredValInfo::Flavor::BoundMember:
{
builder->emitStore(
left.val,
@@ -2511,6 +2646,43 @@ top:
}
break;
+ case LoweredValInfo::Flavor::BoundMember:
+ {
+ auto boundMemberInfo = left.getBoundMemberInfo();
+
+ // If we hit this case, then it means that we are trying to set
+ // a single field in someting that is not atomically set-able.
+ // (e.g., an element of a value where the `subscript` operation
+ // has `get` and `set` but not a `ref` accessor).
+ //
+ // We need to read the entire base value out, modify the field
+ // we care about, and then write it back.
+
+ auto declRef = boundMemberInfo->declRef;
+ if( auto fieldDeclRef = declRef.As<StructField>() )
+ {
+ // materialize the base value and move it into
+ // a mutable temporary if needed
+ auto baseVal = boundMemberInfo->base;
+ auto tempVal = maybeMoveMutableTemp(context, materialize(context, baseVal));
+
+ // extract the field l-value out of the temporary
+ auto tempFieldVal = extractField(context, boundMemberInfo->type, tempVal, fieldDeclRef);
+
+ // assign to the field of the temporary l-value
+ assign(context, tempFieldVal, right);
+
+ // write back the modified temporary to the base l-value
+ assign(context, baseVal, tempVal);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("handled member flavor");
+ }
+
+ }
+ break;
+
default:
SLANG_UNIMPLEMENTED_X("assignment");
break;
@@ -3303,6 +3475,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
irResultType = context->getSession()->getVoidType();
}
+ if( auto refAccessorDecl = dynamic_cast<RefAccessorDecl*>(decl) )
+ {
+ // A `ref` accessor needs to return a *pointer* to the value
+ // being accessed, rather than a simple value.
+ irResultType = context->getSession()->getPtrType(irResultType);
+ }
+
auto irFuncType = getFuncType(
context,
paramTypes.Count(),
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index 3c63c9b56..35cc96b5c 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -2313,6 +2313,10 @@ namespace Slang
{
decl = new SetterDecl();
}
+ else if( AdvanceIf(parser, "ref") )
+ {
+ decl = new RefAccessorDecl();
+ }
else
{
Unexpected(parser);
diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp
index d795a841b..ffc455232 100644
--- a/source/slang/vm.cpp
+++ b/source/slang/vm.cpp
@@ -877,6 +877,33 @@ void resumeThread(
memcpy(elementData, srcPtr, size);
}
break;
+
+ case kIROp_BufferElementRef:
+ {
+ VMType ptrType = decodeType(frame, &ip);
+ VMType type = ((VMPtrTypeImpl*)ptrType.getImpl())->base;
+
+ UInt argCount = decodeUInt(&ip);
+ void* argPtrs[16] = { 0 };
+ for( UInt aa = 0; aa < argCount; ++aa )
+ {
+ void* argPtr = decodeOperandPtr<void>(frame, &ip);
+ argPtrs[aa] = argPtr;
+ }
+
+ void* dest = decodeOperandPtr<void>(frame, &ip);
+
+ char* bufferData = *(char**)argPtrs[0];
+ uint32_t index = *(uint32_t*)argPtrs[1];
+
+ auto size = type.getSize();
+ char* elementData = bufferData + index*size;
+
+ *(void**)dest = elementData;
+ }
+ break;
+
+
case kIROp_Call:
{
VMType type = decodeType(frame, &ip);