@@ -152,13 +152,17 @@ void PairMetatomic::settings(int argc, char ** argv) {
152152 i += 1 ;
153153 } else if (strcmp (argv[i], " non_conservative" ) == 0 ) {
154154 if (i == argc - 1 ) {
155- error->one (FLERR, " expected <on/off> after 'non_conservative' in pair_style metatomic, got nothing" );
155+ error->one (FLERR, " expected <on/off/forces/stress > after 'non_conservative' in pair_style metatomic, got nothing" );
156156 } else if (strcmp (argv[i + 1 ], " on" ) == 0 ) {
157- mta_data->non_conservative = true ;
157+ mta_data->non_conservative = PairMetatomicData::NonConservativeMode::ON ;
158158 } else if (strcmp (argv[i + 1 ], " off" ) == 0 ) {
159- mta_data->non_conservative = false ;
159+ mta_data->non_conservative = PairMetatomicData::NonConservativeMode::OFF;
160+ } else if (strcmp (argv[i + 1 ], " forces" ) == 0 ) {
161+ mta_data->non_conservative = PairMetatomicData::NonConservativeMode::FORCES;
162+ } else if (strcmp (argv[i + 1 ], " stress" ) == 0 ) {
163+ mta_data->non_conservative = PairMetatomicData::NonConservativeMode::STRESS;
160164 } else {
161- error->one (FLERR, " expected <on/off> after 'non_conservative' in pair_style metatomic, got '{}'" , argv[i + 1 ]);
165+ error->one (FLERR, " expected <on/off/forces/stress > after 'non_conservative' in pair_style metatomic, got '{}'" , argv[i + 1 ]);
162166 }
163167
164168 i += 1 ;
@@ -267,29 +271,61 @@ void PairMetatomic::settings(int argc, char ** argv) {
267271 }
268272
269273 // Handle non-conservative variants
270- if (mta_data->non_conservative ) {
271- // Error if *both* nc-force and nc-stress were provided by user AND one is Null
272- bool user_set_forces = (variant_nc_forces != nullptr );
273- bool user_set_stress = (variant_nc_stress != nullptr );
274-
275- if (user_set_forces && user_set_stress) {
276-
277- bool forces_none = !normalize_variant (variant_nc_forces).has_value ();
278- bool stress_none = !normalize_variant (variant_nc_stress).has_value ();
274+ using NCMode = PairMetatomicData::NonConservativeMode;
275+ const auto nc_mode = mta_data->non_conservative ;
276+
277+ bool user_set_forces = (variant_nc_forces != nullptr );
278+ bool user_set_stress = (variant_nc_stress != nullptr );
279+
280+ // Warn if the user set an explicit variant for an output that the chosen
281+ // mode does not use.
282+ if (user_set_forces && nc_mode != NCMode::ON && nc_mode != NCMode::FORCES) {
283+ error->warning (FLERR,
284+ " 'variant/non_conservative_forces' was set but the current 'non_conservative' mode "
285+ " does not use non-conservative forces; the variant will be ignored."
286+ );
287+ }
288+ if (user_set_stress && nc_mode != NCMode::ON && nc_mode != NCMode::STRESS) {
289+ error->warning (FLERR,
290+ " 'variant/non_conservative_stress' was set but the current 'non_conservative' mode "
291+ " does not use non-conservative stress; the variant will be ignored."
292+ );
293+ }
279294
280- if (forces_none != stress_none) {
281- error->one (FLERR,
282- " if both 'variant/non_conservative_stress' and "
283- " 'variant/non_conservative_forces' are present, they "
284- " must either both be 'off' or both not 'off'" );
285- }
295+ // For 'on' mode: if both variants are explicitly set, they must both be
296+ // 'off' or both not 'off' (consistent on/off).
297+ if (nc_mode == NCMode::ON && user_set_forces && user_set_stress) {
298+ bool forces_none = !normalize_variant (variant_nc_forces).has_value ();
299+ bool stress_none = !normalize_variant (variant_nc_stress).has_value ();
300+ if (forces_none != stress_none) {
301+ error->one (FLERR,
302+ " if both 'variant/non_conservative_stress' and "
303+ " 'variant/non_conservative_forces' are present with 'non_conservative on', "
304+ " they must either both be 'off' or both not 'off'" );
286305 }
306+ }
287307
308+ bool do_nc_forces = (nc_mode == NCMode::ON || nc_mode == NCMode::FORCES);
309+ if (do_nc_forces) {
288310 try {
289311 mta_data->nc_forces_key = pick_output (" non_conservative_forces" , outputs, v_nc_forces);
312+ } catch (std::exception& e) {
313+ error->one (FLERR,
314+ " {} Consider using 'non_conservative stress' or 'non_conservative off' instead." ,
315+ e.what ()
316+ );
317+ }
318+ }
319+
320+ bool do_nc_stress = (nc_mode == NCMode::ON || nc_mode == NCMode::STRESS);
321+ if (do_nc_stress) {
322+ try {
290323 mta_data->nc_stress_key = pick_output (" non_conservative_stress" , outputs, v_nc_stress);
291324 } catch (std::exception& e) {
292- error->one (FLERR, e.what ());
325+ error->one (FLERR,
326+ " {} Consider using 'non_conservative forces' or 'non_conservative off' instead." ,
327+ e.what ()
328+ );
293329 }
294330 }
295331
@@ -331,38 +367,34 @@ void PairMetatomic::settings(int argc, char ** argv) {
331367 }
332368 }
333369
334- if (mta_data-> non_conservative ) {
370+ if (do_nc_forces ) {
335371 auto nc_forces = outputs.find (mta_data->nc_forces_key );
336- if (nc_forces == outputs.end ()) {
337- error->one (FLERR,
338- " the model at '{}' does not have a '{}' output, "
339- " we can not enable non_conservative simulations" ,
340- model_path, mta_data->nc_forces_key
341- );
342- }
343-
344372 if (!nc_forces->value ()->per_atom ) {
345373 error->one (FLERR,
346374 " the '{}' output of the model at '{}' "
347375 " can not produce per-atom output, we can not enable non_conservative simulations" ,
348376 mta_data->nc_forces_key , model_path
349377 );
350378 }
351-
352379 mta_data->nc_forces_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
353380 mta_data->nc_forces_output ->set_quantity (" force" );
354381 mta_data->nc_forces_output ->set_unit (this ->energy_unit + " /" + this ->length_unit );
355382 mta_data->nc_forces_output ->per_atom = true ;
383+ }
356384
385+ if (do_nc_stress) {
357386 auto nc_stress = outputs.find (mta_data->nc_stress_key );
358- if (nc_stress != outputs.end ()) {
359- mta_data->nc_stress_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
360- mta_data->nc_stress_output ->set_quantity (" pressure" );
361- mta_data->nc_stress_output ->set_unit (this ->energy_unit + " /" + this ->length_unit + " ^3" );
362- mta_data->nc_stress_output ->per_atom = false ;
363- } else {
364- mta_data->nc_stress_output = nullptr ;
387+ if (nc_stress->value ()->per_atom ) {
388+ error->one (FLERR,
389+ " the '{}' output of the model at '{}' "
390+ " produces per-atom output, but a global stress is required" ,
391+ mta_data->nc_stress_key , model_path
392+ );
365393 }
394+ mta_data->nc_stress_output = torch::make_intrusive<metatomic_torch::ModelOutputHolder>();
395+ mta_data->nc_stress_output ->set_quantity (" pressure" );
396+ mta_data->nc_stress_output ->set_unit (this ->energy_unit + " /" + this ->length_unit + " ^3" );
397+ mta_data->nc_stress_output ->per_atom = false ;
366398 }
367399
368400 // Select the device to use based on the model's preference, the user choice
@@ -505,6 +537,9 @@ void PairMetatomic::coeff(int argc, char ** argv) {
505537
506538// called when the run starts
507539void PairMetatomic::init_style () {
540+ using NCMode = PairMetatomicData::NonConservativeMode;
541+ const auto nc_mode = mta_data->non_conservative ;
542+
508543 // Require newton pair on since we need to communicate forces accumulated on
509544 // ghost atoms to neighboring domains. These forces contributions come from
510545 // gradient of a local descriptor w.r.t. domain ghosts (periodic images
@@ -549,7 +584,7 @@ void PairMetatomic::init_style() {
549584 this ->type_mapping ,
550585 mta_data->max_cutoff ,
551586 mta_data->check_consistency ,
552- !(mta_data-> non_conservative ),
587+ nc_mode != NCMode::ON, // autograd needed for OFF/FORCES/STRESS
553588 };
554589 this ->system_adaptor = std::make_unique<MetatomicSystemAdaptor>(lmp, options);
555590
@@ -576,6 +611,9 @@ void PairMetatomic::init_list(int id, NeighList *ptr) {
576611}
577612
578613void PairMetatomic::compute (int eflag, int vflag) {
614+ using NCMode = PairMetatomicData::NonConservativeMode;
615+ const auto nc_mode = mta_data->non_conservative ;
616+
579617 if (std::getenv (" LAMMPS_METATOMIC_PROFILE" ) != nullptr ) {
580618 MetatomicTimer::enable (true );
581619 } else {
@@ -589,8 +627,15 @@ void PairMetatomic::compute(int eflag, int vflag) {
589627 mta_data->evaluation_options ->outputs .clear ();
590628 // we need an energy output if the energy was explicitly requested (through
591629 // `eflag_either`), or when running in standard/conservative mode, because
592- // we'll get the forces as the gradient of the energy through autodiff.
593- if (eflag_either || !mta_data->non_conservative ) {
630+ // we'll get the forces and stress as the gradient of the energy through autodiff.
631+ auto need_energy_for_autograd = (nc_mode == NCMode::OFF
632+ || nc_mode == NCMode::STRESS
633+ || (nc_mode == NCMode::FORCES && vflag_global));
634+
635+ auto do_nc_forces = nc_mode == NCMode::ON || nc_mode == NCMode::FORCES;
636+ auto do_nc_stress = nc_mode == NCMode::ON || nc_mode == NCMode::STRESS;
637+
638+ if (eflag_either || need_energy_for_autograd) {
594639 if (eflag_atom) {
595640 if (!mta_data->is_energy_output_per_atom ) {
596641 error->one (FLERR,
@@ -609,18 +654,11 @@ void PairMetatomic::compute(int eflag, int vflag) {
609654 mta_data->evaluation_options ->outputs .insert (mta_data->energy_uq_key , mta_data->uncertainty_output );
610655 }
611656
612- if (mta_data-> non_conservative ) {
657+ if (do_nc_forces ) {
613658 mta_data->evaluation_options ->outputs .insert (mta_data->nc_forces_key , mta_data->nc_forces_output );
614- if (vflag_global) {
615- if (mta_data->nc_stress_output == nullptr ) {
616- error->one (FLERR,
617- " the model at '{}' does not have a '{}' output, "
618- " we can not run non_conservative simulations that require computing the stress/virial" ,
619- mta_data->model_path , mta_data->nc_stress_key
620- );
621- }
622- mta_data->evaluation_options ->outputs .insert (mta_data->nc_stress_key , mta_data->nc_stress_output );
623- }
659+ }
660+ if (vflag_global && do_nc_stress) {
661+ mta_data->evaluation_options ->outputs .insert (mta_data->nc_stress_key , mta_data->nc_stress_output );
624662 }
625663
626664 auto dtype = torch::kFloat64 ;
@@ -635,7 +673,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
635673 // transform from LAMMPS to metatomic System
636674 auto system = this ->system_adaptor ->system_from_lmp (
637675 mta_list,
638- static_cast < bool >( vflag_global) ,
676+ vflag_global && !do_nc_stress ,
639677 dtype,
640678 mta_data->device
641679 );
@@ -712,7 +750,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
712750
713751 // get the energy if we need to compute the energy, or if we are using it to
714752 // get the forces/virial with autograd
715- if (eflag_either || !mta_data-> non_conservative ) {
753+ if (eflag_either || need_energy_for_autograd ) {
716754 auto energy = results.at (mta_data->energy_key ).toCustomClass <metatensor_torch::TensorMapHolder>();
717755 auto energy_block = metatensor_torch::TensorMapHolder::block_by_id (energy, 0 );
718756 energy_tensor = energy_block->values ();
@@ -722,30 +760,37 @@ void PairMetatomic::compute(int eflag, int vflag) {
722760 torch::Tensor forces_tensor;
723761 torch::Tensor virial_tensor;
724762
725- if (mta_data->non_conservative ) {
763+ // get nc forces
764+ if (do_nc_forces) {
726765 auto forces = results.at (mta_data->nc_forces_key ).toCustomClass <metatensor_torch::TensorMapHolder>();
727766 auto forces_block = metatensor_torch::TensorMapHolder::block_by_id (forces, 0 );
728767 forces_tensor = forces_block->values ().squeeze (-1 );
729768 forces_tensor = forces_tensor.to (torch::kCPU ).to (torch::kFloat64 );
769+ }
730770
731- if (vflag_global) {
732- auto stress = results.at (mta_data->nc_stress_key ).toCustomClass <metatensor_torch::TensorMapHolder>();
733- auto stress_block = metatensor_torch::TensorMapHolder::block_by_id (stress, 0 );
734- auto stress_tensor = stress_block->values ().squeeze (0 ).squeeze (-1 );
735- virial_tensor = - stress_tensor * compute_volume (domain);
736- virial_tensor = virial_tensor.to (torch::kCPU ).to (torch::kFloat64 );
737- }
738- } else {
739- // compute forces/virial on device with backward propagation
740- // reset gradients to zero before calling backward
771+ // get nc stress
772+ if (vflag_global && do_nc_stress) {
773+ auto stress = results.at (mta_data->nc_stress_key ).toCustomClass <metatensor_torch::TensorMapHolder>();
774+ auto stress_block = metatensor_torch::TensorMapHolder::block_by_id (stress, 0 );
775+ auto stress_tensor = stress_block->values ().squeeze (0 ).squeeze (-1 );
776+ virial_tensor = - stress_tensor * compute_volume (domain);
777+ virial_tensor = virial_tensor.to (torch::kCPU ).to (torch::kFloat64 );
778+ }
779+
780+ // compute conservative qunatities through autograd if needed
781+ if (need_energy_for_autograd) {
741782 this ->system_adaptor ->positions .mutable_grad () = torch::Tensor ();
742783 this ->system_adaptor ->strain .mutable_grad () = torch::Tensor ();
743784
744785 auto _ = MetatomicTimer (" running Model::backward" );
745786 energy_tensor.backward (-torch::ones_like (energy_tensor));
746787
747- forces_tensor = this ->system_adaptor ->positions .grad ();
748- virial_tensor = this ->system_adaptor ->strain .grad ();
788+ if (!do_nc_forces) {
789+ forces_tensor = this ->system_adaptor ->positions .grad ();
790+ }
791+ if (vflag_global && !do_nc_stress) {
792+ virial_tensor = this ->system_adaptor ->strain .grad ();
793+ }
749794 }
750795
751796 {
@@ -802,7 +847,7 @@ void PairMetatomic::compute(int eflag, int vflag) {
802847
803848 assert (!vflag_fdotr);
804849
805- if (vflag_global) {
850+ if (vflag_global && virial_tensor. defined () ) {
806851 auto virial_cpu = virial_tensor.to (torch::kCPU );
807852 assert (virial_cpu.is_cpu () && virial_cpu.scalar_type () == torch::kFloat64 );
808853
@@ -833,8 +878,10 @@ void PairMetatomic::store_forces(const at::Tensor& forces_tensor) {
833878 atom->f [i][2 ] += this ->scale * forces[i][2 ];
834879 }
835880
836- // in non-conservative mode we do not need to update forces on ghost atoms
837- if (!mta_data->non_conservative ) {
881+ // ghost atom forces only exist when forces come from autograd (OFF/STRESS modes)
882+ using NCMode = PairMetatomicData::NonConservativeMode;
883+ const auto nc_mode = mta_data->non_conservative ;
884+ if (nc_mode == NCMode::OFF || nc_mode == NCMode::STRESS) {
838885 const auto & mta_to_lmp = this ->system_adaptor ->mta_to_lmp ;
839886 for (int i=atom->nlocal ; i<forces.size (0 ); i++) {
840887 atom->f [mta_to_lmp[i]][0 ] += this ->scale * forces[i][0 ];
0 commit comments