1use crate::geometry::Geometry;
6use nalgebra::{DMatrix, DVector};
7use std::error::Error;
8use std::fmt;
9
10#[derive(Debug, Clone)]
16pub enum Constraint {
17 Bond {
19 atoms: (usize, usize),
21 target: f64,
23 },
24 Angle {
26 atoms: (usize, usize, usize),
28 target: f64,
30 },
31 Dihedral {
33 atoms: (usize, usize, usize, usize),
35 target: f64,
37 },
38}
39
40pub fn evaluate_constraints(geometry: &Geometry, constraints: &[Constraint]) -> DVector<f64> {
43 let violations: Vec<f64> = constraints
44 .iter()
45 .map(|c| match c {
46 Constraint::Bond { atoms, target } => {
47 let p1 = geometry.get_atom_coords(atoms.0);
48 let p2 = geometry.get_atom_coords(atoms.1);
49 let dx = p1[0] - p2[0];
50 let dy = p1[1] - p2[1];
51 let dz = p1[2] - p2[2];
52 (dx * dx + dy * dy + dz * dz).sqrt() - target
53 }
54 Constraint::Angle { atoms, target } => {
55 let p1 = geometry.get_atom_coords(atoms.0);
56 let p2 = geometry.get_atom_coords(atoms.1);
57 let p3 = geometry.get_atom_coords(atoms.2);
58 let v21 = [p1[0] - p2[0], p1[1] - p2[1], p1[2] - p2[2]];
59 let v23 = [p3[0] - p2[0], p3[1] - p2[1], p3[2] - p2[2]];
60 let dot = v21[0] * v23[0] + v21[1] * v23[1] + v21[2] * v23[2];
61 let n21 = (v21[0].powi(2) + v21[1].powi(2) + v21[2].powi(2)).sqrt();
62 let n23 = (v23[0].powi(2) + v23[1].powi(2) + v23[2].powi(2)).sqrt();
63 (dot / (n21 * n23)).acos() - target
64 }
65 Constraint::Dihedral { atoms, target } => {
66 let (a1, a2, a3, a4) = *atoms;
67 let current_dihedral = calculate_dihedral(geometry, a1, a2, a3, a4);
68 current_dihedral - target
69 }
70 })
71 .collect();
72
73 DVector::from_vec(violations)
74}
75
76pub fn build_constraint_jacobian(geometry: &Geometry, constraints: &[Constraint]) -> DMatrix<f64> {
80 let num_constraints = constraints.len();
81 let num_dof = geometry.num_atoms * 3;
82 let mut jacobian = DMatrix::zeros(num_constraints, num_dof);
83
84 for (i, constraint) in constraints.iter().enumerate() {
85 match constraint {
86 Constraint::Bond {
87 atoms: (a1, a2), ..
88 } => {
89 let grad = calculate_bond_gradient(geometry, *a1, *a2);
90 for j in 0..3 {
92 jacobian[(i, a1 * 3 + j)] = grad[j];
93 jacobian[(i, a2 * 3 + j)] = -grad[j];
94 }
95 }
96 Constraint::Angle {
97 atoms: (a1, a2, a3),
98 ..
99 } => {
100 let (grad1, grad2, grad3) = calculate_angle_gradient(geometry, *a1, *a2, *a3);
101 for j in 0..3 {
103 jacobian[(i, a1 * 3 + j)] = grad1[j];
104 jacobian[(i, a2 * 3 + j)] = grad2[j];
105 jacobian[(i, a3 * 3 + j)] = grad3[j];
106 }
107 }
108 Constraint::Dihedral { atoms, .. } => {
109 let (a1, a2, a3, a4) = *atoms;
110 let (grad1, grad2, grad3, grad4) =
111 calculate_dihedral_gradient(geometry, a1, a2, a3, a4);
112 for j in 0..3 {
114 jacobian[(i, a1 * 3 + j)] = grad1[j];
115 jacobian[(i, a2 * 3 + j)] = grad2[j];
116 jacobian[(i, a3 * 3 + j)] = grad3[j];
117 jacobian[(i, a4 * 3 + j)] = grad4[j];
118 }
119 }
120 }
121 }
122 jacobian
123}
124
125fn calculate_bond_gradient(geometry: &Geometry, a1: usize, a2: usize) -> [f64; 3] {
127 let pos1 = geometry.get_atom_coords(a1);
128 let pos2 = geometry.get_atom_coords(a2);
129 let vec = [pos1[0] - pos2[0], pos1[1] - pos2[1], pos1[2] - pos2[2]];
130 let norm = (vec[0].powi(2) + vec[1].powi(2) + vec[2].powi(2)).sqrt();
131 if norm == 0.0 {
132 return [0.0, 0.0, 0.0];
133 }
134 [vec[0] / norm, vec[1] / norm, vec[2] / norm]
135}
136
137fn calculate_angle_gradient(
139 geometry: &Geometry,
140 i: usize,
141 j: usize,
142 k: usize,
143) -> ([f64; 3], [f64; 3], [f64; 3]) {
144 let pi = geometry.get_atom_coords(i);
145 let pj = geometry.get_atom_coords(j);
146 let pk = geometry.get_atom_coords(k);
147
148 let r_ji = [pi[0] - pj[0], pi[1] - pj[1], pi[2] - pj[2]];
149 let r_jk = [pk[0] - pj[0], pk[1] - pj[1], pk[2] - pj[2]];
150
151 let n_ji = (r_ji[0].powi(2) + r_ji[1].powi(2) + r_ji[2].powi(2)).sqrt();
152 let n_jk = (r_jk[0].powi(2) + r_jk[1].powi(2) + r_jk[2].powi(2)).sqrt();
153
154 let dot = r_ji[0] * r_jk[0] + r_ji[1] * r_jk[1] + r_ji[2] * r_jk[2];
155 let cos_theta = dot / (n_ji * n_jk);
156
157 let cos_theta = cos_theta.clamp(-1.0, 1.0);
159 let sin_theta = (1.0 - cos_theta.powi(2)).sqrt();
160
161 if sin_theta.abs() < 1e-6 {
162 return ([0.0; 3], [0.0; 3], [0.0; 3]);
163 }
164
165 let prefactor = -1.0 / (n_ji * n_jk * sin_theta);
166
167 let term_i = [
168 prefactor * (r_jk[0] / n_jk - cos_theta * r_ji[0] / n_ji),
169 prefactor * (r_jk[1] / n_jk - cos_theta * r_ji[1] / n_ji),
170 prefactor * (r_jk[2] / n_jk - cos_theta * r_ji[2] / n_ji),
171 ];
172
173 let term_k = [
174 prefactor * (r_ji[0] / n_ji - cos_theta * r_jk[0] / n_jk),
175 prefactor * (r_ji[1] / n_ji - cos_theta * r_jk[1] / n_jk),
176 prefactor * (r_ji[2] / n_ji - cos_theta * r_jk[2] / n_jk),
177 ];
178
179 let term_j = [
180 -term_i[0] - term_k[0],
181 -term_i[1] - term_k[1],
182 -term_i[2] - term_k[2],
183 ];
184
185 (term_i, term_j, term_k)
186}
187
188fn cross_product(a: &[f64; 3], b: &[f64; 3]) -> [f64; 3] {
190 [
191 a[1] * b[2] - a[2] * b[1],
192 a[2] * b[0] - a[0] * b[2],
193 a[0] * b[1] - a[1] * b[0],
194 ]
195}
196
197fn dot_product(a: &[f64; 3], b: &[f64; 3]) -> f64 {
199 a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
200}
201
202pub fn calculate_dihedral(geometry: &Geometry, a1: usize, a2: usize, a3: usize, a4: usize) -> f64 {
219 let p1 = geometry.get_atom_coords(a1);
220 let p2 = geometry.get_atom_coords(a2);
221 let p3 = geometry.get_atom_coords(a3);
222 let p4 = geometry.get_atom_coords(a4);
223
224 let v1 = [p2[0] - p1[0], p2[1] - p1[1], p2[2] - p1[2]];
225 let v2 = [p3[0] - p2[0], p3[1] - p2[1], p3[2] - p2[2]];
226 let v3 = [p4[0] - p3[0], p4[1] - p3[1], p4[2] - p3[2]];
227
228 let n1 = cross_product(&v1, &v2);
229 let n2 = cross_product(&v2, &v3);
230
231 let n1_norm = (n1[0].powi(2) + n1[1].powi(2) + n1[2].powi(2)).sqrt();
232 let n2_norm = (n2[0].powi(2) + n2[1].powi(2) + n2[2].powi(2)).sqrt();
233
234 if n1_norm < 1e-10 || n2_norm < 1e-10 {
235 return 0.0;
236 }
237
238 let n1_unit = [n1[0] / n1_norm, n1[1] / n1_norm, n1[2] / n1_norm];
239 let n2_unit = [n2[0] / n2_norm, n2[1] / n2_norm, n2[2] / n2_norm];
240
241 let cos_phi = dot_product(&n1_unit, &n2_unit);
243
244 let cross_n1_n2 = cross_product(&n1_unit, &n2_unit);
246 let sin_phi =
247 dot_product(&cross_n1_n2, &v2) / (v2[0].powi(2) + v2[1].powi(2) + v2[2].powi(2)).sqrt();
248
249 sin_phi.atan2(cos_phi)
251}
252
253pub fn calculate_dihedral_gradient(
274 geometry: &Geometry,
275 a1: usize,
276 a2: usize,
277 a3: usize,
278 a4: usize,
279) -> ([f64; 3], [f64; 3], [f64; 3], [f64; 3]) {
280 let p1 = geometry.get_atom_coords(a1);
281 let p2 = geometry.get_atom_coords(a2);
282 let p3 = geometry.get_atom_coords(a3);
283 let p4 = geometry.get_atom_coords(a4);
284
285 let x1 = p1[0];
286 let y1 = p1[1];
287 let z1 = p1[2];
288 let x2 = p2[0];
289 let y2 = p2[1];
290 let z2 = p2[2];
291 let x3 = p3[0];
292 let y3 = p3[1];
293 let z3 = p3[2];
294 let x4 = p4[0];
295 let y4 = p4[1];
296 let z4 = p4[2];
297
298 let n1x = (y2 - y1) * (z3 - z2) - (z2 - z1) * (y3 - y2);
300 let n1y = (z2 - z1) * (x3 - x2) - (x2 - x1) * (z3 - z2);
301 let n1z = (x2 - x1) * (y3 - y2) - (y2 - y1) * (x3 - x2);
302
303 let n2x = (y3 - y2) * (z4 - z3) - (z3 - z2) * (y4 - y3);
305 let n2y = (z3 - z2) * (x4 - x3) - (x3 - x2) * (z4 - z3);
306 let n2z = (x3 - x2) * (y4 - y3) - (y3 - y2) * (x4 - x3);
307
308 let len_n1_sq = n1x.powi(2) + n1y.powi(2) + n1z.powi(2);
309 let len_n2_sq = n2x.powi(2) + n2y.powi(2) + n2z.powi(2);
310
311 if len_n1_sq < 1e-20 || len_n2_sq < 1e-20 {
312 return ([0.0; 3], [0.0; 3], [0.0; 3], [0.0; 3]);
313 }
314
315 let len_n1 = len_n1_sq.sqrt();
316 let len_n2 = len_n2_sq.sqrt();
317
318 let dot = n1x * n2x + n1y * n2y + n1z * n2z;
319 let u = dot / (len_n1 * len_n2);
320
321 let sin_a_sq = 1.0 - u.powi(2);
322 if sin_a_sq < 1e-10 {
323 return ([0.0; 3], [0.0; 3], [0.0; 3], [0.0; 3]);
324 }
325 let sin_a = sin_a_sq.sqrt();
326
327 let da_dn1x = -(n2x - dot * n1x / len_n1_sq) / (len_n1 * len_n2 * sin_a);
329 let da_dn1y = -(n2y - dot * n1y / len_n1_sq) / (len_n1 * len_n2 * sin_a);
330 let da_dn1z = -(n2z - dot * n1z / len_n1_sq) / (len_n1 * len_n2 * sin_a);
331
332 let da_dn2x = -(n1x - dot * n2x / len_n2_sq) / (len_n1 * len_n2 * sin_a);
333 let da_dn2y = -(n1y - dot * n2y / len_n2_sq) / (len_n1 * len_n2 * sin_a);
334 let da_dn2z = -(n1z - dot * n2z / len_n2_sq) / (len_n1 * len_n2 * sin_a);
335
336 let sigma = -(n1x * (x4 - x3) + n1y * (y4 - y3) + n1z * (z4 - z3)).signum();
338
339 let grad1_x = sigma * (da_dn1y * (z3 - z2) + da_dn1z * (y2 - y3));
341 let grad1_y = sigma * (da_dn1z * (x3 - x2) + da_dn1x * (z2 - z3));
342 let grad1_z = sigma * (da_dn1x * (y3 - y2) + da_dn1y * (x2 - x3));
343
344 let grad2_x = sigma
345 * (da_dn1y * (z1 - z3) + da_dn1z * (y3 - y1) + da_dn2y * (z3 - z4) + da_dn2z * (y4 - y3));
346 let grad2_y = sigma
347 * (da_dn1z * (x1 - x3) + da_dn1x * (z3 - z1) + da_dn2z * (x3 - x4) + da_dn2x * (z4 - z3));
348 let grad2_z = sigma
349 * (da_dn1x * (y3 - y1) + da_dn1y * (x1 - x3) + da_dn2x * (y4 - y3) + da_dn2y * (x3 - x4));
350
351 let grad3_x = sigma
352 * (da_dn1y * (z2 - z1) + da_dn1z * (y1 - y2) + da_dn2y * (z4 - z2) + da_dn2z * (y2 - y4));
353 let grad3_y = sigma
354 * (da_dn1z * (x2 - x1) + da_dn1x * (z1 - z2) + da_dn2z * (x4 - x2) + da_dn2x * (z2 - z4));
355 let grad3_z = sigma
356 * (da_dn1x * (y1 - y2) + da_dn1y * (x2 - x1) + da_dn2x * (y2 - y4) + da_dn2y * (x4 - x2));
357
358 let grad4_x = sigma * (da_dn2y * (z2 - z3) + da_dn2z * (y3 - y2));
359 let grad4_y = sigma * (da_dn2z * (x2 - x3) + da_dn2x * (z3 - z2));
360 let grad4_z = sigma * (da_dn2x * (y2 - y3) + da_dn2y * (x3 - x2));
361
362 (
363 [grad1_x, grad1_y, grad1_z],
364 [grad2_x, grad2_y, grad2_z],
365 [grad3_x, grad3_y, grad3_z],
366 [grad4_x, grad4_y, grad4_z],
367 )
368}
369
370#[derive(Debug)]
372pub enum ConstraintError {
373 SingularJacobian,
375 NumericalInstability(String),
377 InvalidConstraint(String),
379}
380
381impl fmt::Display for ConstraintError {
382 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
383 match self {
384 ConstraintError::SingularJacobian => {
385 write!(f, "Constraint Jacobian matrix is singular - constraints may be linearly dependent")
386 }
387 ConstraintError::NumericalInstability(msg) => {
388 write!(f, "Numerical instability in constraint solving: {}", msg)
389 }
390 ConstraintError::InvalidConstraint(msg) => {
391 write!(f, "Invalid constraint specification: {}", msg)
392 }
393 }
394 }
395}
396
397impl Error for ConstraintError {}
398
399pub fn add_constraint_lagrange(
437 geometry: &Geometry,
438 forces: DVector<f64>,
439 constraints: &[Constraint],
440 lambdas: &mut Vec<f64>,
441) -> Result<(DVector<f64>, DVector<f64>), ConstraintError> {
442 if constraints.is_empty() {
444 lambdas.clear();
445 return Ok((forces, DVector::zeros(0)));
446 }
447
448 let n_constraints = constraints.len();
449 lambdas.resize(n_constraints, 0.0);
450
451 let jacobian = build_constraint_jacobian(geometry, constraints);
453 let violations = evaluate_constraints(geometry, constraints);
454
455 let first_step = lambdas.iter().all(|&l| l == 0.0);
458 if first_step {
459 let g = -forces.clone(); let total_constraints = n_constraints as f64; for i in 0..n_constraints {
462 let c_i = jacobian.row(i);
463 let c_vec = DVector::from_vec(c_i.iter().cloned().collect());
464 let c_dot_g = c_vec.dot(&g);
465 let c_dot_c = c_vec.dot(&c_vec);
466 if c_dot_c > 1e-12 {
467 lambdas[i] = (-c_dot_g / c_dot_c) * total_constraints; } else {
469 lambdas[i] = 0.0;
470 }
471 }
472 }
473
474 let lambda_vec = DVector::from_vec(lambdas.clone());
476 let constraint_forces = jacobian.transpose() * lambda_vec;
477 let modified_forces = forces + constraint_forces;
478
479 Ok((modified_forces, violations))
482}
483
484pub fn validate_constraints(
500 constraints: &[Constraint],
501 num_atoms: usize,
502) -> Result<(), ConstraintError> {
503 for (i, constraint) in constraints.iter().enumerate() {
504 match constraint {
505 Constraint::Bond {
506 atoms: (a1, a2),
507 target,
508 } => {
509 if *a1 >= num_atoms || *a2 >= num_atoms {
510 return Err(ConstraintError::InvalidConstraint(format!(
511 "Bond constraint {}: atom indices {} or {} exceed number of atoms {}",
512 i, a1, a2, num_atoms
513 )));
514 }
515 if a1 == a2 {
516 return Err(ConstraintError::InvalidConstraint(format!(
517 "Bond constraint {}: cannot constrain atom {} to itself",
518 i, a1
519 )));
520 }
521 if *target <= 0.0 || *target > 10.0 {
522 return Err(ConstraintError::InvalidConstraint(format!(
523 "Bond constraint {}: unreasonable target distance {:.3} Angstrom",
524 i, target
525 )));
526 }
527 }
528 Constraint::Angle {
529 atoms: (a1, a2, a3),
530 target,
531 } => {
532 if *a1 >= num_atoms || *a2 >= num_atoms || *a3 >= num_atoms {
533 return Err(ConstraintError::InvalidConstraint(format!(
534 "Angle constraint {}: atom indices {}, {}, or {} exceed number of atoms {}",
535 i, a1, a2, a3, num_atoms
536 )));
537 }
538 if a1 == a2 || a2 == a3 || a1 == a3 {
539 return Err(ConstraintError::InvalidConstraint(format!(
540 "Angle constraint {}: duplicate atom indices {}, {}, {}",
541 i, a1, a2, a3
542 )));
543 }
544 if *target < 0.0 || *target > std::f64::consts::PI {
545 return Err(ConstraintError::InvalidConstraint(format!(
546 "Angle constraint {}: target angle {:.3} rad is outside valid range [0, π]",
547 i, target
548 )));
549 }
550 }
551 Constraint::Dihedral {
552 atoms: (a1, a2, a3, a4),
553 target,
554 } => {
555 if *a1 >= num_atoms || *a2 >= num_atoms || *a3 >= num_atoms || *a4 >= num_atoms {
556 return Err(ConstraintError::InvalidConstraint(
557 format!("Dihedral constraint {}: atom indices {}, {}, {}, or {} exceed number of atoms {}",
558 i, a1, a2, a3, a4, num_atoms)
559 ));
560 }
561 let atoms = [*a1, *a2, *a3, *a4];
562 for j in 0..4 {
563 for k in (j + 1)..4 {
564 if atoms[j] == atoms[k] {
565 return Err(ConstraintError::InvalidConstraint(format!(
566 "Dihedral constraint {}: duplicate atom indices {}, {}, {}, {}",
567 i, a1, a2, a3, a4
568 )));
569 }
570 }
571 }
572 if *target < -std::f64::consts::PI || *target > std::f64::consts::PI {
573 return Err(ConstraintError::InvalidConstraint(
574 format!("Dihedral constraint {}: target angle {:.3} rad is outside valid range [-π, π]",
575 i, target)
576 ));
577 }
578 }
579 }
580 }
581 Ok(())
582}
583
584pub fn report_constraint_status(
611 geometry: &Geometry,
612 constraints: &[Constraint],
613 lambdas: &[f64],
614 step: usize,
615) {
616 if constraints.is_empty() {
617 return;
618 }
619
620 println!("\n--- Constraint Status (Step {}) ---", step);
621
622 let violations = evaluate_constraints(geometry, constraints);
623 let max_violation = violations.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
624
625 println!("Maximum constraint violation: {:.6}", max_violation);
626
627 for (i, constraint) in constraints.iter().enumerate() {
628 let violation = violations[i];
629 let lambda = lambdas.get(i).copied().unwrap_or(0.0);
630
631 match constraint {
632 Constraint::Bond {
633 atoms: (a, b),
634 target,
635 } => {
636 let current = calculate_bond_distance(geometry, *a, *b);
637 println!(
638 " Bond {}-{}: current={:.4} Angstrom, target={:.4} Angstrom, violation={:.6}, λ={:.6}",
639 a + 1,
640 b + 1,
641 current,
642 target,
643 violation,
644 lambda
645 );
646 }
647 Constraint::Angle {
648 atoms: (a, b, c),
649 target,
650 } => {
651 let current = calculate_bond_angle(geometry, *a, *b, *c);
652 println!(
653 " Angle {}-{}-{}: current={:.2}°, target={:.2}°, violation={:.6}, λ={:.6}",
654 a + 1,
655 b + 1,
656 c + 1,
657 current.to_degrees(),
658 target.to_degrees(),
659 violation,
660 lambda
661 );
662 }
663 Constraint::Dihedral {
664 atoms: (a, b, c, d),
665 target,
666 } => {
667 let current = calculate_dihedral(geometry, *a, *b, *c, *d);
668 println!(
669 " Dihedral {}-{}-{}-{}: current={:.2}°, target={:.2}°, violation={:.6}, λ={:.6}",
670 a + 1, b + 1, c + 1, d + 1,
671 current.to_degrees(), target.to_degrees(), violation, lambda
672 );
673 }
674 }
675 }
676
677 let converged = max_violation < 1e-6;
679 println!(
680 "Constraint convergence: {}",
681 if converged {
682 "CONVERGED"
683 } else {
684 "NOT CONVERGED"
685 }
686 );
687 println!("--- End Constraint Status ---\n");
688}
689
690fn calculate_bond_distance(geometry: &Geometry, a: usize, b: usize) -> f64 {
692 let pos_a = geometry.get_atom_coords(a);
693 let pos_b = geometry.get_atom_coords(b);
694 let dx = pos_a[0] - pos_b[0];
695 let dy = pos_a[1] - pos_b[1];
696 let dz = pos_a[2] - pos_b[2];
697 (dx * dx + dy * dy + dz * dz).sqrt()
698}
699
700fn calculate_bond_angle(geometry: &Geometry, a: usize, b: usize, c: usize) -> f64 {
702 let pos_a = geometry.get_atom_coords(a);
703 let pos_b = geometry.get_atom_coords(b);
704 let pos_c = geometry.get_atom_coords(c);
705
706 let v_ba = [
707 pos_a[0] - pos_b[0],
708 pos_a[1] - pos_b[1],
709 pos_a[2] - pos_b[2],
710 ];
711 let v_bc = [
712 pos_c[0] - pos_b[0],
713 pos_c[1] - pos_b[1],
714 pos_c[2] - pos_b[2],
715 ];
716
717 let dot = v_ba[0] * v_bc[0] + v_ba[1] * v_bc[1] + v_ba[2] * v_bc[2];
718 let norm_ba = (v_ba[0].powi(2) + v_ba[1].powi(2) + v_ba[2].powi(2)).sqrt();
719 let norm_bc = (v_bc[0].powi(2) + v_bc[1].powi(2) + v_bc[2].powi(2)).sqrt();
720
721 if norm_ba < 1e-10 || norm_bc < 1e-10 {
722 return 0.0;
723 }
724
725 let cos_angle = (dot / (norm_ba * norm_bc)).clamp(-1.0, 1.0);
726 cos_angle.acos()
727}