Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions compute/fused_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ type FusedEncoderProvider interface {
// FusedEncoderAvailable returns true if the fused encoder kernel is loaded.
FusedEncoderAvailable() bool

// AllocDeviceFloat32 allocates numElements float32s on the GPU and returns
// the device pointer. Memory is pool-managed and freed when the engine closes.
AllocDeviceFloat32(numElements int) (unsafe.Pointer, error)

// CopyToDevice copies len(src) float32 values from host to a device pointer.
CopyToDevice(dst unsafe.Pointer, src []float32) error

// FusedEncoderForward executes one encoder layer forward pass.
// weights: [16]unsafe.Pointer to layer weights.
// bufs: [16]unsafe.Pointer to pre-allocated forward cache buffers.
Expand Down
13 changes: 13 additions & 0 deletions compute/gpu_fused_encoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"unsafe"

"github.com/zerfoo/ztensor/internal/cublas"
"github.com/zerfoo/ztensor/internal/gpuapi"
)

// blasHandlePtr extracts the raw cuBLAS handle pointer from the BLAS interface.
Expand Down Expand Up @@ -68,5 +69,17 @@ func (e *GPUEngine[T]) FusedEncoderBackward(
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream)
}

// AllocDeviceFloat32 allocates GPU memory for numElements float32 values.
func (e *GPUEngine[T]) AllocDeviceFloat32(numElements int) (unsafe.Pointer, error) {
e.setDevice()
return e.pool.Alloc(e.deviceID, numElements*4)
}

// CopyToDevice copies float32 data from host to device.
func (e *GPUEngine[T]) CopyToDevice(dst unsafe.Pointer, src []float32) error {
e.setDevice()
return e.runtime.Memcpy(dst, unsafe.Pointer(&src[0]), len(src)*4, gpuapi.MemcpyHostToDevice)
}

// Compile-time check that GPUEngine implements FusedEncoderProvider.
var _ FusedEncoderProvider = (*GPUEngine[float32])(nil)
Loading