summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/ConstantBuffer.cpp
blob: 5f3bfbe1c0ea707397e5088424660327fcce12fb (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include "stdafx.h"
#include "ConstantBuffer.h"
#include "../D3D/MappedResource.h"
using namespace DirectCompute;

HRESULT ConstantBuffer::create()
{
	if( nullptr == buffer )
	{
		CD3D11_BUFFER_DESC desc{ 16 * 3 * 2, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE };
		return device()->CreateBuffer( &desc, nullptr, &buffer );
	}
	return HRESULT_FROM_WIN32( ERROR_ALREADY_INITIALIZED );
}

namespace
{
	__forceinline void copy32( __m128i* rdi, const TensorShape& ts )
	{
		_mm_storeu_si128( rdi, ts.sizeVec() );
		_mm_storeu_si128( rdi + 1, ts.stridesVec() );
	}
}

HRESULT ConstantBuffer::update( const TensorShape& t0 )
{
	MappedResource mapped;
	CHECK( mapped.map( buffer, false ) );

	__m128i* const rdi = ( __m128i* )mapped.data();
	copy32( rdi, t0 );
	return S_OK;
}

HRESULT ConstantBuffer::update( const TensorShape& t0, const TensorShape& t1 )
{
	MappedResource mapped;
	CHECK( mapped.map( buffer, false ) );

	__m128i* const rdi = ( __m128i* )mapped.data();
	copy32( rdi, t0 );
	copy32( rdi + 2, t1 );
	return S_OK;
}

HRESULT ConstantBuffer::update( const TensorShape& t0, const TensorShape& t1, const TensorShape& t2 )
{
	MappedResource mapped;
	CHECK( mapped.map( buffer, false ) );

	__m128i* const rdi = ( __m128i* )mapped.data();
	copy32( rdi, t0 );
	copy32( rdi + 2, t1 );
	copy32( rdi + 4, t2 );
	return S_OK;
}

void ConstantBuffer::bind() const
{
	ID3D11Buffer* p = buffer;
	assert( nullptr != p );
	context()->CSSetConstantBuffers( 0, 1, &p );
}