Browse Source

Add Zvldot extension support

pull/2065/head
Andrew Waterman 7 months ago
parent
commit
717a6e275c
  1. 6
      disasm/isa_parser.cc
  2. 23
      riscv/encoding.h
  3. 20
      riscv/insns/vfwldot_vv.h
  4. 23
      riscv/insns/vqldots_vv.h
  5. 23
      riscv/insns/vqldotu_vv.h
  6. 3
      riscv/isa_parser.h
  7. 6
      riscv/riscv.mk.in
  8. 28
      riscv/v_ext_macros.h
  9. 6
      riscv/vector_unit.cc

6
disasm/isa_parser.cc

@ -334,6 +334,12 @@ isa_parser_t::isa_parser_t(const char* str, const char *priv)
extension_table[EXT_ZVFWBDOT16BF] = true;
} else if (ext_str == "zvfbdot32f") {
extension_table[EXT_ZVFBDOT32F] = true;
} else if (ext_str == "zvqldot8i") {
extension_table[EXT_ZVQLDOT8I] = true;
} else if (ext_str == "zvqldot16i") {
extension_table[EXT_ZVQLDOT16I] = true;
} else if (ext_str == "zvfwldot16bf") {
extension_table[EXT_ZVFWLDOT16BF] = true;
} else if (ext_str == "zvkt") {
} else if (ext_str == "sstc") {
extension_table[EXT_SSTC] = true;

23
riscv/encoding.h

@ -4,7 +4,7 @@
/*
* This file is auto-generated by running 'make' in
* https://github.com/riscv/riscv-opcodes (1d4b38c)
* https://github.com/riscv/riscv-opcodes (74c1158)
*/
#ifndef RISCV_CSR_ENCODING_H
@ -1786,6 +1786,14 @@
#define MASK_VFNMSUB_VF 0xfc00707f
#define MATCH_VFNMSUB_VV 0xac001057
#define MASK_VFNMSUB_VV 0xfc00707f
#define MATCH_VFQBDOT_ALT_VV 0xbc001077
#define MASK_VFQBDOT_ALT_VV 0xfc00707f
#define MATCH_VFQBDOT_VV 0xb8001077
#define MASK_VFQBDOT_VV 0xfc00707f
#define MATCH_VFQLDOT_ALT_VV 0x9c001077
#define MASK_VFQLDOT_ALT_VV 0xfc00707f
#define MATCH_VFQLDOT_VV 0x98001077
#define MASK_VFQLDOT_VV 0xfc00707f
#define MATCH_VFRDIV_VF 0x84005057
#define MASK_VFRDIV_VF 0xfc00707f
#define MATCH_VFREC7_V 0x4c029057
@ -1850,6 +1858,8 @@
#define MASK_VFWCVT_XU_F_V 0xfc0ff07f
#define MATCH_VFWCVTBF16_F_F_V 0x48069057
#define MASK_VFWCVTBF16_F_F_V 0xfc0ff07f
#define MATCH_VFWLDOT_VV 0x90001077
#define MASK_VFWLDOT_VV 0xfc00707f
#define MATCH_VFWMACC_VF 0xf0005057
#define MASK_VFWMACC_VF 0xfc00707f
#define MATCH_VFWMACC_VV 0xf0001057
@ -2190,6 +2200,10 @@
#define MASK_VQDOTU_VX 0xfc00707f
#define MATCH_VQDOTUS_VX 0xb8006057
#define MASK_VQDOTUS_VX 0xfc00707f
#define MATCH_VQLDOTS_VV 0x9c000077
#define MASK_VQLDOTS_VV 0xfc00707f
#define MATCH_VQLDOTU_VV 0x98000077
#define MASK_VQLDOTU_VV 0xfc00707f
#define MATCH_VREDAND_VS 0x4002057
#define MASK_VREDAND_VS 0xfc00707f
#define MATCH_VREDMAX_VS 0x1c002057
@ -3696,6 +3710,10 @@ DECLARE_INSN(vfnmsac_vf, MATCH_VFNMSAC_VF, MASK_VFNMSAC_VF)
DECLARE_INSN(vfnmsac_vv, MATCH_VFNMSAC_VV, MASK_VFNMSAC_VV)
DECLARE_INSN(vfnmsub_vf, MATCH_VFNMSUB_VF, MASK_VFNMSUB_VF)
DECLARE_INSN(vfnmsub_vv, MATCH_VFNMSUB_VV, MASK_VFNMSUB_VV)
DECLARE_INSN(vfqbdot_alt_vv, MATCH_VFQBDOT_ALT_VV, MASK_VFQBDOT_ALT_VV)
DECLARE_INSN(vfqbdot_vv, MATCH_VFQBDOT_VV, MASK_VFQBDOT_VV)
DECLARE_INSN(vfqldot_alt_vv, MATCH_VFQLDOT_ALT_VV, MASK_VFQLDOT_ALT_VV)
DECLARE_INSN(vfqldot_vv, MATCH_VFQLDOT_VV, MASK_VFQLDOT_VV)
DECLARE_INSN(vfrdiv_vf, MATCH_VFRDIV_VF, MASK_VFRDIV_VF)
DECLARE_INSN(vfrec7_v, MATCH_VFREC7_V, MASK_VFREC7_V)
DECLARE_INSN(vfredmax_vs, MATCH_VFREDMAX_VS, MASK_VFREDMAX_VS)
@ -3728,6 +3746,7 @@ DECLARE_INSN(vfwcvt_rtz_xu_f_v, MATCH_VFWCVT_RTZ_XU_F_V, MASK_VFWCVT_RTZ_XU_F_V)
DECLARE_INSN(vfwcvt_x_f_v, MATCH_VFWCVT_X_F_V, MASK_VFWCVT_X_F_V)
DECLARE_INSN(vfwcvt_xu_f_v, MATCH_VFWCVT_XU_F_V, MASK_VFWCVT_XU_F_V)
DECLARE_INSN(vfwcvtbf16_f_f_v, MATCH_VFWCVTBF16_F_F_V, MASK_VFWCVTBF16_F_F_V)
DECLARE_INSN(vfwldot_vv, MATCH_VFWLDOT_VV, MASK_VFWLDOT_VV)
DECLARE_INSN(vfwmacc_vf, MATCH_VFWMACC_VF, MASK_VFWMACC_VF)
DECLARE_INSN(vfwmacc_vv, MATCH_VFWMACC_VV, MASK_VFWMACC_VV)
DECLARE_INSN(vfwmaccbf16_vf, MATCH_VFWMACCBF16_VF, MASK_VFWMACCBF16_VF)
@ -3898,6 +3917,8 @@ DECLARE_INSN(vqdotsu_vx, MATCH_VQDOTSU_VX, MASK_VQDOTSU_VX)
DECLARE_INSN(vqdotu_vv, MATCH_VQDOTU_VV, MASK_VQDOTU_VV)
DECLARE_INSN(vqdotu_vx, MATCH_VQDOTU_VX, MASK_VQDOTU_VX)
DECLARE_INSN(vqdotus_vx, MATCH_VQDOTUS_VX, MASK_VQDOTUS_VX)
DECLARE_INSN(vqldots_vv, MATCH_VQLDOTS_VV, MASK_VQLDOTS_VV)
DECLARE_INSN(vqldotu_vv, MATCH_VQLDOTU_VV, MASK_VQLDOTU_VV)
DECLARE_INSN(vredand_vs, MATCH_VREDAND_VS, MASK_VREDAND_VS)
DECLARE_INSN(vredmax_vs, MATCH_VREDMAX_VS, MASK_VREDMAX_VS)
DECLARE_INSN(vredmaxu_vs, MATCH_VREDMAXU_VS, MASK_VREDMAXU_VS)

20
riscv/insns/vfwldot_vv.h

@ -0,0 +1,20 @@
VI_VFP_BASE;
ZVLDOT_INIT(2);
switch (P.VU.vsew) {
case 16: {
if (P.VU.altfmt) {
// Although this implementation in IEEE 754 arithmetic is valid, most
// implementations will bulk-normalize on a VLEN-bit granule, then use
// f32_add_bulknorm_odd for the final steps (possibly in a tree).
// If a consensus emerges, we might change this implementation.
require_extension(EXT_ZVFWLDOT16BF);
auto macc = [](auto a, auto b, auto c) { return f32_add_bulknorm_odd(c, f32_mul(bf16_to_f32(a), bf16_to_f32(b))); };
ZVLDOT_GENERIC_LOOP(bfloat16_t, bfloat16_t, float32_t, macc);
} else {
require(false);
}
break;
}
default: require(false);
}

23
riscv/insns/vqldots_vv.h

@ -0,0 +1,23 @@
ZVLDOT_INIT(4);
switch (P.VU.vsew) {
case 8: {
require_extension(EXT_ZVQLDOT8I);
if (P.VU.altfmt) {
ZVLDOT_SIMPLE_LOOP(int8_t, int8_t, uint32_t);
} else {
ZVLDOT_SIMPLE_LOOP(uint8_t, int8_t, uint32_t);
}
break;
}
case 16: {
require_extension(EXT_ZVQLDOT16I);
if (P.VU.altfmt) {
ZVLDOT_SIMPLE_LOOP(int16_t, int16_t, uint64_t);
} else {
ZVLDOT_SIMPLE_LOOP(uint16_t, int16_t, uint64_t);
}
break;
}
default: require(false);
}

23
riscv/insns/vqldotu_vv.h

@ -0,0 +1,23 @@
ZVLDOT_INIT(4);
switch (P.VU.vsew) {
case 8: {
require_extension(EXT_ZVQLDOT8I);
if (P.VU.altfmt) {
ZVLDOT_SIMPLE_LOOP(int8_t, uint8_t, uint32_t);
} else {
ZVLDOT_SIMPLE_LOOP(uint8_t, uint8_t, uint32_t);
}
break;
}
case 16: {
require_extension(EXT_ZVQLDOT16I);
if (P.VU.altfmt) {
ZVLDOT_SIMPLE_LOOP(int16_t, uint16_t, uint64_t);
} else {
ZVLDOT_SIMPLE_LOOP(uint16_t, uint16_t, uint64_t);
}
break;
}
default: require(false);
}

3
riscv/isa_parser.h

@ -72,6 +72,9 @@ typedef enum {
EXT_ZVQBDOT16I,
EXT_ZVFWBDOT16BF,
EXT_ZVFBDOT32F,
EXT_ZVQLDOT8I,
EXT_ZVQLDOT16I,
EXT_ZVFWLDOT16BF,
EXT_SSTC,
EXT_ZAAMO,
EXT_ZALRSC,

6
riscv/riscv.mk.in

@ -1080,6 +1080,11 @@ riscv_insn_ext_zvbdot = \
vfwbdot_vv \
vfbdot_vv \
riscv_insn_ext_zvldot = \
vqldotu_vv \
vqldots_vv \
vfwldot_vv \
riscv_insn_ext_zimop = \
mop_r_N \
mop_rr_N \
@ -1137,6 +1142,7 @@ riscv_insn_list = \
$(riscv_insn_ext_zicond) \
$(riscv_insn_ext_zvk) \
$(riscv_insn_ext_zvbdot) \
$(riscv_insn_ext_zvldot) \
$(riscv_insn_priv) \
$(riscv_insn_smrnmi) \
$(riscv_insn_svinval) \

28
riscv/v_ext_macros.h

@ -2078,6 +2078,15 @@ VI_VX_ULOOP({ \
break; \
}
#define ZVLDOT_INIT(widen) \
require_vector(true); \
require(P.VU.vstart->read() == 0); \
require_align(insn.rs1(), P.VU.vflmul); \
require_align(insn.rs2(), P.VU.vflmul); \
require_vm; \
require_noover(insn.rd(), 1, insn.rs1(), P.VU.vflmul); \
require_noover(insn.rd(), 1, insn.rs2(), P.VU.vflmul)
#define ZVBDOT_INIT(widen) \
unsigned vd_eew = P.VU.vsew * (widen); \
unsigned vd_emul = std::max(1U, unsigned((8 * vd_eew) / P.VU.VLEN)); \
@ -2100,6 +2109,25 @@ c_t generic_dot_product(const std::vector<a_t>& a, const std::vector<b_t>& b, c_
return c;
}
#define ZVLDOT_LOOP(a_t, b_t, c_t, dot) \
std::vector<a_t> a(P.VU.vl->read(), a_t()); \
std::vector<b_t> b(P.VU.vl->read(), b_t()); \
for (reg_t i = 0; i < a.size(); i++) { \
VI_LOOP_ELEMENT_SKIP(); \
a[i] = P.VU.elt<a_t>(insn.rs1(), i); \
b[i] = P.VU.elt<b_t>(insn.rs2(), i); \
} \
auto& acc = P.VU.elt<c_t>(insn.rd(), 0, true); \
acc = dot(a, b, acc)
#define ZVLDOT_GENERIC_LOOP(a_t, b_t, c_t, macc) \
auto dot = std::bind(generic_dot_product<a_t, b_t, c_t>, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, macc); \
ZVLDOT_LOOP(a_t, b_t, c_t, dot)
#define ZVLDOT_SIMPLE_LOOP(a_t, b_t, c_t) \
auto macc = [](auto a, auto b, auto c) { return c + decltype(c)(a) * decltype(c)(b); }; \
ZVLDOT_GENERIC_LOOP(a_t, b_t, c_t, macc)
#define ZVBDOT_LOOP(a_t, b_t, c_t, dot) \
for (reg_t idx = 0; idx < 8; idx++) { \
reg_t i = ci + idx; \

6
riscv/vector_unit.cc

@ -48,6 +48,12 @@ reg_t vectorUnit_t::vectorUnit_t::set_vl(int rd, int rs1, reg_t reqVL, reg_t new
ill_altfmt = false;
else if (p->extension_enabled(EXT_ZVFWBDOT16BF) && vsew == 16)
ill_altfmt = false;
else if (p->extension_enabled(EXT_ZVQLDOT8I) && vsew == 8)
ill_altfmt = false;
else if (p->extension_enabled(EXT_ZVQLDOT16I) && vsew == 16)
ill_altfmt = false;
else if (p->extension_enabled(EXT_ZVFWLDOT16BF) && vsew == 16)
ill_altfmt = false;
}
vill = !(vflmul >= 0.125 && vflmul <= 8)

Loading…
Cancel
Save