mirror of
https://github.com/likelovewant/ollama-for-amd.git
synced 2025-12-22 23:03:55 +00:00
input: Rename Options to Batch
Options is no longer very descriptive of this struct.
This commit is contained in:
@@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
|
||||
|
||||
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
|
||||
type Model interface {
|
||||
Forward(ml.Context, input.Options) (ml.Tensor, error)
|
||||
Forward(ml.Context, input.Batch) (ml.Tensor, error)
|
||||
|
||||
Backend() ml.Backend
|
||||
Config() config
|
||||
@@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
|
||||
t.Kind() == reflect.Slice
|
||||
}
|
||||
|
||||
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
|
||||
if len(opts.Positions) != len(opts.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
|
||||
func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
|
||||
if len(batch.Positions) != len(batch.Sequences) {
|
||||
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
|
||||
}
|
||||
|
||||
if len(opts.Positions) < 1 {
|
||||
if len(batch.Positions) < 1 {
|
||||
return nil, errors.New("batch size cannot be less than 1")
|
||||
}
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, opts)
|
||||
err := cache.StartForward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := m.Forward(ctx, opts)
|
||||
t, err := m.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user