|
|
|
@ -9,6 +9,8 @@ import ( |
|
|
|
"strings" |
|
|
|
"testing" |
|
|
|
|
|
|
|
"github.com/google/go-cmp/cmp" |
|
|
|
"github.com/ollama/ollama/fs/ggml" |
|
|
|
"github.com/pdevine/tensor" |
|
|
|
) |
|
|
|
|
|
|
|
@ -302,3 +304,99 @@ func TestSplitDim(t *testing.T) { |
|
|
|
} |
|
|
|
}) |
|
|
|
} |
|
|
|
|
|
|
|
func TestMerge(t *testing.T) { |
|
|
|
unmatched := []Tensor{ |
|
|
|
&fakeTensor{ |
|
|
|
name: "a.0.b", |
|
|
|
shape: []uint64{5, 2}, |
|
|
|
data: []float32{10, 11, 12, 13, 14, 15, 16, 17, 18, 19}, |
|
|
|
}, |
|
|
|
&fakeTensor{ |
|
|
|
name: "a.1.b", |
|
|
|
shape: []uint64{5, 2}, |
|
|
|
data: []float32{20, 21, 22, 23, 24, 25, 26, 27, 28, 29}, |
|
|
|
}, |
|
|
|
&fakeTensor{ |
|
|
|
name: "c.0.d", |
|
|
|
shape: []uint64{5, 2}, |
|
|
|
data: []float32{30, 31, 32, 33, 34, 35, 36, 37, 38, 39}, |
|
|
|
}, |
|
|
|
&fakeTensor{ |
|
|
|
name: "c.1.d", |
|
|
|
shape: []uint64{5, 2}, |
|
|
|
data: []float32{40, 41, 42, 43, 44, 45, 46, 47, 48, 49}, |
|
|
|
}, |
|
|
|
&fakeTensor{ |
|
|
|
name: "e.0.f", |
|
|
|
shape: []uint64{5, 2}, |
|
|
|
data: []float32{50, 51, 52, 53, 54, 55, 56, 57, 58, 59}, |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
checkMatched := func(t *testing.T, n int, matched []*ggml.Tensor) { |
|
|
|
for i := range n { |
|
|
|
got := matched[i] |
|
|
|
if diff := cmp.Diff([]uint64{2, 5, 2}, got.Shape); diff != "" { |
|
|
|
t.Errorf("unexpected (-want +got):\n%s", diff) |
|
|
|
} |
|
|
|
|
|
|
|
var b bytes.Buffer |
|
|
|
if _, err := got.WriteTo(&b); err != nil { |
|
|
|
t.Fatal(err) |
|
|
|
} |
|
|
|
|
|
|
|
f32s := make([]float32, 20) |
|
|
|
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil { |
|
|
|
t.Fatal(err) |
|
|
|
} |
|
|
|
|
|
|
|
offset := 10 + (i * 20) |
|
|
|
want := make([]float32, 20) |
|
|
|
for j := range 20 { |
|
|
|
want[j] = float32(offset + j) |
|
|
|
} |
|
|
|
|
|
|
|
if diff := cmp.Diff(want, f32s); diff != "" { |
|
|
|
t.Errorf("unexpected data (-want +got):\n%s", diff) |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
t.Run("single merge", func(t *testing.T) { |
|
|
|
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}) |
|
|
|
if len(unmatched) != 3 { |
|
|
|
t.Error("expected 3 remaining tensors, got", len(unmatched)) |
|
|
|
} |
|
|
|
|
|
|
|
if len(matched) != 1 { |
|
|
|
t.Error("expected 1 merged tensor, got", len(matched)) |
|
|
|
} |
|
|
|
|
|
|
|
checkMatched(t, 1, matched) |
|
|
|
}) |
|
|
|
|
|
|
|
t.Run("multiple merges", func(t *testing.T) { |
|
|
|
matched, unmatched := mergeTensors(unmatched, merge{"a.*.b", "a.b"}, merge{"c.*.d", "c.d"}) |
|
|
|
if len(unmatched) != 1 { |
|
|
|
t.Error("expected 1 remaining tensors, got", len(unmatched)) |
|
|
|
} |
|
|
|
|
|
|
|
if len(matched) != 2 { |
|
|
|
t.Error("expected 2 merged tensor, got", len(matched)) |
|
|
|
} |
|
|
|
|
|
|
|
checkMatched(t, 2, matched) |
|
|
|
}) |
|
|
|
|
|
|
|
t.Run("no match", func(t *testing.T) { |
|
|
|
matched, unmatched := mergeTensors(unmatched, merge{"x.*.y", "x.y"}) |
|
|
|
if len(unmatched) != 5 { |
|
|
|
t.Error("expected 5 remaining tensors, got", len(unmatched)) |
|
|
|
} |
|
|
|
|
|
|
|
if len(matched) != 0 { |
|
|
|
t.Error("expected no merged tensors, got", len(matched)) |
|
|
|
} |
|
|
|
}) |
|
|
|
} |
|
|
|
|