diff --git a/compute/fused_encoder.go b/compute/fused_encoder.go index 226ab2f..f72b6a8 100644 --- a/compute/fused_encoder.go +++ b/compute/fused_encoder.go @@ -17,6 +17,13 @@ type FusedEncoderProvider interface { // FusedEncoderAvailable returns true if the fused encoder kernel is loaded. FusedEncoderAvailable() bool + // AllocDeviceFloat32 allocates numElements float32s on the GPU and returns + // the device pointer. Memory is pool-managed and freed when the engine closes. + AllocDeviceFloat32(numElements int) (unsafe.Pointer, error) + + // CopyToDevice copies len(src) float32 values from host to a device pointer. + CopyToDevice(dst unsafe.Pointer, src []float32) error + // FusedEncoderForward executes one encoder layer forward pass. // weights: [16]unsafe.Pointer to layer weights. // bufs: [16]unsafe.Pointer to pre-allocated forward cache buffers. diff --git a/compute/gpu_fused_encoder.go b/compute/gpu_fused_encoder.go index f549305..2d3d0dc 100644 --- a/compute/gpu_fused_encoder.go +++ b/compute/gpu_fused_encoder.go @@ -5,6 +5,7 @@ import ( "unsafe" "github.com/zerfoo/ztensor/internal/cublas" + "github.com/zerfoo/ztensor/internal/gpuapi" ) // blasHandlePtr extracts the raw cuBLAS handle pointer from the BLAS interface. @@ -68,5 +69,17 @@ func (e *GPUEngine[T]) FusedEncoderBackward( totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream) } +// AllocDeviceFloat32 allocates GPU memory for numElements float32 values. +func (e *GPUEngine[T]) AllocDeviceFloat32(numElements int) (unsafe.Pointer, error) { + e.setDevice() + return e.pool.Alloc(e.deviceID, numElements*4) +} + +// CopyToDevice copies float32 data from host to device. +func (e *GPUEngine[T]) CopyToDevice(dst unsafe.Pointer, src []float32) error { + e.setDevice() + return e.runtime.Memcpy(dst, unsafe.Pointer(&src[0]), len(src)*4, gpuapi.MemcpyHostToDevice) +} + // Compile-time check that GPUEngine implements FusedEncoderProvider. var _ FusedEncoderProvider = (*GPUEngine[float32])(nil)