Browse Source

gpu: Group GPU Library sets by variant (#6483)

The recent cuda variant changes uncovered a bug in ByLibrary
which failed to group by common variant for GPU types.
pdevine/import-docs v0.3.7-rc6
Daniel Hiltgen 2 years ago
committed by GitHub
parent
commit
69be940bf6
No known key found for this signature in database GPG Key ID: B5690EEEBB952194
  1. 25
      gpu/gpu_test.go
  2. 2
      gpu/types.go

25
gpu/gpu_test.go

@ -32,4 +32,29 @@ func TestCPUMemInfo(t *testing.T) {
}
}
func TestByLibrary(t *testing.T) {
type testCase struct {
input []GpuInfo
expect int
}
testCases := map[string]*testCase{
"empty": {input: []GpuInfo{}, expect: 0},
"cpu": {input: []GpuInfo{{Library: "cpu"}}, expect: 1},
"cpu + GPU": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2},
"cpu + 2 GPU no variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2},
"cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2},
"cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3},
}
for k, v := range testCases {
t.Run(k, func(t *testing.T) {
resp := (GpuInfoList)(v.input).ByLibrary()
if len(resp) != v.expect {
t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp)
}
})
}
}
// TODO - add some logic to figure out card type through other means and actually verify we got back what we expected

2
gpu/types.go

@ -94,7 +94,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
}
}
if !found {
libs = append(libs, info.Library)
libs = append(libs, requested)
resp = append(resp, []GpuInfo{info})
}
}

Loading…
Cancel
Save