Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ jobs:
# Run go vet on all packages except those with intentional
# unsafe.Pointer usage for GPU runtime bindings via purego/dlopen.
# These warnings are expected and documented in docs/QUALITY.md.
go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$' | grep -v '/internal/nccl$')
go vet $(go list ./... | grep -v '/internal/cuda$' | grep -v '/internal/cublas$' | grep -v '/internal/hip$' | grep -v '/internal/opencl$' | grep -v '/internal/cudnn$' | grep -v '/internal/tensorrt$' | grep -v '/internal/fpga$' | grep -v '/internal/sycl$' | grep -v '/internal/metal$' | grep -v '/internal/pjrt$' | grep -v '/internal/nccl$')
- run: go test -race -timeout 300s ./...
47 changes: 47 additions & 0 deletions compute/fused_encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package compute

import "unsafe"

// FusedEncoderProvider is implemented by engines that support fused PatchTST
// encoder layer forward and backward passes. The fused kernel replaces ~78
// discrete engine operations per layer with a single orchestrated call,
// using cuBLAS for GEMMs and custom CUDA sub-kernels for LayerNorm, GELU,
// softmax, head transpose, and residual operations.
//
// Callers must pre-allocate all buffer arrays and pass device pointers.
// Buffer index constants (FEW_*, FEB_*, FEG_*, etc.) are defined in
// internal/cuda/kernels/fused_encoder_fwd_purego.go and fused_encoder_bwd_purego.go.
//
// This API is not covered by the v1 stability guarantee.
type FusedEncoderProvider interface {
// FusedEncoderAvailable returns true if the fused encoder kernel is loaded.
FusedEncoderAvailable() bool

// FusedEncoderForward executes one encoder layer forward pass.
// weights: [16]unsafe.Pointer to layer weights.
// bufs: [16]unsafe.Pointer to pre-allocated forward cache buffers.
// input/output: [totalRows, dModel] device pointers.
FusedEncoderForward(
weights *[16]unsafe.Pointer,
bufs *[16]unsafe.Pointer,
input, output unsafe.Pointer,
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
) error

// FusedEncoderBackward computes all gradients for one encoder layer.
// weights: [16]unsafe.Pointer to layer weights.
// weightT: [6]unsafe.Pointer to pre-transposed weights.
// fwdBufs: [16]unsafe.Pointer to forward cache (from FusedEncoderForward).
// bwdBufs: [15]unsafe.Pointer to backward scratch buffers.
// grads: [16]unsafe.Pointer to gradient accumulators (accumulated, not zeroed).
// dOutput: upstream gradient; dInput: output gradient; input: original layer input.
FusedEncoderBackward(
weights *[16]unsafe.Pointer,
weightT *[6]unsafe.Pointer,
fwdBufs *[16]unsafe.Pointer,
bwdBufs *[15]unsafe.Pointer,
grads *[16]unsafe.Pointer,
dOutput, dInput, input unsafe.Pointer,
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
) error
}
72 changes: 72 additions & 0 deletions compute/gpu_fused_encoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package compute

import (
"fmt"
"unsafe"

"github.com/zerfoo/ztensor/internal/cublas"
)

// blasHandlePtr extracts the raw cuBLAS handle pointer from the BLAS interface.
// Returns nil if the BLAS is not backed by cuBLAS.
func blasHandlePtr(b interface{}) unsafe.Pointer {
type handleProvider interface {
Handle() *cublas.Handle
}
if hp, ok := b.(handleProvider); ok {
h := hp.Handle()
if h != nil {
return h.Ptr()
}
}
return nil
}

// FusedEncoderAvailable returns true if the fused encoder kernel is loaded
// and the engine has a cuBLAS handle to pass to it.
func (e *GPUEngine[T]) FusedEncoderAvailable() bool {
return e.kernels.FusedEncoderFwdAvailable() && blasHandlePtr(e.blas) != nil
}

// FusedEncoderForward executes one fused encoder layer forward pass.
func (e *GPUEngine[T]) FusedEncoderForward(
weights *[16]unsafe.Pointer,
bufs *[16]unsafe.Pointer,
input, output unsafe.Pointer,
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
) error {
h := blasHandlePtr(e.blas)
if h == nil {
return fmt.Errorf("FusedEncoderForward: cuBLAS handle not available")
}
e.setDevice()
return e.kernels.FusedEncoderFwdF32(h, weights, bufs, input, output,
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream)
}

// FusedEncoderBackward computes all gradients for one fused encoder layer.
func (e *GPUEngine[T]) FusedEncoderBackward(
weights *[16]unsafe.Pointer,
weightT *[6]unsafe.Pointer,
fwdBufs *[16]unsafe.Pointer,
bwdBufs *[15]unsafe.Pointer,
grads *[16]unsafe.Pointer,
dOutput, dInput, input unsafe.Pointer,
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches int,
) error {
h := blasHandlePtr(e.blas)
if h == nil {
return fmt.Errorf("FusedEncoderBackward: cuBLAS handle not available")
}
e.setDevice()
// The KernelRunner interface uses *[16] for weightT, but we have *[6].
// Convert via unsafe pointer.
var wt16 [16]unsafe.Pointer
copy(wt16[:6], weightT[:])
return e.kernels.FusedEncoderBwdF32(h, weights, &wt16, fwdBufs, bwdBufs, grads,
dOutput, dInput, input,
totalRows, dModel, nHeads, headDim, ffnDim, bsC, numPatches, e.stream)
}

// Compile-time check that GPUEngine implements FusedEncoderProvider.
var _ FusedEncoderProvider = (*GPUEngine[float32])(nil)
4 changes: 4 additions & 0 deletions internal/cublas/cublas_purego.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ type Handle struct {
ptr uintptr // cublasHandle_t is a pointer
}

// Ptr returns the raw cuBLAS handle pointer for passing to C functions
// (e.g., the fused encoder kernel orchestrator).
func (h *Handle) Ptr() unsafe.Pointer { return unsafe.Pointer(h.ptr) }

// CreateHandle creates a new cuBLAS context handle.
func CreateHandle() (*Handle, error) {
lib, err := getCublasLib()
Expand Down
4 changes: 2 additions & 2 deletions internal/cuda/kernels/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ ifeq ($(CUDA_ARCH),sm_121)
NVCC_FLAGS += -DFLASH_BLOCK_SIZE=64
endif

SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu
SRCS = counter.cu dequant_q4k.cu dequant_q5_0.cu dequant_q5k.cu dequant_q6k.cu elementwise.cu elementwise_fp16.cu flash_attention.cu flash_attention2.cu flash_decode.cu fp4_gemv.cu fp8_gemm.cu fp8_ops.cu fused_add_rmsnorm.cu fused_encoder_fwd.cu fused_encoder_bwd.cu fused_norm_add.cu fused_qk_norm_rope.cu fused_repeat_interleave.cu fused_rope.cu fused_softmax_vmul.cu fused_swiglu.cu gather.cu gather_q8.cu gemm_int8.cu gemm_int4.cu gemm_q4.cu gemm_q8.cu gemv_q4k.cu gemv_q4k_sm121.cu gemv_q5k.cu gemv_q5_0.cu gemv_q6k.cu gemv_warp.cu megakernel_ops.cu offset_memcpy.cu paged_attention.cu ragged_attention.cu rope_select.cu scaled_softmax.cu selective_scan.cu sgemv_m1.cu ternary_gemv.cu transpose.cu rmsnorm.cu argmax.cu
OBJS = $(SRCS:.cu=.o)
PIC_OBJS = $(SRCS:.cu=.pic.o)
LIB = libkernels.a
Expand All @@ -27,7 +27,7 @@ $(LIB): $(OBJS)
ar rcs $@ $^

$(SO): $(PIC_OBJS)
$(NVCC) -shared -o $@ $^
$(NVCC) -shared -o $@ $^ -lcublas

# Limit register pressure for kernels that benefit from higher occupancy.
# gemm_q4: 40->32 regs/thread, no spills, occupancy 75%->100% (256-thread blocks).
Expand Down
Loading
Loading