blob: adefbf3215b5a3274d3de8d3423e0ad6a3479c20 [file] [log] [blame]
use itertools::Itertools;
use serde::{Deserialize, Deserializer, Serialize, de};
use crate::{
context::{self, GlobalContext},
intrinsic::Intrinsic,
predicate_forms::{PredicateForm, PredicationMask, PredicationMethods},
typekinds::TypeKind,
wildstring::WildString,
};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum InputType {
/// PredicateForm variant argument
#[serde(skip)] // Predicate forms have their own dedicated deserialization field. Skip.
PredicateForm(PredicateForm),
/// Operand from which to generate an N variant
#[serde(skip)]
NVariantOp(Option<WildString>),
/// TypeKind variant argument
Type(TypeKind),
}
impl InputType {
/// Optionally unwraps as a PredicateForm.
pub fn predicate_form(&self) -> Option<&PredicateForm> {
match self {
InputType::PredicateForm(pf) => Some(pf),
_ => None,
}
}
/// Optionally unwraps as a mutable PredicateForm
pub fn predicate_form_mut(&mut self) -> Option<&mut PredicateForm> {
match self {
InputType::PredicateForm(pf) => Some(pf),
_ => None,
}
}
/// Optionally unwraps as a TypeKind.
pub fn typekind(&self) -> Option<&TypeKind> {
match self {
InputType::Type(ty) => Some(ty),
_ => None,
}
}
/// Optionally unwraps as a NVariantOp
pub fn n_variant_op(&self) -> Option<&WildString> {
match self {
InputType::NVariantOp(Some(op)) => Some(op),
_ => None,
}
}
}
impl PartialOrd for InputType {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for InputType {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
use std::cmp::Ordering::*;
match (self, other) {
(InputType::PredicateForm(pf1), InputType::PredicateForm(pf2)) => pf1.cmp(pf2),
(InputType::Type(ty1), InputType::Type(ty2)) => ty1.cmp(ty2),
(InputType::NVariantOp(None), InputType::NVariantOp(Some(..))) => Less,
(InputType::NVariantOp(Some(..)), InputType::NVariantOp(None)) => Greater,
(InputType::NVariantOp(_), InputType::NVariantOp(_)) => Equal,
(InputType::Type(..), InputType::PredicateForm(..)) => Less,
(InputType::PredicateForm(..), InputType::Type(..)) => Greater,
(InputType::Type(..), InputType::NVariantOp(..)) => Less,
(InputType::NVariantOp(..), InputType::Type(..)) => Greater,
(InputType::PredicateForm(..), InputType::NVariantOp(..)) => Less,
(InputType::NVariantOp(..), InputType::PredicateForm(..)) => Greater,
}
}
}
mod many_or_one {
use serde::{Deserialize, Serialize, de::Deserializer, ser::Serializer};
pub fn serialize<T, S>(vec: &Vec<T>, serializer: S) -> Result<S::Ok, S::Error>
where
T: Serialize,
S: Serializer,
{
if vec.len() == 1 {
vec.first().unwrap().serialize(serializer)
} else {
vec.serialize(serializer)
}
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
where
T: Deserialize<'de>,
D: Deserializer<'de>,
{
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
enum ManyOrOne<T> {
Many(Vec<T>),
One(T),
}
match ManyOrOne::deserialize(deserializer)? {
ManyOrOne::Many(vec) => Ok(vec),
ManyOrOne::One(val) => Ok(vec![val]),
}
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct InputSet(#[serde(with = "many_or_one")] Vec<InputType>);
impl InputSet {
pub fn get(&self, idx: usize) -> Option<&InputType> {
self.0.get(idx)
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &InputType> + '_ {
self.0.iter()
}
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut InputType> + '_ {
self.0.iter_mut()
}
pub fn into_iter(self) -> impl Iterator<Item = InputType> + Clone {
self.0.into_iter()
}
pub fn types_len(&self) -> usize {
self.iter().filter_map(|arg| arg.typekind()).count()
}
pub fn typekind(&self, idx: Option<usize>) -> Option<TypeKind> {
let types_len = self.types_len();
self.get(idx.unwrap_or(0)).and_then(move |arg: &InputType| {
if (idx.is_none() && types_len != 1) || (idx.is_some() && types_len == 1) {
None
} else {
arg.typekind().cloned()
}
})
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InputSetEntry(#[serde(with = "many_or_one")] Vec<InputSet>);
impl InputSetEntry {
pub fn new(input: Vec<InputSet>) -> Self {
Self(input)
}
pub fn get(&self, idx: usize) -> Option<&InputSet> {
self.0.get(idx)
}
}
fn validate_types<'de, D>(deserializer: D) -> Result<Vec<InputSetEntry>, D::Error>
where
D: Deserializer<'de>,
{
let v: Vec<InputSetEntry> = Vec::deserialize(deserializer)?;
let mut it = v.iter();
if let Some(first) = it.next() {
it.try_fold(first, |last, cur| {
if last.0.len() == cur.0.len() {
Ok(cur)
} else {
Err("the length of the InputSets and the product lists must match".to_string())
}
})
.map_err(de::Error::custom)?;
}
Ok(v)
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct IntrinsicInput {
#[serde(default)]
#[serde(deserialize_with = "validate_types")]
pub types: Vec<InputSetEntry>,
#[serde(flatten)]
pub predication_methods: PredicationMethods,
/// Generates a _n variant where the specified operand is a primitive type
/// that requires conversion to an SVE one. The `{_n}` wildcard is required
/// in the intrinsic's name, otherwise an error will be thrown.
#[serde(default)]
pub n_variant_op: WildString,
}
impl IntrinsicInput {
/// Extracts all the possible variants as an iterator.
pub fn variants(
&self,
intrinsic: &Intrinsic,
) -> context::Result<impl Iterator<Item = InputSet> + '_> {
let mut top_product = vec![];
if !self.types.is_empty() {
top_product.push(
self.types
.iter()
.flat_map(|ty_in| {
ty_in
.0
.iter()
.map(|v| v.clone().into_iter())
.multi_cartesian_product()
})
.collect_vec(),
)
}
if let Ok(mask) = PredicationMask::try_from(&intrinsic.signature.name) {
top_product.push(
PredicateForm::compile_list(&mask, &self.predication_methods)?
.into_iter()
.map(|pf| vec![InputType::PredicateForm(pf)])
.collect_vec(),
)
}
if !self.n_variant_op.is_empty() {
top_product.push(vec![
vec![InputType::NVariantOp(None)],
vec![InputType::NVariantOp(Some(self.n_variant_op.to_owned()))],
])
}
let it = top_product
.into_iter()
.map(|v| v.into_iter())
.multi_cartesian_product()
.filter(|set| !set.is_empty())
.map(|set| InputSet(set.into_iter().flatten().collect_vec()));
Ok(it)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneratorInput {
#[serde(flatten)]
pub ctx: GlobalContext,
pub intrinsics: Vec<Intrinsic>,
}
#[cfg(test)]
mod tests {
use crate::{
input::*,
predicate_forms::{DontCareMethod, ZeroingMethod},
};
#[test]
fn test_empty() {
let str = r#"types: []"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
assert_eq!(variants.next(), None);
}
#[test]
fn test_product() {
let str = r#"types:
- [f64, f32]
- [i64, [f64, f32]]
"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut intrinsic = Intrinsic::default();
intrinsic.signature.name = "test_intrinsic{_mx}".parse().unwrap();
let mut variants = input.variants(&intrinsic).unwrap().into_iter();
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsMerging)),
])),
);
assert_eq!(variants.next(), None);
}
#[test]
fn test_n_variant() {
let str = r#"types:
- [f64, f32]
n_variant_op: op2
"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut variants = input.variants(&Intrinsic::default()).unwrap().into_iter();
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::NVariantOp(None),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("f64".parse().unwrap()),
InputType::Type("f32".parse().unwrap()),
InputType::NVariantOp(Some("op2".parse().unwrap())),
]))
);
assert_eq!(variants.next(), None)
}
#[test]
fn test_invalid_length() {
let str = r#"types: [i32, [[u64], [u32]]]"#;
serde_yaml::from_str::<IntrinsicInput>(str).expect_err("failure expected");
}
#[test]
fn test_invalid_predication() {
let str = "types: []";
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut intrinsic = Intrinsic::default();
intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
input
.variants(&intrinsic)
.map(|v| v.collect_vec())
.expect_err("failure expected");
}
#[test]
fn test_invalid_predication_mask() {
"test_intrinsic{_mxy}"
.parse::<WildString>()
.expect_err("failure expected");
"test_intrinsic{_}"
.parse::<WildString>()
.expect_err("failure expected");
}
#[test]
fn test_zeroing_predication() {
let str = r#"types: [i64]
zeroing_method: { drop: inactive }"#;
let input: IntrinsicInput = serde_yaml::from_str(str).expect("failed to parse");
let mut intrinsic = Intrinsic::default();
intrinsic.signature.name = "test_intrinsic{_mxz}".parse().unwrap();
let mut variants = input.variants(&intrinsic).unwrap();
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Merging),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::DontCare(DontCareMethod::AsZeroing)),
]))
);
assert_eq!(
variants.next(),
Some(InputSet(vec![
InputType::Type("i64".parse().unwrap()),
InputType::PredicateForm(PredicateForm::Zeroing(ZeroingMethod::Drop {
drop: "inactive".parse().unwrap()
})),
]))
);
assert_eq!(variants.next(), None)
}
}