runner.go: Better abstract vision model integration

-Update mllama to take the cross attention state as embeddings in
a batch, more similar to how Llava handles it. This improves
integration with the input cache.
-Pass locations in a prompt for embeddings using tags similar to Llava.
-Abstract interface to vision models so the main runner accesses Clip
and Mllama similarly

Co-authored-by: Michael Yang <mxyng@pm.me>
This commit is contained in:
Jesse Gross
2024-10-11 15:34:01 -07:00
committed by Jesse Gross
parent 712e99d477
commit c826e57475
13 changed files with 534 additions and 454 deletions

View File

@@ -111,6 +111,28 @@ func PrintSystemInfo() string {
return C.GoString(C.llama_print_system_info()) + compiler
}
func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
if gguf_ctx == nil {
return "", errors.New("unable to load model file")
}
defer C.gguf_free(gguf_ctx)
key := C.CString("general.architecture")
defer C.free(unsafe.Pointer(key))
arch_index := C.gguf_find_key(gguf_ctx, key)
if int(arch_index) < 0 {
return "", errors.New("unknown model architecture")
}
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
return C.GoString(arch), nil
}
type ContextParams struct {
c C.struct_llama_context_params
}
@@ -443,71 +465,36 @@ func Quantize(infile, outfile string, ftype uint32) error {
return nil
}
// llava
// vision processing
type ClipContext struct {
c *C.struct_clip_ctx
m *C.struct_mllama_ctx
IsMllama bool
embedPin runtime.Pinner
pinned bool
c *C.struct_clip_ctx
}
func getVisionArch(mp *C.char) (string, error) {
gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
if gguf_ctx == nil {
return "", errors.New("unable to load vision projector")
}
defer C.gguf_free(gguf_ctx)
arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
if int(arch_index) < 0 {
return "", errors.New("unknown vision model architecture")
}
arch := C.gguf_get_val_str(gguf_ctx, arch_index)
return C.GoString(arch), nil
}
func NewClipContext(modelPath string) (*ClipContext, error) {
func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
c := C.clip_model_load(mp, 1)
arch, err := getVisionArch(mp)
if err != nil {
return nil, err
projEmbedSize := int(C.clip_n_mmproj_embd(c))
modelEmbedSize := llamaContext.Model().NEmbd()
if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
}
var cc ClipContext
if arch == "clip" {
cc.c = C.clip_model_load(mp, 1)
} else if arch == "mllama" {
cc.m = C.mllama_model_load(mp, 1)
cc.IsMllama = true
} else {
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
}
// XXX: check embedding size?
return &cc, nil
return &ClipContext{c: c}, nil
}
func (c *ClipContext) Free() {
if c.c != nil {
C.clip_free(c.c)
}
if c.m != nil {
C.mllama_free(c.m)
}
C.clip_free(c.c)
}
func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 {
l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
numTokens := int(c.n_image_pos)
numTokens := int(l.n_image_pos)
numEmbed := llamaContext.Model().NEmbd()
s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens)
s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
embed := make([][]float32, numTokens)
rows := make([]float32, len(s))
@@ -517,51 +504,57 @@ func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
C.llava_image_embed_free(c)
C.llava_image_embed_free(l)
return embed
}
func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 {
type MllamaContext struct {
c *C.struct_mllama_ctx
}
func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
c := C.mllama_model_load(mp, 1)
projEmbedSize := int(C.mllama_n_embd(c))
modelEmbedSize := llamaContext.Model().NEmbd()
if projEmbedSize != modelEmbedSize {
return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
}
return &MllamaContext{c: c}, nil
}
func (m *MllamaContext) Free() {
C.mllama_free(m.c)
}
func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 {
img := C.mllama_image_init()
defer C.mllama_image_free(img)
C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m))
numEmbed := llamaContext.Model().NEmbd()
rows := make([]float32, m.EmbedSize(llamaContext))
C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
rows := make([]float32, numEmbed*numTokens)
C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
embed := make([][]float32, numTokens)
for i := range embed {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
embed := make([][]float32, 1)
embed[0] = rows
return embed
}
// This really needs to be set on a batch instead
func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) {
if embed != nil {
if clipContext.pinned {
panic("Cross attention state already pinned")
}
func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
numEmbed := llamaContext.Model().NEmbd()
embedData := &embed[0][0]
clipContext.embedPin.Pin(embedData)
clipContext.pinned = true
return numTokens * numEmbed
}
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData)))
} else {
C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
if clipContext.pinned {
clipContext.embedPin.Unpin()
clipContext.pinned = false
}
}
func (c *Context) SetCrossAttention(state bool) {
C.llama_set_cross_attention(c.c, C.bool(state))
}
// sampling