Skip to content

Commit 87bb247

Browse files
committed
Allow nc simulations without stress
1 parent 877fdf7 commit 87bb247

3 files changed

Lines changed: 124 additions & 74 deletions

File tree

src/KOKKOS/pair_metatomic_kokkos.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,12 @@ void PairMetatomicKokkos<DeviceType>::init_style() {
7070
this->type_mapping_kk = Kokkos::View<int32_t*, Kokkos::LayoutRight, DeviceType>("type_mapping_kk", atom->ntypes + 1);
7171
Kokkos::deep_copy(this->type_mapping_kk, type_mapping_kk_host);
7272

73+
using NCMode = PairMetatomicData::NonConservativeMode;
7374
auto options = MetatomicSystemOptions{
7475
this->type_mapping_kk.data(),
7576
mta_data->max_cutoff,
7677
mta_data->check_consistency,
77-
!(mta_data->non_conservative),
78+
mta_data->non_conservative != NCMode::ON,
7879
};
7980

8081
// override the system adaptor with the kokkos version
@@ -112,6 +113,7 @@ void PairMetatomicKokkos<DeviceType>::pick_device(torch::Device& device, const c
112113

113114
template<class DeviceType>
114115
void PairMetatomicKokkos<DeviceType>::store_forces(const at::Tensor& forces_tensor) {
116+
using NCMode = PairMetatomicData::NonConservativeMode;
115117
assert(forces_tensor.scalar_type() == torch::kFloat64);
116118
auto forces = forces_tensor.contiguous();
117119

@@ -131,8 +133,8 @@ void PairMetatomicKokkos<DeviceType>::store_forces(const at::Tensor& forces_tens
131133
}
132134
);
133135

134-
// in non-conservative mode we do not need to update forces on ghost atoms
135-
if (!mta_data->non_conservative) {
136+
// ghost atom forces only exist when forces come from autograd (OFF/STRESS modes)
137+
if (mta_data->non_conservative == NCMode::OFF || mta_data->non_conservative == NCMode::STRESS) {
136138
auto system_adaptor_kk = dynamic_cast<MetatomicSystemAdaptorKokkos<DeviceType>*>(this->system_adaptor.get());
137139
assert(system_adaptor_kk != nullptr);
138140
auto mta_to_lmp_kk = UnmanagedView<int32_t*, DeviceType>(

src/ML-METATOMIC/metatomic_types.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,9 @@ struct PairMetatomicData: public CommonMetatomicData {
6969
metatomic_torch::ModelOutput nc_forces_output;
7070
metatomic_torch::ModelOutput nc_stress_output;
7171

72-
// whether non-conservative forces and stresses should be used
73-
bool non_conservative = false;
72+
// which non-conservative outputs to use
73+
enum class NonConservativeMode { OFF, ON, FORCES, STRESS };
74+
NonConservativeMode non_conservative = NonConservativeMode::OFF;
7475

7576
// energy key for the model
7677
std::string energy_key;

src/ML-METATOMIC/pair_metatomic.cpp

Lines changed: 116 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,17 @@ void PairMetatomic::settings(int argc, char ** argv) {
152152
i += 1;
153153
} else if (strcmp(argv[i], "non_conservative") == 0) {
154154
if (i == argc - 1) {
155-
error->one(FLERR, "expected <on/off> after 'non_conservative' in pair_style metatomic, got nothing");
155+
error->one(FLERR, "expected <on/off/forces/stress> after 'non_conservative' in pair_style metatomic, got nothing");
156156
} else if (strcmp(argv[i + 1], "on") == 0) {
157-
mta_data->non_conservative = true;
157+
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::ON;
158158
} else if (strcmp(argv[i + 1], "off") == 0) {
159-
mta_data->non_conservative = false;
159+
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::OFF;
160+
} else if (strcmp(argv[i + 1], "forces") == 0) {
161+
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::FORCES;
162+
} else if (strcmp(argv[i + 1], "stress") == 0) {
163+
mta_data->non_conservative = PairMetatomicData::NonConservativeMode::STRESS;
160164
} else {
161-
error->one(FLERR, "expected <on/off> after 'non_conservative' in pair_style metatomic, got '{}'", argv[i + 1]);
165+
error->one(FLERR, "expected <on/off/forces/stress> after 'non_conservative' in pair_style metatomic, got '{}'", argv[i + 1]);
162166
}
163167

164168
i += 1;
@@ -267,29 +271,61 @@ void PairMetatomic::settings(int argc, char ** argv) {
267271
}
268272

269273
// Handle non-conservative variants
270-
if (mta_data->non_conservative) {
271-
// Error if *both* nc-force and nc-stress were provided by user AND one is Null
272-
bool user_set_forces = (variant_nc_forces != nullptr);
273-
bool user_set_stress = (variant_nc_stress != nullptr);
274-
275-
if (user_set_forces && user_set_stress) {
276-
277-
bool forces_none = !normalize_variant(variant_nc_forces).has_value();
278-
bool stress_none = !normalize_variant(variant_nc_stress).has_value();
274+
using NCMode = PairMetatomicData::NonConservativeMode;
275+
const auto nc_mode = mta_data->non_conservative;
276+
277+
bool user_set_forces = (variant_nc_forces != nullptr);
278+
bool user_set_stress = (variant_nc_stress != nullptr);
279+
280+
// Warn if the user set an explicit variant for an output that the chosen
281+
// mode does not use.
282+
if (user_set_forces && nc_mode != NCMode::ON && nc_mode != NCMode::FORCES) {
283+
error->warning(FLERR,
284+
"'variant/non_conservative_forces' was set but the current 'non_conservative' mode "
285+
"does not use non-conservative forces; the variant will be ignored."
286+
);
287+
}
288+
if (user_set_stress && nc_mode != NCMode::ON && nc_mode != NCMode::STRESS) {
289+
error->warning(FLERR,
290+
"'variant/non_conservative_stress' was set but the current 'non_conservative' mode "
291+
"does not use non-conservative stress; the variant will be ignored."
292+
);
293+
}
279294

280-
if (forces_none != stress_none) {
281-
error->one(FLERR,
282-
"if both 'variant/non_conservative_stress' and "
283-
"'variant/non_conservative_forces' are present, they "
284-
"must either both be 'off' or both not 'off'");
285-
}
295+
// For 'on' mode: if both variants are explicitly set, they must both be
296+
// 'off' or both not 'off' (consistent on/off).
297+
if (nc_mode == NCMode::ON && user_set_forces && user_set_stress) {
298+
bool forces_none = !normalize_variant(variant_nc_forces).has_value();
299+
bool stress_none = !normalize_variant(variant_nc_stress).has_value();
300+
if (forces_none != stress_none) {
301+
error->one(FLERR,
302+
"if both 'variant/non_conservative_stress' and "
303+
"'variant/non_conservative_forces' are present with 'non_conservative on', "
304+
"they must either both be 'off' or both not 'off'");
286305
}
306+
}
287307

308+
bool do_nc_forces = (nc_mode == NCMode::ON || nc_mode == NCMode::FORCES);
309+
if (do_nc_forces) {
288310
try {
289311
mta_data->nc_forces_key = pick_output("non_conservative_forces", outputs, v_nc_forces);
312+
} catch (std::exception& e) {
313+
error->one(FLERR,
314+
"{} Consider using 'non_conservative stress' or 'non_conservative off' instead.",
315+
e.what()
316+
);
317+
}
318+
}
319+
320+
bool do_nc_stress = (nc_mode == NCMode::ON || nc_mode == NCMode::STRESS);
321+
if (do_nc_stress) {
322+
try {
290323
mta_data->nc_stress_key = pick_output("non_conservative_stress", outputs, v_nc_stress);
291324
} catch (std::exception& e) {
292-
error->one(FLERR, e.what());
325+
error->one(FLERR,
326+
"{} Consider using 'non_conservative forces' or 'non_conservative off' instead.",
327+
e.what()
328+
);
293329
}
294330
}
295331

@@ -331,38 +367,34 @@ void PairMetatomic::settings(int argc, char ** argv) {
331367
}
332368
}
333369

334-
if (mta_data->non_conservative) {
370+
if (do_nc_forces) {
335371
auto nc_forces = outputs.find(mta_data->nc_forces_key);
336-
if (nc_forces == outputs.end()) {
337-
error->one(FLERR,
338-
"the model at '{}' does not have a '{}' output, "
339-
"we can not enable non_conservative simulations",
340-
model_path, mta_data->nc_forces_key
341-
);
342-
}
343-
344372
if (!nc_forces->value()->per_atom) {
345373
error->one(FLERR,
346374
"the '{}' output of the model at '{}' "
347375
"can not produce per-atom output, we can not enable non_conservative simulations",
348376
mta_data->nc_forces_key, model_path
349377
);
350378
}
351-
352379
mta_data->nc_forces_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
353380
mta_data->nc_forces_output->set_quantity("force");
354381
mta_data->nc_forces_output->set_unit(this->energy_unit + "/" + this->length_unit);
355382
mta_data->nc_forces_output->per_atom = true;
383+
}
356384

385+
if (do_nc_stress) {
357386
auto nc_stress = outputs.find(mta_data->nc_stress_key);
358-
if (nc_stress != outputs.end()) {
359-
mta_data->nc_stress_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
360-
mta_data->nc_stress_output->set_quantity("pressure");
361-
mta_data->nc_stress_output->set_unit(this->energy_unit + "/" + this->length_unit + "^3");
362-
mta_data->nc_stress_output->per_atom = false;
363-
} else {
364-
mta_data->nc_stress_output = nullptr;
387+
if (nc_stress->value()->per_atom) {
388+
error->one(FLERR,
389+
"the '{}' output of the model at '{}' "
390+
"produces per-atom output, but a global stress is required",
391+
mta_data->nc_stress_key, model_path
392+
);
365393
}
394+
mta_data->nc_stress_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
395+
mta_data->nc_stress_output->set_quantity("pressure");
396+
mta_data->nc_stress_output->set_unit(this->energy_unit + "/" + this->length_unit + "^3");
397+
mta_data->nc_stress_output->per_atom = false;
366398
}
367399

368400
// Select the device to use based on the model's preference, the user choice
@@ -505,6 +537,9 @@ void PairMetatomic::coeff(int argc, char ** argv) {
505537

506538
// called when the run starts
507539
void PairMetatomic::init_style() {
540+
using NCMode = PairMetatomicData::NonConservativeMode;
541+
const auto nc_mode = mta_data->non_conservative;
542+
508543
// Require newton pair on since we need to communicate forces accumulated on
509544
// ghost atoms to neighboring domains. These forces contributions come from
510545
// gradient of a local descriptor w.r.t. domain ghosts (periodic images
@@ -549,7 +584,7 @@ void PairMetatomic::init_style() {
549584
this->type_mapping,
550585
mta_data->max_cutoff,
551586
mta_data->check_consistency,
552-
!(mta_data->non_conservative),
587+
nc_mode != NCMode::ON, // autograd needed for OFF/FORCES/STRESS
553588
};
554589
this->system_adaptor = std::make_unique<MetatomicSystemAdaptor>(lmp, options);
555590

@@ -576,6 +611,9 @@ void PairMetatomic::init_list(int id, NeighList *ptr) {
576611
}
577612

578613
void PairMetatomic::compute(int eflag, int vflag) {
614+
using NCMode = PairMetatomicData::NonConservativeMode;
615+
const auto nc_mode = mta_data->non_conservative;
616+
579617
if (std::getenv("LAMMPS_METATOMIC_PROFILE") != nullptr) {
580618
MetatomicTimer::enable(true);
581619
} else {
@@ -589,8 +627,15 @@ void PairMetatomic::compute(int eflag, int vflag) {
589627
mta_data->evaluation_options->outputs.clear();
590628
// we need an energy output if the energy was explicitly requested (through
591629
// `eflag_either`), or when running in standard/conservative mode, because
592-
// we'll get the forces as the gradient of the energy through autodiff.
593-
if (eflag_either || !mta_data->non_conservative) {
630+
// we'll get the forces and stress as the gradient of the energy through autodiff.
631+
auto need_energy_for_autograd = (nc_mode == NCMode::OFF
632+
|| nc_mode == NCMode::STRESS
633+
|| (nc_mode == NCMode::FORCES && vflag_global));
634+
635+
auto do_nc_forces = nc_mode == NCMode::ON || nc_mode == NCMode::FORCES;
636+
auto do_nc_stress = nc_mode == NCMode::ON || nc_mode == NCMode::STRESS;
637+
638+
if (eflag_either || need_energy_for_autograd) {
594639
if (eflag_atom) {
595640
if (!mta_data->is_energy_output_per_atom) {
596641
error->one(FLERR,
@@ -609,18 +654,11 @@ void PairMetatomic::compute(int eflag, int vflag) {
609654
mta_data->evaluation_options->outputs.insert(mta_data->energy_uq_key, mta_data->uncertainty_output);
610655
}
611656

612-
if (mta_data->non_conservative) {
657+
if (do_nc_forces) {
613658
mta_data->evaluation_options->outputs.insert(mta_data->nc_forces_key, mta_data->nc_forces_output);
614-
if (vflag_global) {
615-
if (mta_data->nc_stress_output == nullptr) {
616-
error->one(FLERR,
617-
"the model at '{}' does not have a '{}' output, "
618-
"we can not run non_conservative simulations that require computing the stress/virial",
619-
mta_data->model_path, mta_data->nc_stress_key
620-
);
621-
}
622-
mta_data->evaluation_options->outputs.insert(mta_data->nc_stress_key, mta_data->nc_stress_output);
623-
}
659+
}
660+
if (vflag_global && do_nc_stress) {
661+
mta_data->evaluation_options->outputs.insert(mta_data->nc_stress_key, mta_data->nc_stress_output);
624662
}
625663

626664
auto dtype = torch::kFloat64;
@@ -635,7 +673,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
635673
// transform from LAMMPS to metatomic System
636674
auto system = this->system_adaptor->system_from_lmp(
637675
mta_list,
638-
static_cast<bool>(vflag_global),
676+
vflag_global && !do_nc_stress,
639677
dtype,
640678
mta_data->device
641679
);
@@ -712,7 +750,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
712750

713751
// get the energy if we need to compute the energy, or if we are using it to
714752
// get the forces/virial with autograd
715-
if (eflag_either || !mta_data->non_conservative) {
753+
if (eflag_either || need_energy_for_autograd) {
716754
auto energy = results.at(mta_data->energy_key).toCustomClass<metatensor_torch::TensorMapHolder>();
717755
auto energy_block = metatensor_torch::TensorMapHolder::block_by_id(energy, 0);
718756
energy_tensor = energy_block->values();
@@ -722,30 +760,37 @@ void PairMetatomic::compute(int eflag, int vflag) {
722760
torch::Tensor forces_tensor;
723761
torch::Tensor virial_tensor;
724762

725-
if (mta_data->non_conservative) {
763+
// get nc forces
764+
if (do_nc_forces) {
726765
auto forces = results.at(mta_data->nc_forces_key).toCustomClass<metatensor_torch::TensorMapHolder>();
727766
auto forces_block = metatensor_torch::TensorMapHolder::block_by_id(forces, 0);
728767
forces_tensor = forces_block->values().squeeze(-1);
729768
forces_tensor = forces_tensor.to(torch::kCPU).to(torch::kFloat64);
769+
}
730770

731-
if (vflag_global) {
732-
auto stress = results.at(mta_data->nc_stress_key).toCustomClass<metatensor_torch::TensorMapHolder>();
733-
auto stress_block = metatensor_torch::TensorMapHolder::block_by_id(stress, 0);
734-
auto stress_tensor = stress_block->values().squeeze(0).squeeze(-1);
735-
virial_tensor = - stress_tensor * compute_volume(domain);
736-
virial_tensor = virial_tensor.to(torch::kCPU).to(torch::kFloat64);
737-
}
738-
} else {
739-
// compute forces/virial on device with backward propagation
740-
// reset gradients to zero before calling backward
771+
// get nc stress
772+
if (vflag_global && do_nc_stress) {
773+
auto stress = results.at(mta_data->nc_stress_key).toCustomClass<metatensor_torch::TensorMapHolder>();
774+
auto stress_block = metatensor_torch::TensorMapHolder::block_by_id(stress, 0);
775+
auto stress_tensor = stress_block->values().squeeze(0).squeeze(-1);
776+
virial_tensor = - stress_tensor * compute_volume(domain);
777+
virial_tensor = virial_tensor.to(torch::kCPU).to(torch::kFloat64);
778+
}
779+
780+
// compute conservative qunatities through autograd if needed
781+
if (need_energy_for_autograd) {
741782
this->system_adaptor->positions.mutable_grad() = torch::Tensor();
742783
this->system_adaptor->strain.mutable_grad() = torch::Tensor();
743784

744785
auto _ = MetatomicTimer("running Model::backward");
745786
energy_tensor.backward(-torch::ones_like(energy_tensor));
746787

747-
forces_tensor = this->system_adaptor->positions.grad();
748-
virial_tensor = this->system_adaptor->strain.grad();
788+
if (!do_nc_forces) {
789+
forces_tensor = this->system_adaptor->positions.grad();
790+
}
791+
if (vflag_global && !do_nc_stress) {
792+
virial_tensor = this->system_adaptor->strain.grad();
793+
}
749794
}
750795

751796
{
@@ -802,7 +847,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
802847

803848
assert(!vflag_fdotr);
804849

805-
if (vflag_global) {
850+
if (vflag_global && virial_tensor.defined()) {
806851
auto virial_cpu = virial_tensor.to(torch::kCPU);
807852
assert(virial_cpu.is_cpu() && virial_cpu.scalar_type() == torch::kFloat64);
808853

@@ -833,8 +878,10 @@ void PairMetatomic::store_forces(const at::Tensor& forces_tensor) {
833878
atom->f[i][2] += this->scale * forces[i][2];
834879
}
835880

836-
// in non-conservative mode we do not need to update forces on ghost atoms
837-
if (!mta_data->non_conservative) {
881+
// ghost atom forces only exist when forces come from autograd (OFF/STRESS modes)
882+
using NCMode = PairMetatomicData::NonConservativeMode;
883+
const auto nc_mode = mta_data->non_conservative;
884+
if (nc_mode == NCMode::OFF || nc_mode == NCMode::STRESS) {
838885
const auto& mta_to_lmp = this->system_adaptor->mta_to_lmp;
839886
for (int i=atom->nlocal; i<forces.size(0); i++) {
840887
atom->f[mta_to_lmp[i]][0] += this->scale * forces[i][0];

0 commit comments

Comments
 (0)