|
|
|
@ -156,6 +156,54 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string] |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
func ChatModelHandler(c *gin.Context) { |
|
|
|
loaded.mu.Lock() |
|
|
|
defer loaded.mu.Unlock() |
|
|
|
|
|
|
|
var req api.ChatRequest |
|
|
|
if err := c.ShouldBindJSON(&req); err != nil { |
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
model, err := GetModel(req.Model) |
|
|
|
if err != nil { |
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
prompt, err := model.ChatPrompt(req.Messages) |
|
|
|
if err != nil { |
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
var response string |
|
|
|
fn := func(r api.GenerateResponse) { |
|
|
|
response += r.Response |
|
|
|
} |
|
|
|
|
|
|
|
workDir := c.GetString("workDir") |
|
|
|
if err := load(c.Request.Context(), workDir, model, nil, defaultSessionDuration); err != nil { |
|
|
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) |
|
|
|
return |
|
|
|
} |
|
|
|
|
|
|
|
fmt.Println(prompt) |
|
|
|
|
|
|
|
if err := loaded.llm.Predict(c.Request.Context(), []int{}, prompt, fn); err != nil { |
|
|
|
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) |
|
|
|
} |
|
|
|
|
|
|
|
c.JSON(http.StatusOK, api.ChatResponse{ |
|
|
|
Message: api.Message{ |
|
|
|
Role: "assistant", |
|
|
|
Content: response, |
|
|
|
}, |
|
|
|
CreatedAt: time.Now().UTC(), |
|
|
|
}) |
|
|
|
} |
|
|
|
|
|
|
|
func GenerateHandler(c *gin.Context) { |
|
|
|
loaded.mu.Lock() |
|
|
|
defer loaded.mu.Unlock() |
|
|
|
@ -552,6 +600,7 @@ func Serve(ln net.Listener, allowOrigins []string) error { |
|
|
|
}, |
|
|
|
) |
|
|
|
|
|
|
|
r.POST("/api/chat", ChatModelHandler) |
|
|
|
r.POST("/api/pull", PullModelHandler) |
|
|
|
r.POST("/api/generate", GenerateHandler) |
|
|
|
r.POST("/api/embeddings", EmbeddingHandler) |
|
|
|
|