mirror of https://gitee.com/namelin2022/ollama
Browse Source
* Reapply "feat: incremental gguf parser (#10822)" (#11114)
This reverts commit a6e64fbdf2.
* fix older ggufs
mxyng/gguf
committed by
GitHub
13 changed files with 1362 additions and 169 deletions
@ -0,0 +1,347 @@ |
|||
package gguf |
|||
|
|||
import ( |
|||
"bytes" |
|||
"cmp" |
|||
"encoding/binary" |
|||
"errors" |
|||
"fmt" |
|||
"io" |
|||
"iter" |
|||
"os" |
|||
"slices" |
|||
"strings" |
|||
) |
|||
|
|||
const ( |
|||
typeUint8 uint32 = iota |
|||
typeInt8 |
|||
typeUint16 |
|||
typeInt16 |
|||
typeUint32 |
|||
typeInt32 |
|||
typeFloat32 |
|||
typeBool |
|||
typeString |
|||
typeArray |
|||
typeUint64 |
|||
typeInt64 |
|||
typeFloat64 |
|||
) |
|||
|
|||
var ErrUnsupported = errors.New("unsupported") |
|||
|
|||
type File struct { |
|||
Magic [4]byte |
|||
Version uint32 |
|||
|
|||
keyValues *lazy[KeyValue] |
|||
tensors *lazy[TensorInfo] |
|||
offset int64 |
|||
|
|||
file *os.File |
|||
reader *bufferedReader |
|||
bts []byte |
|||
} |
|||
|
|||
func Open(path string) (f *File, err error) { |
|||
f = &File{bts: make([]byte, 4096)} |
|||
f.file, err = os.Open(path) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
f.reader = newBufferedReader(f.file, 32<<10) |
|||
|
|||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Magic); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
if bytes.Equal(f.Magic[:], []byte("gguf")) { |
|||
return nil, fmt.Errorf("%w file type %v", ErrUnsupported, f.Magic) |
|||
} |
|||
|
|||
if err := binary.Read(f.reader, binary.LittleEndian, &f.Version); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
if f.Version < 2 { |
|||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version) |
|||
} |
|||
|
|||
f.tensors, err = newLazy(f, f.readTensor) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
f.tensors.successFunc = func() error { |
|||
offset := f.reader.offset |
|||
|
|||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32) |
|||
f.offset = offset + (alignment-offset%alignment)%alignment |
|||
return nil |
|||
} |
|||
|
|||
f.keyValues, err = newLazy(f, f.readKeyValue) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
return f, nil |
|||
} |
|||
|
|||
func (f *File) readTensor() (TensorInfo, error) { |
|||
name, err := readString(f) |
|||
if err != nil { |
|||
return TensorInfo{}, err |
|||
} |
|||
|
|||
dims, err := read[uint32](f) |
|||
if err != nil { |
|||
return TensorInfo{}, err |
|||
} |
|||
|
|||
shape := make([]uint64, dims) |
|||
for i := range dims { |
|||
shape[i], err = read[uint64](f) |
|||
if err != nil { |
|||
return TensorInfo{}, err |
|||
} |
|||
} |
|||
|
|||
type_, err := read[uint32](f) |
|||
if err != nil { |
|||
return TensorInfo{}, err |
|||
} |
|||
|
|||
offset, err := read[uint64](f) |
|||
if err != nil { |
|||
return TensorInfo{}, err |
|||
} |
|||
|
|||
return TensorInfo{ |
|||
Name: name, |
|||
Offset: offset, |
|||
Shape: shape, |
|||
Type: TensorType(type_), |
|||
}, nil |
|||
} |
|||
|
|||
func (f *File) readKeyValue() (KeyValue, error) { |
|||
key, err := readString(f) |
|||
if err != nil { |
|||
return KeyValue{}, err |
|||
} |
|||
|
|||
t, err := read[uint32](f) |
|||
if err != nil { |
|||
return KeyValue{}, err |
|||
} |
|||
|
|||
value, err := func() (any, error) { |
|||
switch t { |
|||
case typeUint8: |
|||
return read[uint8](f) |
|||
case typeInt8: |
|||
return read[int8](f) |
|||
case typeUint16: |
|||
return read[uint16](f) |
|||
case typeInt16: |
|||
return read[int16](f) |
|||
case typeUint32: |
|||
return read[uint32](f) |
|||
case typeInt32: |
|||
return read[int32](f) |
|||
case typeUint64: |
|||
return read[uint64](f) |
|||
case typeInt64: |
|||
return read[int64](f) |
|||
case typeFloat32: |
|||
return read[float32](f) |
|||
case typeFloat64: |
|||
return read[float64](f) |
|||
case typeBool: |
|||
return read[bool](f) |
|||
case typeString: |
|||
return readString(f) |
|||
case typeArray: |
|||
return readArray(f) |
|||
default: |
|||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) |
|||
} |
|||
}() |
|||
if err != nil { |
|||
return KeyValue{}, err |
|||
} |
|||
|
|||
return KeyValue{ |
|||
Key: key, |
|||
Value: Value{value}, |
|||
}, nil |
|||
} |
|||
|
|||
func read[T any](f *File) (t T, err error) { |
|||
err = binary.Read(f.reader, binary.LittleEndian, &t) |
|||
return t, err |
|||
} |
|||
|
|||
func readString(f *File) (string, error) { |
|||
n, err := read[uint64](f) |
|||
if err != nil { |
|||
return "", err |
|||
} |
|||
|
|||
if int(n) > len(f.bts) { |
|||
f.bts = make([]byte, n) |
|||
} |
|||
|
|||
bts := f.bts[:n] |
|||
if _, err := io.ReadFull(f.reader, bts); err != nil { |
|||
return "", err |
|||
} |
|||
defer clear(bts) |
|||
|
|||
return string(bts), nil |
|||
} |
|||
|
|||
func readArray(f *File) (any, error) { |
|||
t, err := read[uint32](f) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
n, err := read[uint64](f) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
switch t { |
|||
case typeUint8: |
|||
return readArrayData[uint8](f, n) |
|||
case typeInt8: |
|||
return readArrayData[int8](f, n) |
|||
case typeUint16: |
|||
return readArrayData[uint16](f, n) |
|||
case typeInt16: |
|||
return readArrayData[int16](f, n) |
|||
case typeUint32: |
|||
return readArrayData[uint32](f, n) |
|||
case typeInt32: |
|||
return readArrayData[int32](f, n) |
|||
case typeUint64: |
|||
return readArrayData[uint64](f, n) |
|||
case typeInt64: |
|||
return readArrayData[int64](f, n) |
|||
case typeFloat32: |
|||
return readArrayData[float32](f, n) |
|||
case typeFloat64: |
|||
return readArrayData[float64](f, n) |
|||
case typeBool: |
|||
return readArrayData[bool](f, n) |
|||
case typeString: |
|||
return readArrayString(f, n) |
|||
default: |
|||
return nil, fmt.Errorf("%w type %d", ErrUnsupported, t) |
|||
} |
|||
} |
|||
|
|||
func readArrayData[T any](f *File, n uint64) (s []T, err error) { |
|||
s = make([]T, n) |
|||
for i := range n { |
|||
e, err := read[T](f) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
s[i] = e |
|||
} |
|||
|
|||
return s, nil |
|||
} |
|||
|
|||
func readArrayString(f *File, n uint64) (s []string, err error) { |
|||
s = make([]string, n) |
|||
for i := range n { |
|||
e, err := readString(f) |
|||
if err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
s[i] = e |
|||
} |
|||
|
|||
return s, nil |
|||
} |
|||
|
|||
func (f *File) Close() error { |
|||
f.keyValues.stop() |
|||
f.tensors.stop() |
|||
return f.file.Close() |
|||
} |
|||
|
|||
func (f *File) KeyValue(key string) KeyValue { |
|||
if !strings.HasPrefix(key, "general.") && !strings.HasPrefix(key, "tokenizer.") { |
|||
key = f.KeyValue("general.architecture").String() + "." + key |
|||
} |
|||
|
|||
if index := slices.IndexFunc(f.keyValues.values, func(kv KeyValue) bool { |
|||
return kv.Key == key |
|||
}); index >= 0 { |
|||
return f.keyValues.values[index] |
|||
} |
|||
|
|||
for keyValue, ok := f.keyValues.next(); ok; keyValue, ok = f.keyValues.next() { |
|||
if keyValue.Key == key { |
|||
return keyValue |
|||
} |
|||
} |
|||
|
|||
return KeyValue{} |
|||
} |
|||
|
|||
func (f *File) NumKeyValues() int { |
|||
return int(f.keyValues.count) |
|||
} |
|||
|
|||
func (f *File) KeyValues() iter.Seq2[int, KeyValue] { |
|||
return f.keyValues.All() |
|||
} |
|||
|
|||
func (f *File) TensorInfo(name string) TensorInfo { |
|||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool { |
|||
return t.Name == name |
|||
}); index >= 0 { |
|||
return f.tensors.values[index] |
|||
} |
|||
|
|||
// fast-forward through key values if we haven't already
|
|||
_ = f.keyValues.rest() |
|||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() { |
|||
if tensor.Name == name { |
|||
return tensor |
|||
} |
|||
} |
|||
|
|||
return TensorInfo{} |
|||
} |
|||
|
|||
func (f *File) NumTensors() int { |
|||
return int(f.tensors.count) |
|||
} |
|||
|
|||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] { |
|||
// fast forward through key values if we haven't already
|
|||
f.keyValues.rest() |
|||
return f.tensors.All() |
|||
} |
|||
|
|||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) { |
|||
t := f.TensorInfo(name) |
|||
if t.NumBytes() == 0 { |
|||
return TensorInfo{}, nil, fmt.Errorf("tensor %s not found", name) |
|||
} |
|||
|
|||
// fast forward through tensor info if we haven't already
|
|||
_ = f.tensors.rest() |
|||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil |
|||
} |
|||
@ -0,0 +1,249 @@ |
|||
package gguf_test |
|||
|
|||
import ( |
|||
"bytes" |
|||
"os" |
|||
"strconv" |
|||
"strings" |
|||
"testing" |
|||
|
|||
"github.com/google/go-cmp/cmp" |
|||
"github.com/google/go-cmp/cmp/cmpopts" |
|||
"github.com/ollama/ollama/fs/ggml" |
|||
"github.com/ollama/ollama/fs/gguf" |
|||
) |
|||
|
|||
func createBinFile(tb testing.TB) string { |
|||
tb.Helper() |
|||
f, err := os.CreateTemp(tb.TempDir(), "") |
|||
if err != nil { |
|||
tb.Fatal(err) |
|||
} |
|||
defer f.Close() |
|||
|
|||
kv := ggml.KV{ |
|||
"general.architecture": "llama", |
|||
"llama.block_count": uint32(8), |
|||
"llama.embedding_length": uint32(3), |
|||
"llama.attention.head_count": uint32(2), |
|||
"llama.attention.head_count_kv": uint32(2), |
|||
"llama.attention.key_length": uint32(3), |
|||
"llama.rope.dimension_count": uint32(4), |
|||
"llama.rope.freq_base": float32(10000.0), |
|||
"llama.rope.freq_scale": float32(1.0), |
|||
"llama.attention.layer_norm_rms_epsilon": float32(1e-6), |
|||
"tokenizer.ggml.eos_token_id": uint32(0), |
|||
"tokenizer.ggml.eos_token_ids": []int32{1, 2, 3}, |
|||
"tokenizer.ggml.tokens": []string{"hello", "world"}, |
|||
"tokenizer.ggml.scores": []float32{0, 1}, |
|||
} |
|||
|
|||
tensors := []*ggml.Tensor{ |
|||
{ |
|||
Name: "token_embd.weight", |
|||
Kind: 0, |
|||
Shape: []uint64{2, 3}, |
|||
WriterTo: bytes.NewBuffer(make([]byte, 4*2*3)), |
|||
}, |
|||
{ |
|||
Name: "output.weight", |
|||
Kind: 0, |
|||
Shape: []uint64{3, 2}, |
|||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*2)), |
|||
}, |
|||
} |
|||
|
|||
for i := range 8 { |
|||
tensors = append(tensors, &ggml.Tensor{ |
|||
Name: "blk." + strconv.Itoa(i) + ".attn_q.weight", |
|||
Kind: 0, |
|||
Shape: []uint64{3, 3}, |
|||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), |
|||
}, &ggml.Tensor{ |
|||
Name: "blk." + strconv.Itoa(i) + ".attn_k.weight", |
|||
Kind: 0, |
|||
Shape: []uint64{3, 3}, |
|||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), |
|||
}, &ggml.Tensor{ |
|||
Name: "blk." + strconv.Itoa(i) + ".attn_v.weight", |
|||
Kind: 0, |
|||
Shape: []uint64{3, 3}, |
|||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), |
|||
}, &ggml.Tensor{ |
|||
Name: "blk." + strconv.Itoa(i) + ".attn_output.weight", |
|||
Kind: 0, |
|||
Shape: []uint64{3, 3}, |
|||
WriterTo: bytes.NewBuffer(make([]byte, 4*3*3)), |
|||
}) |
|||
} |
|||
|
|||
if err := ggml.WriteGGUF(f, kv, tensors); err != nil { |
|||
tb.Fatal(err) |
|||
} |
|||
|
|||
return f.Name() |
|||
} |
|||
|
|||
func TestRead(t *testing.T) { |
|||
f, err := gguf.Open(createBinFile(t)) |
|||
if err != nil { |
|||
t.Fatal(err) |
|||
} |
|||
defer f.Close() |
|||
|
|||
if got := f.KeyValue("does.not.exist").Valid(); got { |
|||
t.Errorf(`KeyValue("does.not.exist").Exists() = %v, want false`, got) |
|||
} |
|||
|
|||
if got := f.KeyValue("general.architecture").String(); got != "llama" { |
|||
t.Errorf(`KeyValue("general.architecture").String() = %q, want %q`, got, "llama") |
|||
} |
|||
|
|||
if got := f.TensorInfo("token_embd.weight"); got.Name != "token_embd.weight" { |
|||
t.Errorf(`TensorInfo("token_embd.weight").Name = %q, want %q`, got.Name, "token_embd.weight") |
|||
} else if diff := cmp.Diff(got.Shape, []uint64{2, 3}); diff != "" { |
|||
t.Errorf(`TensorInfo("token_embd.weight").Shape mismatch (-got +want):\n%s`, diff) |
|||
} else if got.Type != gguf.TensorTypeF32 { |
|||
t.Errorf(`TensorInfo("token_embd.weight").Type = %d, want %d`, got.Type, gguf.TensorTypeF32) |
|||
} |
|||
|
|||
if got := f.KeyValue("block_count").Uint(); got != 8 { |
|||
t.Errorf(`KeyValue("block_count").Uint() = %d, want %d`, got, 8) |
|||
} |
|||
|
|||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.tokens").Strings(), []string{"hello", "world"}); diff != "" { |
|||
t.Errorf("KeyValue(\"tokenizer.ggml.tokens\").Strings() mismatch (-got +want):\n%s", diff) |
|||
} |
|||
|
|||
if diff := cmp.Diff(f.KeyValue("tokenizer.ggml.scores").Floats(), []float64{0, 1}); diff != "" { |
|||
t.Errorf("KeyValue(\"tokenizer.ggml.scores\").Ints() mismatch (-got +want):\n%s", diff) |
|||
} |
|||
|
|||
var kvs []string |
|||
for _, kv := range f.KeyValues() { |
|||
if !kv.Valid() { |
|||
t.Error("found invalid key-value pair:", kv) |
|||
} |
|||
|
|||
kvs = append(kvs, kv.Key) |
|||
} |
|||
|
|||
if len(kvs) != f.NumKeyValues() { |
|||
t.Errorf("iterated key count = %d, want %d", len(kvs), f.NumKeyValues()) |
|||
} |
|||
|
|||
if diff := cmp.Diff(kvs, []string{ |
|||
"general.architecture", |
|||
"llama.block_count", |
|||
"llama.embedding_length", |
|||
"llama.attention.head_count", |
|||
"llama.attention.head_count_kv", |
|||
"llama.attention.key_length", |
|||
"llama.rope.dimension_count", |
|||
"llama.rope.freq_base", |
|||
"llama.rope.freq_scale", |
|||
"llama.attention.layer_norm_rms_epsilon", |
|||
"tokenizer.ggml.eos_token_id", |
|||
"tokenizer.ggml.eos_token_ids", |
|||
"tokenizer.ggml.tokens", |
|||
"tokenizer.ggml.scores", |
|||
}, cmpopts.SortSlices(strings.Compare)); diff != "" { |
|||
t.Errorf("KeyValues() mismatch (-got +want):\n%s", diff) |
|||
} |
|||
|
|||
var tis []string |
|||
for _, ti := range f.TensorInfos() { |
|||
if !ti.Valid() { |
|||
t.Error("found invalid tensor info:", ti) |
|||
} |
|||
|
|||
tis = append(tis, ti.Name) |
|||
} |
|||
|
|||
if len(tis) != f.NumTensors() { |
|||
t.Errorf("iterated tensor count = %d, want %d", len(tis), f.NumTensors()) |
|||
} |
|||
|
|||
if diff := cmp.Diff(tis, []string{ |
|||
"token_embd.weight", |
|||
"output.weight", |
|||
"blk.0.attn_q.weight", |
|||
"blk.0.attn_k.weight", |
|||
"blk.0.attn_v.weight", |
|||
"blk.0.attn_output.weight", |
|||
"blk.1.attn_q.weight", |
|||
"blk.1.attn_k.weight", |
|||
"blk.1.attn_v.weight", |
|||
"blk.1.attn_output.weight", |
|||
"blk.2.attn_q.weight", |
|||
"blk.2.attn_k.weight", |
|||
"blk.2.attn_v.weight", |
|||
"blk.2.attn_output.weight", |
|||
"blk.3.attn_q.weight", |
|||
"blk.3.attn_k.weight", |
|||
"blk.3.attn_v.weight", |
|||
"blk.3.attn_output.weight", |
|||
"blk.4.attn_q.weight", |
|||
"blk.4.attn_k.weight", |
|||
"blk.4.attn_v.weight", |
|||
"blk.4.attn_output.weight", |
|||
"blk.5.attn_q.weight", |
|||
"blk.5.attn_k.weight", |
|||
"blk.5.attn_v.weight", |
|||
"blk.5.attn_output.weight", |
|||
"blk.6.attn_q.weight", |
|||
"blk.6.attn_k.weight", |
|||
"blk.6.attn_v.weight", |
|||
"blk.6.attn_output.weight", |
|||
"blk.7.attn_q.weight", |
|||
"blk.7.attn_k.weight", |
|||
"blk.7.attn_v.weight", |
|||
"blk.7.attn_output.weight", |
|||
}, cmpopts.SortSlices(strings.Compare)); diff != "" { |
|||
t.Errorf("TensorInfos() mismatch (-got +want):\n%s", diff) |
|||
} |
|||
|
|||
ti, r, err := f.TensorReader("output.weight") |
|||
if err != nil { |
|||
t.Fatalf(`TensorReader("output.weight") error: %v`, err) |
|||
} |
|||
|
|||
if ti.Name != "output.weight" { |
|||
t.Errorf(`TensorReader("output.weight").Name = %q, want %q`, ti.Name, "output.weight") |
|||
} else if diff := cmp.Diff(ti.Shape, []uint64{3, 2}); diff != "" { |
|||
t.Errorf(`TensorReader("output.weight").Shape mismatch (-got +want):\n%s`, diff) |
|||
} else if ti.Type != gguf.TensorTypeF32 { |
|||
t.Errorf(`TensorReader("output.weight").Type = %d, want %d`, ti.Type, gguf.TensorTypeF32) |
|||
} |
|||
|
|||
var b bytes.Buffer |
|||
if _, err := b.ReadFrom(r); err != nil { |
|||
t.Fatalf(`ReadFrom TensorReader("output.weight") error: %v`, err) |
|||
} |
|||
|
|||
if b.Len() != int(ti.NumBytes()) { |
|||
t.Errorf(`ReadFrom TensorReader("output.weight") length = %d, want %d`, b.Len(), ti.NumBytes()) |
|||
} |
|||
} |
|||
|
|||
func BenchmarkRead(b *testing.B) { |
|||
b.ReportAllocs() |
|||
|
|||
p := createBinFile(b) |
|||
for b.Loop() { |
|||
f, err := gguf.Open(p) |
|||
if err != nil { |
|||
b.Fatal(err) |
|||
} |
|||
|
|||
if got := f.KeyValue("general.architecture").String(); got != "llama" { |
|||
b.Errorf("got = %q, want %q", got, "llama") |
|||
} |
|||
|
|||
// Iterate through some tensors
|
|||
for range f.TensorInfos() { |
|||
} |
|||
|
|||
f.Close() |
|||
} |
|||
} |
|||
@ -0,0 +1,90 @@ |
|||
package gguf |
|||
|
|||
import ( |
|||
"reflect" |
|||
"slices" |
|||
) |
|||
|
|||
type KeyValue struct { |
|||
Key string |
|||
Value |
|||
} |
|||
|
|||
func (kv KeyValue) Valid() bool { |
|||
return kv.Key != "" && kv.Value.value != nil |
|||
} |
|||
|
|||
type Value struct { |
|||
value any |
|||
} |
|||
|
|||
func value[T any](v Value, kinds ...reflect.Kind) (t T) { |
|||
vv := reflect.ValueOf(v.value) |
|||
if slices.Contains(kinds, vv.Kind()) { |
|||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T) |
|||
} |
|||
return |
|||
} |
|||
|
|||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) { |
|||
switch vv := reflect.ValueOf(v.value); vv.Kind() { |
|||
case reflect.Slice: |
|||
if slices.Contains(kinds, vv.Type().Elem().Kind()) { |
|||
ts = make([]T, vv.Len()) |
|||
for i := range vv.Len() { |
|||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T) |
|||
} |
|||
} |
|||
} |
|||
return |
|||
} |
|||
|
|||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
|||
func (v Value) Int() int64 { |
|||
return value[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64) |
|||
} |
|||
|
|||
// Ints returns Value as a signed integer slice. If it is not a signed integer slice, it returns nil.
|
|||
func (v Value) Ints() (i64s []int64) { |
|||
return values[int64](v, reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64) |
|||
} |
|||
|
|||
// Uint converts an unsigned integer value to uint64. If the value is not a unsigned integer, it returns 0.
|
|||
func (v Value) Uint() uint64 { |
|||
return value[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64) |
|||
} |
|||
|
|||
// Uints returns Value as a unsigned integer slice. If it is not a unsigned integer slice, it returns nil.
|
|||
func (v Value) Uints() (u64s []uint64) { |
|||
return values[uint64](v, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64) |
|||
} |
|||
|
|||
// Float returns Value as a float. If it is not a float, it returns 0.
|
|||
func (v Value) Float() float64 { |
|||
return value[float64](v, reflect.Float32, reflect.Float64) |
|||
} |
|||
|
|||
// Floats returns Value as a float slice. If it is not a float slice, it returns nil.
|
|||
func (v Value) Floats() (f64s []float64) { |
|||
return values[float64](v, reflect.Float32, reflect.Float64) |
|||
} |
|||
|
|||
// Bool returns Value as a boolean. If it is not a boolean, it returns false.
|
|||
func (v Value) Bool() bool { |
|||
return value[bool](v, reflect.Bool) |
|||
} |
|||
|
|||
// Bools returns Value as a boolean slice. If it is not a boolean slice, it returns nil.
|
|||
func (v Value) Bools() (bools []bool) { |
|||
return values[bool](v, reflect.Bool) |
|||
} |
|||
|
|||
// String returns Value as a string. If it is not a string, it returns an empty string.
|
|||
func (v Value) String() string { |
|||
return value[string](v, reflect.String) |
|||
} |
|||
|
|||
// Strings returns Value as a string slice. If it is not a string slice, it returns nil.
|
|||
func (v Value) Strings() (strings []string) { |
|||
return values[string](v, reflect.String) |
|||
} |
|||
@ -0,0 +1,208 @@ |
|||
package gguf |
|||
|
|||
import ( |
|||
"testing" |
|||
|
|||
"github.com/google/go-cmp/cmp" |
|||
) |
|||
|
|||
func split(name string, values map[string][]any) (matched []any, unmatched []any) { |
|||
for key, value := range values { |
|||
if key == name { |
|||
matched = value |
|||
} else { |
|||
unmatched = append(unmatched, value...) |
|||
} |
|||
} |
|||
return |
|||
} |
|||
|
|||
func TestValue(t *testing.T) { |
|||
values := map[string][]any{ |
|||
"int64": {int(42), int8(42), int16(42), int32(42), int64(42)}, |
|||
"uint64": {uint(42), uint8(42), uint16(42), uint32(42), uint64(42)}, |
|||
"float64": {float32(42), float64(42)}, |
|||
"string": {"42", "hello"}, |
|||
"bool": {true, false}, |
|||
} |
|||
|
|||
t.Run("int64", func(t *testing.T) { |
|||
matched, unmatched := split("int64", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if i64 := kv.Int(); i64 != 42 { |
|||
t.Errorf("expected 42, got %d", i64) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if i64 := kv.Int(); i64 != 0 { |
|||
t.Errorf("expected 42, got %d", i64) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("uint64", func(t *testing.T) { |
|||
matched, unmatched := split("uint64", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if u64 := kv.Uint(); u64 != 42 { |
|||
t.Errorf("expected 42, got %d", u64) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if u64 := kv.Uint(); u64 != 0 { |
|||
t.Errorf("expected 42, got %d", u64) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("float64", func(t *testing.T) { |
|||
matched, unmatched := split("float64", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if f64 := kv.Float(); f64 != 42 { |
|||
t.Errorf("expected 42, got %f", f64) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if f64 := kv.Float(); f64 != 0 { |
|||
t.Errorf("expected 42, got %f", f64) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("string", func(t *testing.T) { |
|||
matched, unmatched := split("string", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if s := kv.String(); s != v { |
|||
t.Errorf("expected 42, got %s", s) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if s := kv.String(); s != "" { |
|||
t.Errorf("expected 42, got %s", s) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("bool", func(t *testing.T) { |
|||
matched, unmatched := split("bool", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if b := kv.Bool(); b != v { |
|||
t.Errorf("expected true, got %v", b) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if b := kv.Bool(); b != false { |
|||
t.Errorf("expected false, got %v", b) |
|||
} |
|||
} |
|||
}) |
|||
} |
|||
|
|||
func TestValues(t *testing.T) { |
|||
values := map[string][]any{ |
|||
"int64s": {[]int{42}, []int8{42}, []int16{42}, []int32{42}, []int64{42}}, |
|||
"uint64s": {[]uint{42}, []uint8{42}, []uint16{42}, []uint32{42}, []uint64{42}}, |
|||
"float64s": {[]float32{42}, []float64{42}}, |
|||
"strings": {[]string{"42"}, []string{"hello"}}, |
|||
"bools": {[]bool{true}, []bool{false}}, |
|||
} |
|||
|
|||
t.Run("int64s", func(t *testing.T) { |
|||
matched, unmatched := split("int64s", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if diff := cmp.Diff(kv.Ints(), []int64{42}); diff != "" { |
|||
t.Errorf("diff: %s", diff) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if i64s := kv.Ints(); i64s != nil { |
|||
t.Errorf("expected nil, got %v", i64s) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("uint64s", func(t *testing.T) { |
|||
matched, unmatched := split("uint64s", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if diff := cmp.Diff(kv.Uints(), []uint64{42}); diff != "" { |
|||
t.Errorf("diff: %s", diff) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if u64s := kv.Uints(); u64s != nil { |
|||
t.Errorf("expected nil, got %v", u64s) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("float64s", func(t *testing.T) { |
|||
matched, unmatched := split("float64s", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if diff := cmp.Diff(kv.Floats(), []float64{42}); diff != "" { |
|||
t.Errorf("diff: %s", diff) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if f64s := kv.Floats(); f64s != nil { |
|||
t.Errorf("expected nil, got %v", f64s) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("strings", func(t *testing.T) { |
|||
matched, unmatched := split("strings", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if diff := cmp.Diff(kv.Strings(), v); diff != "" { |
|||
t.Errorf("diff: %s", diff) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if s := kv.Strings(); s != nil { |
|||
t.Errorf("expected nil, got %v", s) |
|||
} |
|||
} |
|||
}) |
|||
|
|||
t.Run("bools", func(t *testing.T) { |
|||
matched, unmatched := split("bools", values) |
|||
for _, v := range matched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if diff := cmp.Diff(kv.Bools(), v); diff != "" { |
|||
t.Errorf("diff: %s", diff) |
|||
} |
|||
} |
|||
|
|||
for _, v := range unmatched { |
|||
kv := KeyValue{"key", Value{v}} |
|||
if b := kv.Bools(); b != nil { |
|||
t.Errorf("expected nil, got %v", b) |
|||
} |
|||
} |
|||
}) |
|||
} |
|||
@ -0,0 +1,89 @@ |
|||
package gguf |
|||
|
|||
import ( |
|||
"encoding/binary" |
|||
"iter" |
|||
"log/slog" |
|||
) |
|||
|
|||
type lazy[T any] struct { |
|||
count uint64 |
|||
next func() (T, bool) |
|||
stop func() |
|||
values []T |
|||
|
|||
// successFunc is called when all values have been successfully read.
|
|||
successFunc func() error |
|||
} |
|||
|
|||
func newLazy[T any](f *File, fn func() (T, error)) (*lazy[T], error) { |
|||
it := lazy[T]{} |
|||
if err := binary.Read(f.reader, binary.LittleEndian, &it.count); err != nil { |
|||
return nil, err |
|||
} |
|||
|
|||
it.values = make([]T, 0) |
|||
it.next, it.stop = iter.Pull(func(yield func(T) bool) { |
|||
for i := range it.count { |
|||
t, err := fn() |
|||
if err != nil { |
|||
slog.Error("error reading tensor", "index", i, "error", err) |
|||
return |
|||
} |
|||
|
|||
it.values = append(it.values, t) |
|||
if !yield(t) { |
|||
break |
|||
} |
|||
} |
|||
|
|||
if it.successFunc != nil { |
|||
it.successFunc() |
|||
} |
|||
}) |
|||
|
|||
return &it, nil |
|||
} |
|||
|
|||
func (g *lazy[T]) Values() iter.Seq[T] { |
|||
return func(yield func(T) bool) { |
|||
for _, v := range g.All() { |
|||
if !yield(v) { |
|||
break |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (g *lazy[T]) All() iter.Seq2[int, T] { |
|||
return func(yield func(int, T) bool) { |
|||
for i := range int(g.count) { |
|||
if i < len(g.values) { |
|||
if !yield(i, g.values[i]) { |
|||
break |
|||
} |
|||
} else { |
|||
t, ok := g.next() |
|||
if !ok { |
|||
break |
|||
} |
|||
|
|||
if !yield(i, t) { |
|||
break |
|||
} |
|||
} |
|||
} |
|||
} |
|||
} |
|||
|
|||
func (g *lazy[T]) rest() (collected bool) { |
|||
for { |
|||
_, ok := g.next() |
|||
collected = collected || ok |
|||
if !ok { |
|||
break |
|||
} |
|||
} |
|||
|
|||
return collected |
|||
} |
|||
@ -0,0 +1,23 @@ |
|||
package gguf |
|||
|
|||
import ( |
|||
"bufio" |
|||
"io" |
|||
) |
|||
|
|||
type bufferedReader struct { |
|||
offset int64 |
|||
*bufio.Reader |
|||
} |
|||
|
|||
func newBufferedReader(rs io.ReadSeeker, size int) *bufferedReader { |
|||
return &bufferedReader{ |
|||
Reader: bufio.NewReaderSize(rs, size), |
|||
} |
|||
} |
|||
|
|||
func (rs *bufferedReader) Read(p []byte) (n int, err error) { |
|||
n, err = rs.Reader.Read(p) |
|||
rs.offset += int64(n) |
|||
return n, err |
|||
} |
|||
@ -0,0 +1,288 @@ |
|||
package gguf |
|||
|
|||
import ( |
|||
"log/slog" |
|||
"strings" |
|||
) |
|||
|
|||
type TensorInfo struct { |
|||
Name string |
|||
Offset uint64 |
|||
Shape []uint64 |
|||
Type TensorType |
|||
} |
|||
|
|||
func (ti TensorInfo) Valid() bool { |
|||
return ti.Name != "" && ti.NumBytes() > 0 |
|||
} |
|||
|
|||
func (ti TensorInfo) NumValues() int64 { |
|||
var numItems int64 = 1 |
|||
for _, dim := range ti.Shape { |
|||
numItems *= int64(dim) |
|||
} |
|||
return numItems |
|||
} |
|||
|
|||
// NumBytes returns the number of bytes in the tensor.
|
|||
func (ti TensorInfo) NumBytes() int64 { |
|||
return int64(float64(ti.NumValues()) * ti.Type.NumBytes()) |
|||
} |
|||
|
|||
func (ti TensorInfo) LogValue() slog.Value { |
|||
return slog.GroupValue( |
|||
slog.String("name", ti.Name), |
|||
slog.Int64("offset", int64(ti.Offset)), |
|||
slog.Any("shape", ti.Shape), |
|||
slog.Int64("num_values", ti.NumValues()), |
|||
slog.Int64("num_bytes", ti.NumBytes()), |
|||
slog.Any("type", ti.Type), |
|||
) |
|||
} |
|||
|
|||
type TensorType uint32 |
|||
|
|||
const ( |
|||
TensorTypeF32 TensorType = iota |
|||
TensorTypeF16 |
|||
TensorTypeQ4_0 |
|||
TensorTypeQ4_1 |
|||
|
|||
// unexported // unused in gguf
|
|||
tensorTypeQ4_2 |
|||
tensorTypeQ4_3 |
|||
|
|||
TensorTypeQ5_0 |
|||
TensorTypeQ5_1 |
|||
TensorTypeQ8_0 |
|||
TensorTypeQ8_1 |
|||
TensorTypeQ2_K |
|||
TensorTypeQ3_K |
|||
TensorTypeQ4_K |
|||
TensorTypeQ5_K |
|||
TensorTypeQ6_K |
|||
TensorTypeQ8_K |
|||
|
|||
// unexported // unquantizable by ollama
|
|||
tensorTypeIQ2_XXS |
|||
tensorTypeIQ2_XS |
|||
tensorTypeIQ3_XXS |
|||
tensorTypeIQ1_S |
|||
tensorTypeIQ4_NL |
|||
tensorTypeIQ3_S |
|||
tensorTypeIQ2_S |
|||
tensorTypeIQ4_XS |
|||
|
|||
TensorTypeI8 |
|||
TensorTypeI16 |
|||
TensorTypeI32 |
|||
TensorTypeI64 |
|||
TensorTypeF64 |
|||
|
|||
// unexported // unquantizable by ollama
|
|||
tensorTypeIQ1_M |
|||
|
|||
TensorTypeBF16 |
|||
|
|||
// unexported // unused in gguf
|
|||
tensorTypeQ4_0_4_4 |
|||
tensorTypeQ4_0_4_8 |
|||
tensorTypeQ4_0_8_8 |
|||
|
|||
// unexported // unquantizable by ollama
|
|||
tensorTypeTQ1_0 |
|||
tensorTypeTQ2_0 |
|||
|
|||
// unexported // unused in gguf
|
|||
tensorTypeIQ4_NL_4_4 |
|||
tensorTypeIQ4_NL_4_8 |
|||
tensorTypeIQ4_NL_8_8 |
|||
) |
|||
|
|||
func (tt TensorType) NumBytes() float64 { |
|||
return float64(tt.typeSize()) / float64(tt.blockSize()) |
|||
} |
|||
|
|||
func (tt TensorType) typeSize() int64 { |
|||
switch tt { |
|||
case TensorTypeF32: |
|||
return 4 |
|||
case TensorTypeF16: |
|||
return 2 |
|||
case TensorTypeQ4_0: |
|||
return 2 + tt.blockSize()/2 |
|||
case TensorTypeQ4_1: |
|||
return 2 + 2 + tt.blockSize()/2 |
|||
case TensorTypeQ5_0: |
|||
return 2 + 4 + tt.blockSize()/2 |
|||
case TensorTypeQ5_1: |
|||
return 2 + 2 + 4 + tt.blockSize()/2 |
|||
case TensorTypeQ8_0: |
|||
return 2 + tt.blockSize() |
|||
case TensorTypeQ8_1: |
|||
return 2 + 2 + tt.blockSize() |
|||
case TensorTypeQ2_K: |
|||
return tt.blockSize()/16 + tt.blockSize()/4 + 2 + 2 |
|||
case TensorTypeQ3_K: |
|||
return tt.blockSize()/8 + tt.blockSize()/4 + 12 + 2 |
|||
case TensorTypeQ4_K: |
|||
return 2 + 2 + 12 + tt.blockSize()/2 |
|||
case TensorTypeQ5_K: |
|||
return 2 + 2 + 12 + tt.blockSize()/8 + tt.blockSize()/2 |
|||
case TensorTypeQ6_K: |
|||
return tt.blockSize()/2 + tt.blockSize()/4 + tt.blockSize()/16 + 2 |
|||
case TensorTypeQ8_K: |
|||
return 4 + tt.blockSize() + 2*tt.blockSize()/16 |
|||
case tensorTypeIQ2_XXS: |
|||
return 2 + 2*tt.blockSize()/8 |
|||
case tensorTypeIQ2_XS: |
|||
return 2 + 2*tt.blockSize()/8 + tt.blockSize()/32 |
|||
case tensorTypeIQ3_XXS: |
|||
return 2 + tt.blockSize()/4 + tt.blockSize()/8 |
|||
case tensorTypeIQ1_S: |
|||
return 2 + tt.blockSize()/8 + tt.blockSize()/16 |
|||
case tensorTypeIQ4_NL: |
|||
return 2 + tt.blockSize()/2 |
|||
case tensorTypeIQ3_S: |
|||
return 2 + tt.blockSize()/4 + tt.blockSize()/8 + tt.blockSize()/32 + 4 |
|||
case tensorTypeIQ2_S: |
|||
return 2 + tt.blockSize()/4 + tt.blockSize()/16 |
|||
case tensorTypeIQ4_XS: |
|||
return 2 + 2 + tt.blockSize()/2 + tt.blockSize()/64 |
|||
case TensorTypeI8: |
|||
return 1 |
|||
case TensorTypeI16: |
|||
return 2 |
|||
case TensorTypeI32: |
|||
return 4 |
|||
case TensorTypeI64: |
|||
return 8 |
|||
case TensorTypeF64: |
|||
return 8 |
|||
case tensorTypeIQ1_M: |
|||
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32 |
|||
case TensorTypeBF16: |
|||
return 2 |
|||
default: |
|||
return 0 |
|||
} |
|||
} |
|||
|
|||
func (tt TensorType) blockSize() int64 { |
|||
switch tt { |
|||
case TensorTypeF32, |
|||
TensorTypeF16, |
|||
TensorTypeI8, |
|||
TensorTypeI16, |
|||
TensorTypeI32, |
|||
TensorTypeI64, |
|||
TensorTypeF64, |
|||
TensorTypeBF16: |
|||
return 1 |
|||
case TensorTypeQ4_0, |
|||
TensorTypeQ4_1, |
|||
TensorTypeQ5_0, |
|||
TensorTypeQ5_1, |
|||
TensorTypeQ8_0, |
|||
TensorTypeQ8_1, |
|||
tensorTypeIQ4_NL: |
|||
return 32 |
|||
default: |
|||
return 256 |
|||
} |
|||
} |
|||
|
|||
func (tt TensorType) String() string { |
|||
switch tt { |
|||
case TensorTypeF32: |
|||
return "f32" |
|||
case TensorTypeF16: |
|||
return "f16" |
|||
case TensorTypeQ4_0: |
|||
return "q4_0" |
|||
case TensorTypeQ4_1: |
|||
return "q4_1" |
|||
case tensorTypeQ4_2: |
|||
return "q4_2" |
|||
case tensorTypeQ4_3: |
|||
return "q4_3" |
|||
case TensorTypeQ5_0: |
|||
return "q5_0" |
|||
case TensorTypeQ5_1: |
|||
return "q5_1" |
|||
case TensorTypeQ8_0: |
|||
return "q8_0" |
|||
case TensorTypeQ8_1: |
|||
return "q8_1" |
|||
case TensorTypeQ2_K: |
|||
return "q2_k" |
|||
case TensorTypeQ3_K: |
|||
return "q3_k" |
|||
case TensorTypeQ4_K: |
|||
return "q4_k" |
|||
case TensorTypeQ5_K: |
|||
return "q5_k" |
|||
case TensorTypeQ6_K: |
|||
return "q6_k" |
|||
case TensorTypeQ8_K: |
|||
return "q8_k" |
|||
case tensorTypeIQ2_XXS: |
|||
return "iq2_xxs" |
|||
case tensorTypeIQ2_XS: |
|||
return "iq2_xs" |
|||
case tensorTypeIQ3_XXS: |
|||
return "iq3_xxs" |
|||
case tensorTypeIQ1_S: |
|||
return "iq1_s" |
|||
case tensorTypeIQ4_NL: |
|||
return "iq4_nl" |
|||
case tensorTypeIQ3_S: |
|||
return "iq3_s" |
|||
case tensorTypeIQ2_S: |
|||
return "iq2_s" |
|||
case tensorTypeIQ4_XS: |
|||
return "iq4_xs" |
|||
case TensorTypeI8: |
|||
return "i8" |
|||
case TensorTypeI16: |
|||
return "i16" |
|||
case TensorTypeI32: |
|||
return "i32" |
|||
case TensorTypeI64: |
|||
return "i64" |
|||
case TensorTypeF64: |
|||
return "f64" |
|||
case tensorTypeIQ1_M: |
|||
return "iq1_m" |
|||
case TensorTypeBF16: |
|||
return "bf16" |
|||
case tensorTypeQ4_0_4_4: |
|||
return "q4_0_4_4" |
|||
case tensorTypeQ4_0_4_8: |
|||
return "q4_0_4_8" |
|||
case tensorTypeQ4_0_8_8: |
|||
return "q4_0_8_8" |
|||
case tensorTypeTQ1_0: |
|||
return "tq1_0" |
|||
case tensorTypeTQ2_0: |
|||
return "tq2_0" |
|||
case tensorTypeIQ4_NL_4_4: |
|||
return "iq4_nl_4_4" |
|||
case tensorTypeIQ4_NL_4_8: |
|||
return "iq4_nl_4_8" |
|||
case tensorTypeIQ4_NL_8_8: |
|||
return "iq4_nl_8_8" |
|||
default: |
|||
return "unknown" |
|||
} |
|||
} |
|||
|
|||
func (tt TensorType) LogValue() slog.Value { |
|||
return slog.GroupValue( |
|||
slog.Uint64("value", uint64(tt)), |
|||
slog.String("name", strings.ToUpper(tt.String())), |
|||
slog.Int64("size", tt.typeSize()), |
|||
slog.Int64("block_size", tt.blockSize()), |
|||
slog.Float64("num_bytes", tt.NumBytes()), |
|||
) |
|||
} |
|||
Loading…
Reference in new issue