You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
59 lines
1.8 KiB
59 lines
1.8 KiB
#ifndef _RISCV_ZVBDOT_H
|
|
#define _RISCV_ZVBDOT_H
|
|
|
|
#include "bulknormdot.h"
|
|
#include <vector>
|
|
#include <algorithm>
|
|
|
|
static inline float32_t f32_add_odd(float32_t a, float32_t b)
|
|
{
|
|
auto rm = softfloat_roundingMode;
|
|
auto flags = softfloat_exceptionFlags;
|
|
|
|
softfloat_roundingMode = softfloat_round_odd;
|
|
softfloat_exceptionFlags = 0;
|
|
|
|
auto res = f32_add(a, b);
|
|
|
|
if (softfloat_exceptionFlags & softfloat_flag_overflow) {
|
|
res.v++; // FLT_MAX -> INF
|
|
}
|
|
|
|
auto new_flags = softfloat_exceptionFlags & (softfloat_flag_overflow | softfloat_flag_invalid);
|
|
|
|
softfloat_roundingMode = rm;
|
|
softfloat_exceptionFlags = flags | new_flags;
|
|
|
|
return res;
|
|
}
|
|
|
|
static inline float32_t zvfwbdot16bf_dot_acc(const std::vector<uint16_t>& a, const std::vector<uint16_t>& b, float32_t c)
|
|
{
|
|
std::vector<bf16_t> fa(a.size());
|
|
std::transform(a.begin(), a.end(), fa.begin(), [](auto f) { return f; });
|
|
|
|
std::vector<bf16_t> fb(b.size());
|
|
std::transform(b.begin(), b.end(), fb.begin(), [](auto f) { return f; });
|
|
|
|
DotConfig cfg(a.size(), int_log2(a.size()) + ((a.size() & (a.size() - 1)) != 0));
|
|
auto res = bulk_norm_dot_bf16(cfg, &fa[0], &fb[0]);
|
|
softfloat_exceptionFlags |= res.flags;
|
|
return f32_add_odd(f32(res.out), c);
|
|
}
|
|
|
|
template<typename A, typename B>
|
|
float32_t zvfqbdot8f_dot_acc(const std::vector<uint8_t>& a, const std::vector<uint8_t>& b, float32_t c)
|
|
{
|
|
std::vector<A> fa(a.size());
|
|
std::transform(a.begin(), a.end(), fa.begin(), [](auto f) { return f; });
|
|
|
|
std::vector<B> fb(b.size());
|
|
std::transform(b.begin(), b.end(), fb.begin(), [](auto f) { return f; });
|
|
|
|
DotConfig cfg(a.size(), int_log2(a.size()) + ((a.size() & (a.size() - 1)) != 0));
|
|
auto res = bulk_norm_dot_ofp8(cfg, &fa[0], &fb[0]);
|
|
softfloat_exceptionFlags |= res.flags;
|
|
return f32_add_odd(f32(res.out), c);
|
|
}
|
|
|
|
#endif
|
|
|