ollamarunner: fix deadlock

hardErrCh will deadlock since forwardBatch is blocked on
computeStartedCh which never gets sent. since the response to
hardErrCh is to panic, just panic instead
This commit is contained in:
Michael Yang
2025-10-10 16:38:12 -07:00
committed by Michael Yang
parent aab2190420
commit 1a2feb2a97

View File

@@ -321,9 +321,6 @@ type Server struct {
// TODO (jmorganca): make this n_batch // TODO (jmorganca): make this n_batch
batchSize int batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches // Simple counter used only for trace logging batches
batchID int batchID int
@@ -411,8 +408,6 @@ func (s *Server) run(ctx context.Context) {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return return
case err := <-s.hardErrCh:
panic(err)
default: default:
var err error var err error
nextBatch, err := s.forwardBatch(previousBatch) nextBatch, err := s.forwardBatch(previousBatch)
@@ -663,9 +658,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// don't sample prompt processing // don't sample prompt processing
if len(seq.inputs) != 0 { if len(seq.inputs) != 0 {
if !s.cache.enabled { if !s.cache.enabled {
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch") panic("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
} }
continue continue
} }
@@ -720,8 +713,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) panic("failed to sample token")
return
} }
nextBatchTokens[i].Token = token nextBatchTokens[i].Token = token
@@ -738,8 +730,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
piece, err := s.model.(model.TextProcessor).Decode([]int32{token}) piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err) panic("failed to decode token")
return
} }
seq.pendingResponses = append(seq.pendingResponses, piece) seq.pendingResponses = append(seq.pendingResponses, piece)
@@ -1321,7 +1312,6 @@ func Execute(args []string) error {
server := &Server{ server := &Server{
modelPath: *mpath, modelPath: *mpath,
status: llm.ServerStatusLaunched, status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
} }
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)