|
|
|
@ -34,13 +34,15 @@ func cosineSimilarity[V float32 | float64](v1, v2 []V) V { |
|
|
|
func TestAllMiniLMEmbeddings(t *testing.T) { |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) |
|
|
|
defer cancel() |
|
|
|
client, _, cleanup := InitServerConnection(ctx, t) |
|
|
|
defer cleanup() |
|
|
|
|
|
|
|
req := api.EmbeddingRequest{ |
|
|
|
Model: "all-minilm", |
|
|
|
Prompt: "why is the sky blue?", |
|
|
|
} |
|
|
|
|
|
|
|
res, err := embeddingTestHelper(ctx, t, req) |
|
|
|
res, err := embeddingTestHelper(ctx, client, t, req) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
t.Fatalf("error: %v", err) |
|
|
|
@ -62,13 +64,15 @@ func TestAllMiniLMEmbeddings(t *testing.T) { |
|
|
|
func TestAllMiniLMEmbed(t *testing.T) { |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) |
|
|
|
defer cancel() |
|
|
|
client, _, cleanup := InitServerConnection(ctx, t) |
|
|
|
defer cleanup() |
|
|
|
|
|
|
|
req := api.EmbedRequest{ |
|
|
|
Model: "all-minilm", |
|
|
|
Input: "why is the sky blue?", |
|
|
|
} |
|
|
|
|
|
|
|
res, err := embedTestHelper(ctx, t, req) |
|
|
|
res, err := embedTestHelper(ctx, client, t, req) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
t.Fatalf("error: %v", err) |
|
|
|
@ -98,13 +102,15 @@ func TestAllMiniLMEmbed(t *testing.T) { |
|
|
|
func TestAllMiniLMBatchEmbed(t *testing.T) { |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) |
|
|
|
defer cancel() |
|
|
|
client, _, cleanup := InitServerConnection(ctx, t) |
|
|
|
defer cleanup() |
|
|
|
|
|
|
|
req := api.EmbedRequest{ |
|
|
|
Model: "all-minilm", |
|
|
|
Input: []string{"why is the sky blue?", "why is the grass green?"}, |
|
|
|
} |
|
|
|
|
|
|
|
res, err := embedTestHelper(ctx, t, req) |
|
|
|
res, err := embedTestHelper(ctx, client, t, req) |
|
|
|
|
|
|
|
if err != nil { |
|
|
|
t.Fatalf("error: %v", err) |
|
|
|
@ -144,6 +150,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) { |
|
|
|
func TestAllMiniLMEmbedTruncate(t *testing.T) { |
|
|
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) |
|
|
|
defer cancel() |
|
|
|
client, _, cleanup := InitServerConnection(ctx, t) |
|
|
|
defer cleanup() |
|
|
|
|
|
|
|
truncTrue, truncFalse := true, false |
|
|
|
|
|
|
|
@ -182,7 +190,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { |
|
|
|
res := make(map[string]*api.EmbedResponse) |
|
|
|
|
|
|
|
for _, req := range reqs { |
|
|
|
response, err := embedTestHelper(ctx, t, req.Request) |
|
|
|
response, err := embedTestHelper(ctx, client, t, req.Request) |
|
|
|
if err != nil { |
|
|
|
t.Fatalf("error: %v", err) |
|
|
|
} |
|
|
|
@ -198,7 +206,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { |
|
|
|
} |
|
|
|
|
|
|
|
// check that truncate set to false returns an error if context length is exceeded
|
|
|
|
_, err := embedTestHelper(ctx, t, api.EmbedRequest{ |
|
|
|
_, err := embedTestHelper(ctx, client, t, api.EmbedRequest{ |
|
|
|
Model: "all-minilm", |
|
|
|
Input: "why is the sky blue?", |
|
|
|
Truncate: &truncFalse, |
|
|
|
@ -210,9 +218,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) { |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { |
|
|
|
client, _, cleanup := InitServerConnection(ctx, t) |
|
|
|
defer cleanup() |
|
|
|
func embeddingTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) { |
|
|
|
if err := PullIfMissing(ctx, client, req.Model); err != nil { |
|
|
|
t.Fatalf("failed to pull model %s: %v", req.Model, err) |
|
|
|
} |
|
|
|
@ -226,9 +232,7 @@ func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingReq |
|
|
|
return response, nil |
|
|
|
} |
|
|
|
|
|
|
|
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { |
|
|
|
client, _, cleanup := InitServerConnection(ctx, t) |
|
|
|
defer cleanup() |
|
|
|
func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { |
|
|
|
if err := PullIfMissing(ctx, client, req.Model); err != nil { |
|
|
|
t.Fatalf("failed to pull model %s: %v", req.Model, err) |
|
|
|
} |
|
|
|
|