From 717a6e275cd465651ee1a685936f4857e17eac39 Mon Sep 17 00:00:00 2001 From: Andrew Waterman Date: Wed, 3 Sep 2025 19:32:53 -0700 Subject: [PATCH] Add Zvldot extension support --- disasm/isa_parser.cc | 6 ++++++ riscv/encoding.h | 23 ++++++++++++++++++++++- riscv/insns/vfwldot_vv.h | 20 ++++++++++++++++++++ riscv/insns/vqldots_vv.h | 23 +++++++++++++++++++++++ riscv/insns/vqldotu_vv.h | 23 +++++++++++++++++++++++ riscv/isa_parser.h | 3 +++ riscv/riscv.mk.in | 6 ++++++ riscv/v_ext_macros.h | 28 ++++++++++++++++++++++++++++ riscv/vector_unit.cc | 6 ++++++ 9 files changed, 137 insertions(+), 1 deletion(-) create mode 100644 riscv/insns/vfwldot_vv.h create mode 100644 riscv/insns/vqldots_vv.h create mode 100644 riscv/insns/vqldotu_vv.h diff --git a/disasm/isa_parser.cc b/disasm/isa_parser.cc index 48df33e7..4b9049ec 100644 --- a/disasm/isa_parser.cc +++ b/disasm/isa_parser.cc @@ -334,6 +334,12 @@ isa_parser_t::isa_parser_t(const char* str, const char *priv) extension_table[EXT_ZVFWBDOT16BF] = true; } else if (ext_str == "zvfbdot32f") { extension_table[EXT_ZVFBDOT32F] = true; + } else if (ext_str == "zvqldot8i") { + extension_table[EXT_ZVQLDOT8I] = true; + } else if (ext_str == "zvqldot16i") { + extension_table[EXT_ZVQLDOT16I] = true; + } else if (ext_str == "zvfwldot16bf") { + extension_table[EXT_ZVFWLDOT16BF] = true; } else if (ext_str == "zvkt") { } else if (ext_str == "sstc") { extension_table[EXT_SSTC] = true; diff --git a/riscv/encoding.h b/riscv/encoding.h index 225aabba..13cbcf03 100644 --- a/riscv/encoding.h +++ b/riscv/encoding.h @@ -4,7 +4,7 @@ /* * This file is auto-generated by running 'make' in - * https://github.com/riscv/riscv-opcodes (1d4b38c) + * https://github.com/riscv/riscv-opcodes (74c1158) */ #ifndef RISCV_CSR_ENCODING_H @@ -1786,6 +1786,14 @@ #define MASK_VFNMSUB_VF 0xfc00707f #define MATCH_VFNMSUB_VV 0xac001057 #define MASK_VFNMSUB_VV 0xfc00707f +#define MATCH_VFQBDOT_ALT_VV 0xbc001077 +#define MASK_VFQBDOT_ALT_VV 0xfc00707f +#define MATCH_VFQBDOT_VV 0xb8001077 +#define MASK_VFQBDOT_VV 0xfc00707f +#define MATCH_VFQLDOT_ALT_VV 0x9c001077 +#define MASK_VFQLDOT_ALT_VV 0xfc00707f +#define MATCH_VFQLDOT_VV 0x98001077 +#define MASK_VFQLDOT_VV 0xfc00707f #define MATCH_VFRDIV_VF 0x84005057 #define MASK_VFRDIV_VF 0xfc00707f #define MATCH_VFREC7_V 0x4c029057 @@ -1850,6 +1858,8 @@ #define MASK_VFWCVT_XU_F_V 0xfc0ff07f #define MATCH_VFWCVTBF16_F_F_V 0x48069057 #define MASK_VFWCVTBF16_F_F_V 0xfc0ff07f +#define MATCH_VFWLDOT_VV 0x90001077 +#define MASK_VFWLDOT_VV 0xfc00707f #define MATCH_VFWMACC_VF 0xf0005057 #define MASK_VFWMACC_VF 0xfc00707f #define MATCH_VFWMACC_VV 0xf0001057 @@ -2190,6 +2200,10 @@ #define MASK_VQDOTU_VX 0xfc00707f #define MATCH_VQDOTUS_VX 0xb8006057 #define MASK_VQDOTUS_VX 0xfc00707f +#define MATCH_VQLDOTS_VV 0x9c000077 +#define MASK_VQLDOTS_VV 0xfc00707f +#define MATCH_VQLDOTU_VV 0x98000077 +#define MASK_VQLDOTU_VV 0xfc00707f #define MATCH_VREDAND_VS 0x4002057 #define MASK_VREDAND_VS 0xfc00707f #define MATCH_VREDMAX_VS 0x1c002057 @@ -3696,6 +3710,10 @@ DECLARE_INSN(vfnmsac_vf, MATCH_VFNMSAC_VF, MASK_VFNMSAC_VF) DECLARE_INSN(vfnmsac_vv, MATCH_VFNMSAC_VV, MASK_VFNMSAC_VV) DECLARE_INSN(vfnmsub_vf, MATCH_VFNMSUB_VF, MASK_VFNMSUB_VF) DECLARE_INSN(vfnmsub_vv, MATCH_VFNMSUB_VV, MASK_VFNMSUB_VV) +DECLARE_INSN(vfqbdot_alt_vv, MATCH_VFQBDOT_ALT_VV, MASK_VFQBDOT_ALT_VV) +DECLARE_INSN(vfqbdot_vv, MATCH_VFQBDOT_VV, MASK_VFQBDOT_VV) +DECLARE_INSN(vfqldot_alt_vv, MATCH_VFQLDOT_ALT_VV, MASK_VFQLDOT_ALT_VV) +DECLARE_INSN(vfqldot_vv, MATCH_VFQLDOT_VV, MASK_VFQLDOT_VV) DECLARE_INSN(vfrdiv_vf, MATCH_VFRDIV_VF, MASK_VFRDIV_VF) DECLARE_INSN(vfrec7_v, MATCH_VFREC7_V, MASK_VFREC7_V) DECLARE_INSN(vfredmax_vs, MATCH_VFREDMAX_VS, MASK_VFREDMAX_VS) @@ -3728,6 +3746,7 @@ DECLARE_INSN(vfwcvt_rtz_xu_f_v, MATCH_VFWCVT_RTZ_XU_F_V, MASK_VFWCVT_RTZ_XU_F_V) DECLARE_INSN(vfwcvt_x_f_v, MATCH_VFWCVT_X_F_V, MASK_VFWCVT_X_F_V) DECLARE_INSN(vfwcvt_xu_f_v, MATCH_VFWCVT_XU_F_V, MASK_VFWCVT_XU_F_V) DECLARE_INSN(vfwcvtbf16_f_f_v, MATCH_VFWCVTBF16_F_F_V, MASK_VFWCVTBF16_F_F_V) +DECLARE_INSN(vfwldot_vv, MATCH_VFWLDOT_VV, MASK_VFWLDOT_VV) DECLARE_INSN(vfwmacc_vf, MATCH_VFWMACC_VF, MASK_VFWMACC_VF) DECLARE_INSN(vfwmacc_vv, MATCH_VFWMACC_VV, MASK_VFWMACC_VV) DECLARE_INSN(vfwmaccbf16_vf, MATCH_VFWMACCBF16_VF, MASK_VFWMACCBF16_VF) @@ -3898,6 +3917,8 @@ DECLARE_INSN(vqdotsu_vx, MATCH_VQDOTSU_VX, MASK_VQDOTSU_VX) DECLARE_INSN(vqdotu_vv, MATCH_VQDOTU_VV, MASK_VQDOTU_VV) DECLARE_INSN(vqdotu_vx, MATCH_VQDOTU_VX, MASK_VQDOTU_VX) DECLARE_INSN(vqdotus_vx, MATCH_VQDOTUS_VX, MASK_VQDOTUS_VX) +DECLARE_INSN(vqldots_vv, MATCH_VQLDOTS_VV, MASK_VQLDOTS_VV) +DECLARE_INSN(vqldotu_vv, MATCH_VQLDOTU_VV, MASK_VQLDOTU_VV) DECLARE_INSN(vredand_vs, MATCH_VREDAND_VS, MASK_VREDAND_VS) DECLARE_INSN(vredmax_vs, MATCH_VREDMAX_VS, MASK_VREDMAX_VS) DECLARE_INSN(vredmaxu_vs, MATCH_VREDMAXU_VS, MASK_VREDMAXU_VS) diff --git a/riscv/insns/vfwldot_vv.h b/riscv/insns/vfwldot_vv.h new file mode 100644 index 00000000..3c95f677 --- /dev/null +++ b/riscv/insns/vfwldot_vv.h @@ -0,0 +1,20 @@ +VI_VFP_BASE; +ZVLDOT_INIT(2); + +switch (P.VU.vsew) { + case 16: { + if (P.VU.altfmt) { + // Although this implementation in IEEE 754 arithmetic is valid, most + // implementations will bulk-normalize on a VLEN-bit granule, then use + // f32_add_bulknorm_odd for the final steps (possibly in a tree). + // If a consensus emerges, we might change this implementation. + require_extension(EXT_ZVFWLDOT16BF); + auto macc = [](auto a, auto b, auto c) { return f32_add_bulknorm_odd(c, f32_mul(bf16_to_f32(a), bf16_to_f32(b))); }; + ZVLDOT_GENERIC_LOOP(bfloat16_t, bfloat16_t, float32_t, macc); + } else { + require(false); + } + break; + } + default: require(false); +} diff --git a/riscv/insns/vqldots_vv.h b/riscv/insns/vqldots_vv.h new file mode 100644 index 00000000..ce6376ac --- /dev/null +++ b/riscv/insns/vqldots_vv.h @@ -0,0 +1,23 @@ +ZVLDOT_INIT(4); + +switch (P.VU.vsew) { + case 8: { + require_extension(EXT_ZVQLDOT8I); + if (P.VU.altfmt) { + ZVLDOT_SIMPLE_LOOP(int8_t, int8_t, uint32_t); + } else { + ZVLDOT_SIMPLE_LOOP(uint8_t, int8_t, uint32_t); + } + break; + } + case 16: { + require_extension(EXT_ZVQLDOT16I); + if (P.VU.altfmt) { + ZVLDOT_SIMPLE_LOOP(int16_t, int16_t, uint64_t); + } else { + ZVLDOT_SIMPLE_LOOP(uint16_t, int16_t, uint64_t); + } + break; + } + default: require(false); +} diff --git a/riscv/insns/vqldotu_vv.h b/riscv/insns/vqldotu_vv.h new file mode 100644 index 00000000..2b674b13 --- /dev/null +++ b/riscv/insns/vqldotu_vv.h @@ -0,0 +1,23 @@ +ZVLDOT_INIT(4); + +switch (P.VU.vsew) { + case 8: { + require_extension(EXT_ZVQLDOT8I); + if (P.VU.altfmt) { + ZVLDOT_SIMPLE_LOOP(int8_t, uint8_t, uint32_t); + } else { + ZVLDOT_SIMPLE_LOOP(uint8_t, uint8_t, uint32_t); + } + break; + } + case 16: { + require_extension(EXT_ZVQLDOT16I); + if (P.VU.altfmt) { + ZVLDOT_SIMPLE_LOOP(int16_t, uint16_t, uint64_t); + } else { + ZVLDOT_SIMPLE_LOOP(uint16_t, uint16_t, uint64_t); + } + break; + } + default: require(false); +} diff --git a/riscv/isa_parser.h b/riscv/isa_parser.h index 376a5285..b14166b2 100644 --- a/riscv/isa_parser.h +++ b/riscv/isa_parser.h @@ -72,6 +72,9 @@ typedef enum { EXT_ZVQBDOT16I, EXT_ZVFWBDOT16BF, EXT_ZVFBDOT32F, + EXT_ZVQLDOT8I, + EXT_ZVQLDOT16I, + EXT_ZVFWLDOT16BF, EXT_SSTC, EXT_ZAAMO, EXT_ZALRSC, diff --git a/riscv/riscv.mk.in b/riscv/riscv.mk.in index ab683357..b7f83447 100644 --- a/riscv/riscv.mk.in +++ b/riscv/riscv.mk.in @@ -1080,6 +1080,11 @@ riscv_insn_ext_zvbdot = \ vfwbdot_vv \ vfbdot_vv \ +riscv_insn_ext_zvldot = \ + vqldotu_vv \ + vqldots_vv \ + vfwldot_vv \ + riscv_insn_ext_zimop = \ mop_r_N \ mop_rr_N \ @@ -1137,6 +1142,7 @@ riscv_insn_list = \ $(riscv_insn_ext_zicond) \ $(riscv_insn_ext_zvk) \ $(riscv_insn_ext_zvbdot) \ + $(riscv_insn_ext_zvldot) \ $(riscv_insn_priv) \ $(riscv_insn_smrnmi) \ $(riscv_insn_svinval) \ diff --git a/riscv/v_ext_macros.h b/riscv/v_ext_macros.h index f20e599f..3b95aeda 100644 --- a/riscv/v_ext_macros.h +++ b/riscv/v_ext_macros.h @@ -2078,6 +2078,15 @@ VI_VX_ULOOP({ \ break; \ } +#define ZVLDOT_INIT(widen) \ + require_vector(true); \ + require(P.VU.vstart->read() == 0); \ + require_align(insn.rs1(), P.VU.vflmul); \ + require_align(insn.rs2(), P.VU.vflmul); \ + require_vm; \ + require_noover(insn.rd(), 1, insn.rs1(), P.VU.vflmul); \ + require_noover(insn.rd(), 1, insn.rs2(), P.VU.vflmul) + #define ZVBDOT_INIT(widen) \ unsigned vd_eew = P.VU.vsew * (widen); \ unsigned vd_emul = std::max(1U, unsigned((8 * vd_eew) / P.VU.VLEN)); \ @@ -2100,6 +2109,25 @@ c_t generic_dot_product(const std::vector& a, const std::vector& b, c_ return c; } +#define ZVLDOT_LOOP(a_t, b_t, c_t, dot) \ + std::vector a(P.VU.vl->read(), a_t()); \ + std::vector b(P.VU.vl->read(), b_t()); \ + for (reg_t i = 0; i < a.size(); i++) { \ + VI_LOOP_ELEMENT_SKIP(); \ + a[i] = P.VU.elt(insn.rs1(), i); \ + b[i] = P.VU.elt(insn.rs2(), i); \ + } \ + auto& acc = P.VU.elt(insn.rd(), 0, true); \ + acc = dot(a, b, acc) + +#define ZVLDOT_GENERIC_LOOP(a_t, b_t, c_t, macc) \ + auto dot = std::bind(generic_dot_product, std::placeholders::_1, std::placeholders::_2, std::placeholders::_3, macc); \ + ZVLDOT_LOOP(a_t, b_t, c_t, dot) + +#define ZVLDOT_SIMPLE_LOOP(a_t, b_t, c_t) \ + auto macc = [](auto a, auto b, auto c) { return c + decltype(c)(a) * decltype(c)(b); }; \ + ZVLDOT_GENERIC_LOOP(a_t, b_t, c_t, macc) + #define ZVBDOT_LOOP(a_t, b_t, c_t, dot) \ for (reg_t idx = 0; idx < 8; idx++) { \ reg_t i = ci + idx; \ diff --git a/riscv/vector_unit.cc b/riscv/vector_unit.cc index b4ef6404..5fbab5ff 100644 --- a/riscv/vector_unit.cc +++ b/riscv/vector_unit.cc @@ -48,6 +48,12 @@ reg_t vectorUnit_t::vectorUnit_t::set_vl(int rd, int rs1, reg_t reqVL, reg_t new ill_altfmt = false; else if (p->extension_enabled(EXT_ZVFWBDOT16BF) && vsew == 16) ill_altfmt = false; + else if (p->extension_enabled(EXT_ZVQLDOT8I) && vsew == 8) + ill_altfmt = false; + else if (p->extension_enabled(EXT_ZVQLDOT16I) && vsew == 16) + ill_altfmt = false; + else if (p->extension_enabled(EXT_ZVFWLDOT16BF) && vsew == 16) + ill_altfmt = false; } vill = !(vflmul >= 0.125 && vflmul <= 8)