diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 153a3e57..962931fe 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -1009,12 +1009,17 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) { } } -func (s *Server) reserveWorstCaseGraph() error { +func (s *Server) reserveWorstCaseGraph(prompt bool) error { ctx := s.model.Backend().NewContext() defer ctx.Close() var err error - inputs := make([]*input.Input, s.batchSize) + batchSize := 1 + if prompt { + batchSize = s.batchSize + } + + inputs := make([]*input.Input, batchSize) for i := range inputs { inputs[i] = &input.Input{} } @@ -1031,7 +1036,7 @@ func (s *Server) reserveWorstCaseGraph() error { // - The result may now be larger than a batch (images may not fit in a // single batch), so trim based on what will fit and must be grouped together. // - Fill out the rest of the space with text tokens. - if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); ok { + if multimodalProcessor, ok := s.model.(model.MultimodalProcessor); prompt && ok { mmCtx := s.model.Backend().NewContext() defer mmCtx.Close() @@ -1058,10 +1063,10 @@ func (s *Server) reserveWorstCaseGraph() error { } } - if len(inputs) < s.batchSize { - newInputs := make([]*input.Input, s.batchSize) + if len(inputs) < batchSize { + newInputs := make([]*input.Input, batchSize) copy(newInputs, inputs) - for i := len(inputs); i < s.batchSize; i++ { + for i := len(inputs); i < batchSize; i++ { newInputs[i] = &input.Input{} } inputs = newInputs @@ -1160,7 +1165,12 @@ func (s *Server) allocModel( s.seqs = make([]*Sequence, s.parallel) s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) - return s.reserveWorstCaseGraph() + err = s.reserveWorstCaseGraph(true) + if err != nil { + return nil + } + + return s.reserveWorstCaseGraph(false) } // closeModel frees all memory associated with a model