summaryrefslogtreecommitdiffstats
path: root/Whisper/Whisper/DecoderInputBuffers.cpp
blob: 68d3cecb61d0951c07daee835274ea7b59822c46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#include "stdafx.h"
#include "DecoderInputBuffers.h"
#include "../D3D/createBuffer.h"
#include "../D3D/MappedResource.h"
using namespace DirectCompute;

void DecoderInputBuffers::resize( uint32_t size )
{
	if( 0 == size )
		throw E_INVALIDARG;

	if( size <= m_capacity )
	{
		m_size = size;
		return;
	}

	embd = nullptr;

	// Round up by 256, mostly for lulz
	const uint32_t newCapacity = ( size + 0xFFu ) & ( ~( 0xFFu ) );
	const size_t totalBytes = (size_t)4 * newCapacity;

	check( createBuffer( eBufferUse::Dynamic, totalBytes, &embd, nullptr, nullptr ) );

	m_capacity = newCapacity;
	m_size = size;
}

namespace
{
	static Tensor createView( ID3D11Buffer* buffer, uint32_t length )
	{
		Tensor res;

		TensorGpuViews& views = res;
		check( views.create( buffer, DXGI_FORMAT_R32_UINT, length, false ) );

		res.ne = { length, 1, 1, 1 };
		res.setDenseStrides();
		return res;
	}
}

Tensor DecoderInputBuffers::embedding( const int* rsi ) const
{
	if( nullptr == embd || m_size == 0 )
		throw OLE_E_BLANK;

	// Upload the data
	{
		MappedResource mapped;
		check( mapped.map( embd, false ) );
		int* const rdi = (int*)mapped.data();
		memcpy( rdi, rsi, m_size * (size_t)4 );
	}

	return createView( embd, m_size );
}

void DecoderInputBuffers::clear()
{
	embd = nullptr;
	m_size = 0;
	m_capacity = 0;
}