Skip to content

Commit 6eab53e

Browse files
authored
Update cli.py
1 parent 84a4093 commit 6eab53e

1 file changed

Lines changed: 208 additions & 25 deletions

File tree

src/conezen/cli.py

Lines changed: 208 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,114 @@
1+
#!/usr/bin/env python
12
# conezen/cli.py
23

34
import sys
45
import shutil
56
from pathlib import Path
67
import pandas as pd
78
import matplotlib.pyplot as plt
9+
import numpy as np
10+
import re
811

12+
# This is the logic from your other file, now part of ConeZen
913
from . import logic
1014

15+
# --- GRADIENT & NAC EXTRACTION HELPERS ---
16+
17+
def _generate_state_headers(num_singlets=20, num_triplets=20):
18+
"""Generates a dictionary mapping common state names to their file header patterns."""
19+
headers = {}
20+
for n in range(1, num_singlets + 1):
21+
state_name = f"S{n-1}"
22+
headers[state_name] = f"m1 1 s1 {n} ms1 0"
23+
for n in range(1, num_triplets + 1):
24+
state_name_base = f"T{n}"
25+
headers[f"{state_name_base}_ms-1"] = f"m1 3 s1 {n} ms1 -1"
26+
headers[f"{state_name_base}_ms0"] = f"m1 3 s1 {n} ms1 0"
27+
headers[f"{state_name_base}_ms+1"] = f"m1 3 s1 {n} ms1 1"
28+
return headers
29+
30+
def _extract_gradient(file_path, state_name, state_headers):
31+
"""
32+
Extracts a gradient for a given state, returning the header and numpy array.
33+
"""
34+
header_to_find = state_headers.get(state_name.upper())
35+
if not header_to_find:
36+
print(f"❌ Error: State '{state_name}' is not defined in the header dictionary.")
37+
return None, None
38+
39+
try:
40+
with open(file_path, 'r') as f:
41+
lines = f.readlines()
42+
except FileNotFoundError:
43+
print(f"❌ Error: The gradient source file '{file_path}' was not found.")
44+
return None, None
45+
46+
for i, line in enumerate(lines):
47+
# We only search for the gradient header part, as it's unique enough
48+
if header_to_find in line and "m2" not in line:
49+
found_header_line = line.strip()
50+
gradient_data = []
51+
try:
52+
num_atoms = int(line.strip().split()[0])
53+
if i + 1 + num_atoms > len(lines):
54+
return None, None
55+
for j in range(num_atoms):
56+
data_line = lines[i + 1 + j].strip().split()
57+
gradient_data.append([float(x) for x in data_line])
58+
return found_header_line, np.array(gradient_data)
59+
except (ValueError, IndexError):
60+
return None, None
61+
return None, None
62+
63+
def _format_nac_part2(state_header_part):
64+
"""Converts a gradient-style header part to a NAC part 2 style."""
65+
# e.g., "m1 1 s1 4 ms1 0" -> "m2 1 s2 4 ms2 0"
66+
parts = state_header_part.split()
67+
parts[0] = 'm2'
68+
parts[2] = 's2'
69+
parts[4] = 'ms2'
70+
return " ".join(parts)
71+
72+
def _extract_nac_vector(file_path, state1_name, state2_name, state_headers):
73+
"""
74+
Extracts the NAC vector between two states, accounting for commutative headers.
75+
"""
76+
part1_raw = state_headers.get(state1_name.upper())
77+
part2_raw = state_headers.get(state2_name.upper())
78+
79+
if not part1_raw or not part2_raw:
80+
print(f"❌ Error: One or both states ('{state1_name}', '{state2_name}') are not defined.")
81+
return None, None
82+
83+
# Construct the two possible header combinations, e.g., (S2,S3) and (S3,S2)
84+
header_combo1 = f"{part1_raw} {_format_nac_part2(part2_raw)}"
85+
header_combo2 = f"{part2_raw} {_format_nac_part2(part1_raw)}"
86+
87+
try:
88+
with open(file_path, 'r') as f:
89+
lines = f.readlines()
90+
except FileNotFoundError:
91+
print(f"❌ Error: The NAC source file '{file_path}' was not found.")
92+
return None, None
93+
94+
for i, line in enumerate(lines):
95+
# Check for both possible header orders
96+
if header_combo1 in line or header_combo2 in line:
97+
found_header_line = line.strip()
98+
nac_data = []
99+
try:
100+
num_atoms = int(line.strip().split()[0])
101+
if i + 1 + num_atoms > len(lines):
102+
return None, None
103+
for j in range(num_atoms):
104+
data_line = lines[i + 1 + j].strip().split()
105+
nac_data.append([float(x) for x in data_line])
106+
return found_header_line, np.array(nac_data)
107+
except (ValueError, IndexError):
108+
return None, None
109+
return None, None
110+
111+
11112
# --- USER INTERACTION HELPERS ---
12113
def print_about():
13114
"""Display script info and citation."""
@@ -53,6 +154,15 @@ def get_file(prompt: str, must_exist=True, default=None):
53154
print(f"❌ File '{fname}' not found."); continue
54155
return path
55156

157+
def get_state_name(prompt: str) -> str:
158+
"""Get a non-empty state name from the user."""
159+
while True:
160+
state_name = input(prompt).strip().upper()
161+
if state_name:
162+
return state_name
163+
print("❌ State name cannot be empty. Please enter a value.")
164+
165+
56166
def safe_save_path(prompt: str, default: str):
57167
"""Prompt for a save path, warn if exists."""
58168
while True:
@@ -68,63 +178,133 @@ def main():
68178
"""Main function to run the command-line tool."""
69179
try:
70180
print_about()
71-
# -- Input files
72-
grad_fileA = get_file("Enter the gradient file name for State A", default="gradientA.out")
73-
grad_fileB = get_file("Enter the gradient file name for State B", default="gradientB.out")
74-
nac_file = get_file("Enter the NAC vector file name", default="NAC.out")
75-
181+
182+
grad_fileA, grad_fileB, nac_file = None, None, None
183+
184+
if ask_yes_no("(Do you want to automatically extract gradients and NACs from a QM output file? (only for sharc-molcas output QM.out file)"):
185+
# --- Integrated Extraction Workflow ---
186+
grad_source_file = get_file("Enter the source QM file name", default="QM.out")
187+
188+
num_singlets = get_numeric("Enter the total number of singlet states", default=20, type_cast=int)
189+
num_triplets = get_numeric("Enter the total number of triplet states", default=20, type_cast=int)
190+
191+
state_headers = _generate_state_headers(num_singlets=num_singlets, num_triplets=num_triplets)
192+
193+
lower_state = get_state_name("Enter the lower state (e.g., S2): ")
194+
upper_state = get_state_name("Enter the upper state (e.g., S3): ")
195+
196+
# Process Lower State (State A)
197+
header_A, grad_A_data = _extract_gradient(grad_source_file, lower_state, state_headers)
198+
if grad_A_data is not None:
199+
grad_fileA = Path(f"{lower_state}_gradient.out")
200+
with open(grad_fileA, 'w') as f:
201+
f.write(header_A + '\n')
202+
np.savetxt(f, grad_A_data, fmt='%18.10f')
203+
print(f"✅ Successfully extracted and saved '{grad_fileA}'")
204+
else:
205+
print(f"❌ Failed to extract gradient for '{lower_state}'. Exiting.")
206+
sys.exit(1)
207+
208+
# Process Upper State (State B)
209+
header_B, grad_B_data = _extract_gradient(grad_source_file, upper_state, state_headers)
210+
if grad_B_data is not None:
211+
grad_fileB = Path(f"{upper_state}_gradient.out")
212+
with open(grad_fileB, 'w') as f:
213+
f.write(header_B + '\n')
214+
np.savetxt(f, grad_B_data, fmt='%18.10f')
215+
print(f"✅ Successfully extracted and saved '{grad_fileB}'")
216+
else:
217+
print(f"❌ Failed to extract gradient for '{upper_state}'. Exiting.")
218+
sys.exit(1)
219+
220+
# Automatically extract NAC vector
221+
header_NAC, nac_data = _extract_nac_vector(grad_source_file, lower_state, upper_state, state_headers)
222+
if nac_data is not None:
223+
nac_file = Path(f"NAC_{lower_state}_{upper_state}.out")
224+
with open(nac_file, 'w') as f:
225+
f.write(header_NAC + '\n')
226+
np.savetxt(f, nac_data, fmt='%18.10f')
227+
print(f"✅ Automatically extracted and saved '{nac_file}'")
228+
else:
229+
print(f"❌ Failed to automatically extract NAC vector for {lower_state}-{upper_state}. Please provide it manually.")
230+
nac_file = get_file("Enter the NAC vector file name", default="NAC.out")
231+
232+
else:
233+
# --- Original Workflow ---
234+
grad_fileA = get_file("Enter the gradient file name for State A", default="gradientA.out")
235+
grad_fileB = get_file("Enter the gradient file name for State B", default="gradientB.out")
236+
nac_file = get_file("Enter the NAC vector file name", default="NAC.out")
237+
238+
# -- Continue with the rest of the script as before --
76239
grad_A, skipped_A = logic.load_vector_file(grad_fileA)
77240
grad_B, skipped_B = logic.load_vector_file(grad_fileB)
78241
h, skipped_h = logic.load_vector_file(nac_file)
79242

80243
if any([skipped_A, skipped_B, skipped_h]):
81244
print(f"⚠️ Skipped malformed lines in one or more input files.")
82245

83-
# -- Numeric/energy inputs
84246
E_X = get_numeric("Enter the energy of the intersection point (Hartree)", default=logic.DEFAULT_EX)
85247

86-
# -- Core calculation
87248
params = logic.get_branching_plane_vectors(grad_A, grad_B, h)
88249
print("\n" + "="*40)
89250
print(" Branching Plane Key Quantities")
90-
print(f"theta_s (θ_s) in degrees: {logic.np.degrees(params['theta_s_rad']):.6f}")
251+
print(f"theta_s (θ_s) in degrees: {np.degrees(params['theta_s_rad']):.6f}")
91252
print(f"del_gh (δ_gh): {params['del_gh']:.6f}")
92253
print(f"delta_gh (Δ_gh): {params['delta_gh']:.6f}")
93254
print(f"sigma (σ): {params['sigma']:.6f}")
94255
print("="*40 + "\n")
95256

96-
# -- (Optional) Save key quantities to a file
97257
if ask_yes_no("Save branching plane key quantities to a file?"):
98258
params_path = safe_save_path("Enter filename for parameters", default="ci_parameters.txt")
99259
with open(params_path, "w") as f:
100-
f.write("Branching Plane Key Quantities\n")
101-
f.write("="*40 + "\n")
102-
f.write(f"theta_s (θ_s) in degrees: {logic.np.degrees(params['theta_s_rad']):.6f}\n")
260+
f.write("Branching Plane Key Quantities\n" + "="*40 + "\n")
261+
f.write(f"theta_s (θ_s) in degrees: {np.degrees(params['theta_s_rad']):.6f}\n")
103262
f.write(f"del_gh (δ_gh): {params['del_gh']:.6f}\n")
104263
f.write(f"delta_gh (Δ_gh): {params['delta_gh']:.6f}\n")
105264
f.write(f"sigma (σ): {params['sigma']:.6f}\n")
106265
print(f"✅ Key quantities saved to '{params_path}'\n")
107266

108-
xyz_file = get_file("Enter the xyz file name for atom labels", default="orca.xyz")
109-
atom_list = logic.extract_atom_symbols(xyz_file)
110-
111-
N = len(atom_list)
267+
# --- MODIFIED: Optional XYZ file for atom labels ---
268+
N = grad_A.shape[0] # Get number of atoms from the loaded gradient
112269
x_hat_2d = params['x_hat'].reshape(N, 3)
113270
y_hat_2d = params['y_hat'].reshape(N, 3)
114271

115-
with open("x_vectors.out", "w") as f:
116-
f.write("atoms x vectors\n")
117-
pd.DataFrame({'Atom': atom_list, 'x': x_hat_2d[:,0], 'y': x_hat_2d[:,1], 'z': x_hat_2d[:,2]}).to_csv(f, sep=' ', index=False, header=False, float_format="%.10f")
118-
with open("y_vectors.out", "w") as f:
119-
f.write("atoms y vectors\n")
120-
pd.DataFrame({'Atom': atom_list, 'x': y_hat_2d[:,0], 'y': y_hat_2d[:,1], 'z': y_hat_2d[:,2]}).to_csv(f, sep=' ', index=False, header=False, float_format="%.10f")
121-
print("✅ x_hat and y_hat vectors saved to x_vectors.out and y_vectors.out")
272+
if ask_yes_no("Add atom labels from an XYZ file to the output vectors?"):
273+
xyz_file = get_file("Enter the xyz file name for atom labels", default="orca.xyz")
274+
atom_list = logic.extract_atom_symbols(xyz_file)
275+
276+
if len(atom_list) != N:
277+
print(f"⚠️ Warning: XYZ file has {len(atom_list)} atoms, but gradient files have {N}. Output may be misaligned.")
278+
279+
# Save with atom labels
280+
with open("x_vectors.out", "w") as f:
281+
f.write("atoms x y z\n")
282+
df = pd.DataFrame(x_hat_2d, columns=['x', 'y', 'z'])
283+
df.insert(0, 'Atom', atom_list)
284+
df.to_csv(f, sep=' ', index=False, header=False, float_format="%.10f")
285+
286+
with open("y_vectors.out", "w") as f:
287+
f.write("atoms x y z\n")
288+
df = pd.DataFrame(y_hat_2d, columns=['x', 'y', 'z'])
289+
df.insert(0, 'Atom', atom_list)
290+
df.to_csv(f, sep=' ', index=False, header=False, float_format="%.10f")
291+
else:
292+
# Save without atom labels
293+
with open("x_vectors.out", "w") as f:
294+
f.write("x y z\n")
295+
np.savetxt(f, x_hat_2d, fmt='%18.10f')
122296

297+
with open("y_vectors.out", "w") as f:
298+
f.write("x y z\n")
299+
np.savetxt(f, y_hat_2d, fmt='%18.10f')
300+
301+
print("✅ x_hat and y_hat vectors saved to x_vectors.out and y_vectors.out")
302+
303+
# --- Continue with plotting and animation ---
123304
X, Y, E_A, E_B, had_neg_sqrt = logic.compute_surfaces(params, E_X)
124305
if had_neg_sqrt:
125306
print("⚠️ Some negative values in sqrt term were set to zero during surface calculation.")
126307

127-
# -- Plotting & Animation
128308
if ask_yes_no("Show 3D surface plot now?"):
129309
fig_w = get_numeric("Figure width (inches)", default=logic.DEFAULT_FIGSIZE[0])
130310
fig_h = get_numeric("Figure height (inches)", default=logic.DEFAULT_FIGSIZE[1])
@@ -136,7 +316,7 @@ def main():
136316
dpi = get_numeric("DPI", default=logic.DEFAULT_DPI, type_cast=int)
137317
fig, ax = logic.plot_surfaces(X, Y, E_A, E_B, logic.DEFAULT_FIGSIZE[0], logic.DEFAULT_FIGSIZE[1])
138318
fig.savefig(outpath, dpi=dpi, bbox_inches='tight')
139-
plt.close(fig) # Close the figure to free up memory
319+
plt.close(fig)
140320
print(f"✅ Saved static image as '{outpath}'")
141321

142322
if ask_yes_no("Create a rotation animation?"):
@@ -152,3 +332,6 @@ def main():
152332
except KeyboardInterrupt:
153333
print("\nInterrupted. Exiting gracefully.")
154334
sys.exit(0)
335+
336+
if __name__ == "__main__":
337+
main()

0 commit comments

Comments
 (0)