|
|
|
@ -125,6 +125,55 @@ func StartServer(ctx context.Context, ollamaHost string) error { |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error { |
|
|
|
slog.Debug("checking status of model", "model", modelName) |
|
|
|
showReq := &api.ShowRequest{Name: modelName} |
|
|
|
requestJSON, err := json.Marshal(showReq) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/show", bytes.NewReader(requestJSON)) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
// Make the request with the HTTP client
|
|
|
|
response, err := client.Do(req.WithContext(ctx)) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
defer response.Body.Close() |
|
|
|
if response.StatusCode == 200 { |
|
|
|
slog.Info("model already present", "model", modelName) |
|
|
|
return nil |
|
|
|
} |
|
|
|
slog.Info("model missing", "status", response.StatusCode) |
|
|
|
|
|
|
|
pullReq := &api.PullRequest{Name: modelName, Stream: &stream} |
|
|
|
requestJSON, err = json.Marshal(pullReq) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
|
|
|
|
req, err = http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/pull", bytes.NewReader(requestJSON)) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
slog.Info("pulling", "model", modelName) |
|
|
|
|
|
|
|
response, err = client.Do(req.WithContext(ctx)) |
|
|
|
if err != nil { |
|
|
|
return err |
|
|
|
} |
|
|
|
defer response.Body.Close() |
|
|
|
if response.StatusCode != 200 { |
|
|
|
return fmt.Errorf("failed to pull model") // TODO more details perhaps
|
|
|
|
} |
|
|
|
slog.Info("model pulled", "model", modelName) |
|
|
|
return nil |
|
|
|
} |
|
|
|
|
|
|
|
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) { |
|
|
|
requestJSON, err := json.Marshal(genReq) |
|
|
|
if err != nil { |
|
|
|
@ -158,6 +207,11 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, |
|
|
|
assert.NoError(t, StartServer(ctx, testEndpoint)) |
|
|
|
} |
|
|
|
|
|
|
|
err = PullIfMissing(ctx, client, scheme, testEndpoint, genReq.Model) |
|
|
|
if err != nil { |
|
|
|
t.Fatalf("Error pulling model: %v", err) |
|
|
|
} |
|
|
|
|
|
|
|
// Make the request and get the response
|
|
|
|
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON)) |
|
|
|
if err != nil { |
|
|
|
@ -172,6 +226,7 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, |
|
|
|
if err != nil { |
|
|
|
t.Fatalf("Error making request: %v", err) |
|
|
|
} |
|
|
|
defer response.Body.Close() |
|
|
|
body, err := io.ReadAll(response.Body) |
|
|
|
assert.NoError(t, err) |
|
|
|
assert.Equal(t, response.StatusCode, 200, string(body)) |
|
|
|
@ -184,7 +239,12 @@ func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, |
|
|
|
} |
|
|
|
|
|
|
|
// Verify the response contains the expected data
|
|
|
|
atLeastOne := false |
|
|
|
for _, resp := range anyResp { |
|
|
|
assert.Contains(t, strings.ToLower(payload.Response), resp) |
|
|
|
if strings.Contains(strings.ToLower(payload.Response), resp) { |
|
|
|
atLeastOne = true |
|
|
|
break |
|
|
|
} |
|
|
|
} |
|
|
|
assert.True(t, atLeastOne, "none of %v found in %s", anyResp, payload.Response) |
|
|
|
} |
|
|
|
|