Compare commits

...

5 Commits

Author SHA1 Message Date
likelovewant
cb02c084eb revert to rocm 5.7 2024-07-14 09:54:14 +08:00
likelovewant
706449c10d Merge branch 'ollama:main' into main 2024-07-14 09:51:57 +08:00
jmorganca
f7ee012300 server: prepend system message in chat handler 2024-07-13 15:08:00 -07:00
likelovewant
90807b2ad0 Merge branch 'ollama:main' into main 2024-07-14 00:58:33 +08:00
Jeffrey Morgan
1ed0aa8fea server: fix context, load_duration and total_duration fields (#5676)
* server: fix `contet`, `load_duration` and `total_duration` fields

* Update server/routes.go
2024-07-13 09:25:31 -07:00
2 changed files with 51 additions and 11 deletions

View File

@@ -23,7 +23,7 @@ const (
var ( var (
// Used to validate if the given ROCm lib is usable // Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6 ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob? RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
) )
func AMDGetGPUInfo() []RocmGPUInfo { func AMDGetGPUInfo() []RocmGPUInfo {

View File

@@ -102,6 +102,7 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -129,6 +130,8 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
if req.Prompt == "" { if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model, Model: req.Model,
@@ -191,26 +194,48 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(cr llm.CompletionResponse) {
ch <- api.GenerateResponse{ res := api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: r.Content, Response: cr.Content,
Done: r.Done, Done: cr.Done,
DoneReason: r.DoneReason, DoneReason: cr.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount, PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration, PromptEvalDuration: cr.PromptEvalDuration,
EvalCount: r.EvalCount, EvalCount: cr.EvalCount,
EvalDuration: r.EvalDuration, EvalDuration: cr.EvalDuration,
}, },
} }
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = append(req.Context, tokens...)
}
}
ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
@@ -1122,6 +1147,8 @@ func (s *Server) ProcessHandler(c *gin.Context) {
} }
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1141,6 +1168,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model, Model: req.Model,
@@ -1152,6 +1181,10 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if req.Messages[0].Role != "system" {
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages) prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1169,7 +1202,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
ch <- api.ChatResponse{ res := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1182,6 +1215,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }