Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion feos-core/src/python/phase_equilibria.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ macro_rules! impl_phase_equilibrium {
/// A thermodynamic two phase equilibrium state.
#[pyclass(name = "PhaseEquilibrium")]
#[derive(Clone)]
pub struct PyPhaseEquilibrium(PhaseEquilibrium<$eos, 2>);
pub struct PyPhaseEquilibrium(pub PhaseEquilibrium<$eos, 2>);

#[pymethods]
impl PyPhaseEquilibrium {
Expand Down
172 changes: 45 additions & 127 deletions feos-derive/src/dft.rs
Original file line number Diff line number Diff line change
@@ -1,114 +1,45 @@
use crate::{implement, OPT_IMPLS};
use quote::quote;
use syn::DeriveInput;

use crate::implement;

const OPT_IMPLS: [&str; 4] = [
"bond_lengths",
"molar_weight",
"fluid_parameters",
"pair_potential",
];

pub(crate) fn expand_helmholtz_energy_functional(
input: DeriveInput,
) -> syn::Result<proc_macro2::TokenStream> {
let variants = match input.data {
syn::Data::Enum(syn::DataEnum { ref variants, .. }) => variants,
_ => panic!("this derive macro only works on enums"),
let syn::Data::Enum(syn::DataEnum { ref variants, .. }) = input.data else {
panic!("this derive macro only works on enums")
};

let from = impl_from(variants)?;
let functional = impl_helmholtz_energy_functional(variants)?;
let fluid_parameters = impl_fluid_parameters(variants)?;
let pair_potential = impl_pair_potential(variants)?;
let functional = impl_helmholtz_energy_functional(&input.ident, variants)?;
let fluid_parameters = impl_fluid_parameters(&input.ident, variants)?;
let pair_potential = impl_pair_potential(&input.ident, variants)?;
Ok(quote! {
#from
#functional
#fluid_parameters
#pair_potential
})
}

// extract the variant name and the name of the functional,
// i.e. PcSaft(PcSaftFunctional) will return (PcSaft, PcSaftFunctional)
fn extract_names(variant: &syn::Variant) -> syn::Result<(&syn::Ident, &syn::Ident)> {
let name = &variant.ident;
let field = if let syn::Fields::Unnamed(syn::FieldsUnnamed { ref unnamed, .. }) = variant.fields
{
if unnamed.len() != 1 {
return Err(syn::Error::new_spanned(
unnamed,
"expected tuple struct with single HelmholtzFunctional as variant",
));
}
&unnamed[0]
} else {
return Err(syn::Error::new_spanned(
name,
"expected variant with a HelmholtzFunctional as data",
));
};

let inner = if let syn::Type::Path(syn::TypePath { ref path, .. }) = &field.ty {
path.get_ident()
} else {
None
}
.ok_or_else(|| syn::Error::new_spanned(field, "expected HelmholtzFunctional"))?;
Ok((name, inner))
}

fn impl_from(
variants: &syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
variants
.iter()
.map(|v| {
let (variant_name, functional_name) = extract_names(v)?;
Ok(quote! {
impl From<#functional_name> for FunctionalVariant {
fn from(f: #functional_name) -> Self {
Self::#variant_name(f)
}
}
})
})
.collect()
}

fn impl_helmholtz_energy_functional(
pub(crate) fn impl_helmholtz_energy_functional(
ident: &syn::Ident,
variants: &syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
let molecule_shape = variants.iter().map(|v| {
let name = &v.ident;
quote! {
Self::#name(functional) => functional.molecule_shape()
}
});
let compute_max_density = variants.iter().map(|v| {
let name = &v.ident;
quote! {
Self::#name(functional) => functional.compute_max_density(moles)
}
});
let contributions = variants.iter().map(|v| {
let name = &v.ident;
quote! {
Self::#name(functional) => Box::new(functional.contributions().map(FunctionalContributionVariant::from))
}
});

let mut molar_weight = Vec::new();
let mut has_molar_weight = Vec::new();
let mut molecule_shape = Vec::new();
let mut contributions = Vec::new();
for v in variants.iter() {
if implement("molar_weight", v, &OPT_IMPLS)? {
let name = &v.ident;
molar_weight.push(quote! {
Self::#name(functional) => functional.molar_weight()
let name = &v.ident;
if implement("functional", v, &OPT_IMPLS)? {
molecule_shape.push(quote! {
Self::#name(functional) => functional.molecule_shape()
});
contributions.push(quote! {
Self::#name(functional) => Box::new(functional.contributions().map(FunctionalContributionVariant::from))
});
} else {
molecule_shape.push(quote! {
Self::#name(functional) => panic!("{} is not a Helmholtz energy functional!", stringify!(#name))
});
has_molar_weight.push(quote! {
Self::#name(functional) => true
contributions.push(quote! {
Self::#name(functional) => panic!("{} is not a Helmholtz energy functional!", stringify!(#name))
});
}
}
Expand All @@ -124,106 +55,93 @@ fn impl_helmholtz_energy_functional(
}

Ok(quote! {
impl HelmholtzEnergyFunctional for FunctionalVariant {
impl HelmholtzEnergyFunctional for #ident {
type Contribution = FunctionalContributionVariant;
fn molecule_shape(&self) -> MoleculeShape {
fn molecule_shape(&self) -> feos_dft::MoleculeShape {
match self {
#(#molecule_shape,)*
}
}
fn compute_max_density(&self, moles: &Array1<f64>) -> f64 {
match self {
#(#compute_max_density,)*
}
}
fn contributions(&self) -> Box<dyn Iterator<Item = FunctionalContributionVariant>> {
match self {
#(#contributions,)*
}
}
fn bond_lengths<N: DualNum<f64> + Copy>(&self, temperature: N) -> UnGraph<(), N> {
fn bond_lengths<N: DualNum<f64> + Copy>(&self, temperature: N) -> petgraph::graph::UnGraph<(), N> {
match self {
#(#bond_lengths,)*
_ => Graph::with_capacity(0, 0),
}
}
}

impl Molarweight for FunctionalVariant {
fn molar_weight(&self) -> MolarWeight<Array1<f64>> {
match self {
#(#molar_weight,)*
_ => unimplemented!()
}
}
}

impl FunctionalVariant {
pub fn has_molar_weight(&self) -> bool {
match self {
#(#has_molar_weight,)*
_ => false,
_ => petgraph::Graph::with_capacity(0, 0),
}
}
}
})
}

fn impl_fluid_parameters(
ident: &syn::Ident,
variants: &syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
let mut epsilon_k_ff = Vec::new();
let mut sigma_ff = Vec::new();

for v in variants.iter() {
let name = &v.ident;
if implement("fluid_parameters", v, &OPT_IMPLS)? {
let name = &v.ident;
epsilon_k_ff.push(quote! {
Self::#name(functional) => functional.epsilon_k_ff()
});
sigma_ff.push(quote! {
Self::#name(functional) => functional.sigma_ff()
});
} else {
epsilon_k_ff.push(quote! {
Self::#name(functional) => panic!("{} does not support the automatic calculation of external potentials!", stringify!(#name))
});
sigma_ff.push(quote! {
Self::#name(functional) => panic!("{} does not support the automatic calculation of external potentials!", stringify!(#name))
});
}
}
Ok(quote! {
impl FluidParameters for FunctionalVariant {
impl feos_dft::adsorption::FluidParameters for #ident {
fn epsilon_k_ff(&self) -> Array1<f64> {
match self {
#(#epsilon_k_ff,)*
_ => unimplemented!()
}
}

fn sigma_ff(&self) -> &Array1<f64> {
match self {
#(#sigma_ff,)*
_ => unimplemented!()
}
}
}
})
}

fn impl_pair_potential(
ident: &syn::Ident,
variants: &syn::punctuated::Punctuated<syn::Variant, syn::token::Comma>,
) -> syn::Result<proc_macro2::TokenStream> {
let mut pair_potential = Vec::new();

for v in variants.iter() {
let name = &v.ident;
if implement("pair_potential", v, &OPT_IMPLS)? {
let name = &v.ident;
pair_potential.push(quote! {
Self::#name(functional) => functional.pair_potential(i, r, temperature)
});
} else {
pair_potential.push(quote! {
Self::#name(functional) => panic!("{} does not provide pair potentials!", stringify!(#name))
});
}
}
Ok(quote! {
impl PairPotential for FunctionalVariant {
fn pair_potential(&self, i: usize, r: &Array1<f64>, temperature: f64) -> Array2<f64> {
impl feos_dft::solvation::PairPotential for #ident {
fn pair_potential(&self, i: usize, r: &Array1<f64>, temperature: f64) -> ndarray::Array2<f64> {
match self {
#(#pair_potential,)*
_ => unimplemented!()
}
}
}
Expand Down
Loading