Skip to content

Commit

Permalink
[js/webgpu] fix external buffer registration (microsoft#22254)
Browse files Browse the repository at this point in the history
### Description

Fixes the problem of running into failure when GPU inputs shuffled
between iterations.
  • Loading branch information
fs-eire authored Sep 28, 2024
1 parent 52a8c1c commit 1bda91f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 20 deletions.
9 changes: 7 additions & 2 deletions js/web/lib/wasm/jsep/backend-webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -785,15 +785,20 @@ export class WebGpuBackend {
this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping);
}

// the buffer may be user created, or managed by GPU data manager.
// The GPU data manager will not manage these buffers. we register them as external buffers.
//
// The map `sessionInputOutputMapping` is used to store the data ID and buffer for each input/output. Once a
// specific input/output is registered, the data ID will not change.
const previousBuffer = sessionInputOutputMapping.get(index);
const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]);
const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer);
sessionInputOutputMapping.set(index, [id, buffer]);
return id;
}
unregisterBuffers(sessionId: number): void {
const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId);
if (sessionInputOutputMapping) {
sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1]));
sessionInputOutputMapping.forEach((bufferInfo) => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[0]));
this.sessionExternalDataMapping.delete(sessionId);
}
}
Expand Down
25 changes: 7 additions & 18 deletions js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ export interface GpuDataManager {
* GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of
* the external buffer.
*/
registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number;
registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number;

/**
* unregister an external buffer for IO Binding.
*/
unregisterExternalBuffer(buffer: GPUBuffer): void;
unregisterExternalBuffer(id: GpuDataId): void;

/**
* destroy all gpu buffers.
Expand Down Expand Up @@ -196,9 +196,6 @@ class GpuDataManagerImpl implements GpuDataManager {
// The reusable uniform buffers
private freeUniformBuffers: Map<number, GPUBuffer[]>;

// The external buffers registered users for IO Binding.
private externalBuffers: Map<GPUBuffer, GpuDataId>;

// The pendingBuffers for capture graph.
// a SessionID -> GPUBuffer[] mapping.
private capturedPendingBuffers: Map<number, GPUBuffer[]>;
Expand All @@ -209,7 +206,6 @@ class GpuDataManagerImpl implements GpuDataManager {
this.freeUniformBuffers = new Map();
this.buffersForUploadingPending = [];
this.buffersPending = [];
this.externalBuffers = new Map();
this.capturedPendingBuffers = new Map();

for (const [key] of bucketFreelist) {
Expand Down Expand Up @@ -284,14 +280,11 @@ class GpuDataManagerImpl implements GpuDataManager {
);
}

registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number {
registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previous?: [GpuDataId, GPUBuffer]): number {
let id: number | undefined;
if (previousBuffer) {
id = this.externalBuffers.get(previousBuffer);
if (id === undefined) {
throw new Error('previous buffer is not registered');
}
if (buffer === previousBuffer) {
if (previous) {
id = previous[0];
if (buffer === previous[1]) {
LOG_DEBUG(
'verbose',
() =>
Expand All @@ -304,25 +297,21 @@ class GpuDataManagerImpl implements GpuDataManager {
throw new Error(`Registering a different external buffer under graph capture mode is not supported yet.
Please use the previous external buffer!`);
}
this.externalBuffers.delete(previousBuffer);
} else {
id = createNewGpuDataId();
}

this.storageCache.set(id, { gpuData: { id, type: GpuDataType.default, buffer }, originalSize });
this.externalBuffers.set(buffer, id);
LOG_DEBUG(
'verbose',
() => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`,
);
return id;
}

unregisterExternalBuffer(buffer: GPUBuffer): void {
const id = this.externalBuffers.get(buffer);
unregisterExternalBuffer(id: GpuDataId): void {
if (id !== undefined) {
this.storageCache.delete(id);
this.externalBuffers.delete(buffer);
LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`);
}
}
Expand Down

0 comments on commit 1bda91f

Please sign in to comment.