1+ #!/usr/bin/env python
12# conezen/cli.py
23
34import sys
45import shutil
56from pathlib import Path
67import pandas as pd
78import matplotlib .pyplot as plt
9+ import numpy as np
10+ import re
811
12+ # This is the logic from your other file, now part of ConeZen
913from . import logic
1014
15+ # --- GRADIENT & NAC EXTRACTION HELPERS ---
16+
17+ def _generate_state_headers (num_singlets = 20 , num_triplets = 20 ):
18+ """Generates a dictionary mapping common state names to their file header patterns."""
19+ headers = {}
20+ for n in range (1 , num_singlets + 1 ):
21+ state_name = f"S{ n - 1 } "
22+ headers [state_name ] = f"m1 1 s1 { n } ms1 0"
23+ for n in range (1 , num_triplets + 1 ):
24+ state_name_base = f"T{ n } "
25+ headers [f"{ state_name_base } _ms-1" ] = f"m1 3 s1 { n } ms1 -1"
26+ headers [f"{ state_name_base } _ms0" ] = f"m1 3 s1 { n } ms1 0"
27+ headers [f"{ state_name_base } _ms+1" ] = f"m1 3 s1 { n } ms1 1"
28+ return headers
29+
30+ def _extract_gradient (file_path , state_name , state_headers ):
31+ """
32+ Extracts a gradient for a given state, returning the header and numpy array.
33+ """
34+ header_to_find = state_headers .get (state_name .upper ())
35+ if not header_to_find :
36+ print (f"❌ Error: State '{ state_name } ' is not defined in the header dictionary." )
37+ return None , None
38+
39+ try :
40+ with open (file_path , 'r' ) as f :
41+ lines = f .readlines ()
42+ except FileNotFoundError :
43+ print (f"❌ Error: The gradient source file '{ file_path } ' was not found." )
44+ return None , None
45+
46+ for i , line in enumerate (lines ):
47+ # We only search for the gradient header part, as it's unique enough
48+ if header_to_find in line and "m2" not in line :
49+ found_header_line = line .strip ()
50+ gradient_data = []
51+ try :
52+ num_atoms = int (line .strip ().split ()[0 ])
53+ if i + 1 + num_atoms > len (lines ):
54+ return None , None
55+ for j in range (num_atoms ):
56+ data_line = lines [i + 1 + j ].strip ().split ()
57+ gradient_data .append ([float (x ) for x in data_line ])
58+ return found_header_line , np .array (gradient_data )
59+ except (ValueError , IndexError ):
60+ return None , None
61+ return None , None
62+
63+ def _format_nac_part2 (state_header_part ):
64+ """Converts a gradient-style header part to a NAC part 2 style."""
65+ # e.g., "m1 1 s1 4 ms1 0" -> "m2 1 s2 4 ms2 0"
66+ parts = state_header_part .split ()
67+ parts [0 ] = 'm2'
68+ parts [2 ] = 's2'
69+ parts [4 ] = 'ms2'
70+ return " " .join (parts )
71+
72+ def _extract_nac_vector (file_path , state1_name , state2_name , state_headers ):
73+ """
74+ Extracts the NAC vector between two states, accounting for commutative headers.
75+ """
76+ part1_raw = state_headers .get (state1_name .upper ())
77+ part2_raw = state_headers .get (state2_name .upper ())
78+
79+ if not part1_raw or not part2_raw :
80+ print (f"❌ Error: One or both states ('{ state1_name } ', '{ state2_name } ') are not defined." )
81+ return None , None
82+
83+ # Construct the two possible header combinations, e.g., (S2,S3) and (S3,S2)
84+ header_combo1 = f"{ part1_raw } { _format_nac_part2 (part2_raw )} "
85+ header_combo2 = f"{ part2_raw } { _format_nac_part2 (part1_raw )} "
86+
87+ try :
88+ with open (file_path , 'r' ) as f :
89+ lines = f .readlines ()
90+ except FileNotFoundError :
91+ print (f"❌ Error: The NAC source file '{ file_path } ' was not found." )
92+ return None , None
93+
94+ for i , line in enumerate (lines ):
95+ # Check for both possible header orders
96+ if header_combo1 in line or header_combo2 in line :
97+ found_header_line = line .strip ()
98+ nac_data = []
99+ try :
100+ num_atoms = int (line .strip ().split ()[0 ])
101+ if i + 1 + num_atoms > len (lines ):
102+ return None , None
103+ for j in range (num_atoms ):
104+ data_line = lines [i + 1 + j ].strip ().split ()
105+ nac_data .append ([float (x ) for x in data_line ])
106+ return found_header_line , np .array (nac_data )
107+ except (ValueError , IndexError ):
108+ return None , None
109+ return None , None
110+
111+
11112# --- USER INTERACTION HELPERS ---
12113def print_about ():
13114 """Display script info and citation."""
@@ -53,6 +154,15 @@ def get_file(prompt: str, must_exist=True, default=None):
53154 print (f"❌ File '{ fname } ' not found." ); continue
54155 return path
55156
157+ def get_state_name (prompt : str ) -> str :
158+ """Get a non-empty state name from the user."""
159+ while True :
160+ state_name = input (prompt ).strip ().upper ()
161+ if state_name :
162+ return state_name
163+ print ("❌ State name cannot be empty. Please enter a value." )
164+
165+
56166def safe_save_path (prompt : str , default : str ):
57167 """Prompt for a save path, warn if exists."""
58168 while True :
@@ -68,63 +178,133 @@ def main():
68178 """Main function to run the command-line tool."""
69179 try :
70180 print_about ()
71- # -- Input files
72- grad_fileA = get_file ("Enter the gradient file name for State A" , default = "gradientA.out" )
73- grad_fileB = get_file ("Enter the gradient file name for State B" , default = "gradientB.out" )
74- nac_file = get_file ("Enter the NAC vector file name" , default = "NAC.out" )
75-
181+
182+ grad_fileA , grad_fileB , nac_file = None , None , None
183+
184+ if ask_yes_no ("(Do you want to automatically extract gradients and NACs from a QM output file? (only for sharc-molcas output QM.out file)" ):
185+ # --- Integrated Extraction Workflow ---
186+ grad_source_file = get_file ("Enter the source QM file name" , default = "QM.out" )
187+
188+ num_singlets = get_numeric ("Enter the total number of singlet states" , default = 20 , type_cast = int )
189+ num_triplets = get_numeric ("Enter the total number of triplet states" , default = 20 , type_cast = int )
190+
191+ state_headers = _generate_state_headers (num_singlets = num_singlets , num_triplets = num_triplets )
192+
193+ lower_state = get_state_name ("Enter the lower state (e.g., S2): " )
194+ upper_state = get_state_name ("Enter the upper state (e.g., S3): " )
195+
196+ # Process Lower State (State A)
197+ header_A , grad_A_data = _extract_gradient (grad_source_file , lower_state , state_headers )
198+ if grad_A_data is not None :
199+ grad_fileA = Path (f"{ lower_state } _gradient.out" )
200+ with open (grad_fileA , 'w' ) as f :
201+ f .write (header_A + '\n ' )
202+ np .savetxt (f , grad_A_data , fmt = '%18.10f' )
203+ print (f"✅ Successfully extracted and saved '{ grad_fileA } '" )
204+ else :
205+ print (f"❌ Failed to extract gradient for '{ lower_state } '. Exiting." )
206+ sys .exit (1 )
207+
208+ # Process Upper State (State B)
209+ header_B , grad_B_data = _extract_gradient (grad_source_file , upper_state , state_headers )
210+ if grad_B_data is not None :
211+ grad_fileB = Path (f"{ upper_state } _gradient.out" )
212+ with open (grad_fileB , 'w' ) as f :
213+ f .write (header_B + '\n ' )
214+ np .savetxt (f , grad_B_data , fmt = '%18.10f' )
215+ print (f"✅ Successfully extracted and saved '{ grad_fileB } '" )
216+ else :
217+ print (f"❌ Failed to extract gradient for '{ upper_state } '. Exiting." )
218+ sys .exit (1 )
219+
220+ # Automatically extract NAC vector
221+ header_NAC , nac_data = _extract_nac_vector (grad_source_file , lower_state , upper_state , state_headers )
222+ if nac_data is not None :
223+ nac_file = Path (f"NAC_{ lower_state } _{ upper_state } .out" )
224+ with open (nac_file , 'w' ) as f :
225+ f .write (header_NAC + '\n ' )
226+ np .savetxt (f , nac_data , fmt = '%18.10f' )
227+ print (f"✅ Automatically extracted and saved '{ nac_file } '" )
228+ else :
229+ print (f"❌ Failed to automatically extract NAC vector for { lower_state } -{ upper_state } . Please provide it manually." )
230+ nac_file = get_file ("Enter the NAC vector file name" , default = "NAC.out" )
231+
232+ else :
233+ # --- Original Workflow ---
234+ grad_fileA = get_file ("Enter the gradient file name for State A" , default = "gradientA.out" )
235+ grad_fileB = get_file ("Enter the gradient file name for State B" , default = "gradientB.out" )
236+ nac_file = get_file ("Enter the NAC vector file name" , default = "NAC.out" )
237+
238+ # -- Continue with the rest of the script as before --
76239 grad_A , skipped_A = logic .load_vector_file (grad_fileA )
77240 grad_B , skipped_B = logic .load_vector_file (grad_fileB )
78241 h , skipped_h = logic .load_vector_file (nac_file )
79242
80243 if any ([skipped_A , skipped_B , skipped_h ]):
81244 print (f"⚠️ Skipped malformed lines in one or more input files." )
82245
83- # -- Numeric/energy inputs
84246 E_X = get_numeric ("Enter the energy of the intersection point (Hartree)" , default = logic .DEFAULT_EX )
85247
86- # -- Core calculation
87248 params = logic .get_branching_plane_vectors (grad_A , grad_B , h )
88249 print ("\n " + "=" * 40 )
89250 print (" Branching Plane Key Quantities" )
90- print (f"theta_s (θ_s) in degrees: { logic . np .degrees (params ['theta_s_rad' ]):.6f} " )
251+ print (f"theta_s (θ_s) in degrees: { np .degrees (params ['theta_s_rad' ]):.6f} " )
91252 print (f"del_gh (δ_gh): { params ['del_gh' ]:.6f} " )
92253 print (f"delta_gh (Δ_gh): { params ['delta_gh' ]:.6f} " )
93254 print (f"sigma (σ): { params ['sigma' ]:.6f} " )
94255 print ("=" * 40 + "\n " )
95256
96- # -- (Optional) Save key quantities to a file
97257 if ask_yes_no ("Save branching plane key quantities to a file?" ):
98258 params_path = safe_save_path ("Enter filename for parameters" , default = "ci_parameters.txt" )
99259 with open (params_path , "w" ) as f :
100- f .write ("Branching Plane Key Quantities\n " )
101- f .write ("=" * 40 + "\n " )
102- f .write (f"theta_s (θ_s) in degrees: { logic .np .degrees (params ['theta_s_rad' ]):.6f} \n " )
260+ f .write ("Branching Plane Key Quantities\n " + "=" * 40 + "\n " )
261+ f .write (f"theta_s (θ_s) in degrees: { np .degrees (params ['theta_s_rad' ]):.6f} \n " )
103262 f .write (f"del_gh (δ_gh): { params ['del_gh' ]:.6f} \n " )
104263 f .write (f"delta_gh (Δ_gh): { params ['delta_gh' ]:.6f} \n " )
105264 f .write (f"sigma (σ): { params ['sigma' ]:.6f} \n " )
106265 print (f"✅ Key quantities saved to '{ params_path } '\n " )
107266
108- xyz_file = get_file ("Enter the xyz file name for atom labels" , default = "orca.xyz" )
109- atom_list = logic .extract_atom_symbols (xyz_file )
110-
111- N = len (atom_list )
267+ # --- MODIFIED: Optional XYZ file for atom labels ---
268+ N = grad_A .shape [0 ] # Get number of atoms from the loaded gradient
112269 x_hat_2d = params ['x_hat' ].reshape (N , 3 )
113270 y_hat_2d = params ['y_hat' ].reshape (N , 3 )
114271
115- with open ("x_vectors.out" , "w" ) as f :
116- f .write ("atoms x vectors\n " )
117- pd .DataFrame ({'Atom' : atom_list , 'x' : x_hat_2d [:,0 ], 'y' : x_hat_2d [:,1 ], 'z' : x_hat_2d [:,2 ]}).to_csv (f , sep = ' ' , index = False , header = False , float_format = "%.10f" )
118- with open ("y_vectors.out" , "w" ) as f :
119- f .write ("atoms y vectors\n " )
120- pd .DataFrame ({'Atom' : atom_list , 'x' : y_hat_2d [:,0 ], 'y' : y_hat_2d [:,1 ], 'z' : y_hat_2d [:,2 ]}).to_csv (f , sep = ' ' , index = False , header = False , float_format = "%.10f" )
121- print ("✅ x_hat and y_hat vectors saved to x_vectors.out and y_vectors.out" )
272+ if ask_yes_no ("Add atom labels from an XYZ file to the output vectors?" ):
273+ xyz_file = get_file ("Enter the xyz file name for atom labels" , default = "orca.xyz" )
274+ atom_list = logic .extract_atom_symbols (xyz_file )
275+
276+ if len (atom_list ) != N :
277+ print (f"⚠️ Warning: XYZ file has { len (atom_list )} atoms, but gradient files have { N } . Output may be misaligned." )
278+
279+ # Save with atom labels
280+ with open ("x_vectors.out" , "w" ) as f :
281+ f .write ("atoms x y z\n " )
282+ df = pd .DataFrame (x_hat_2d , columns = ['x' , 'y' , 'z' ])
283+ df .insert (0 , 'Atom' , atom_list )
284+ df .to_csv (f , sep = ' ' , index = False , header = False , float_format = "%.10f" )
285+
286+ with open ("y_vectors.out" , "w" ) as f :
287+ f .write ("atoms x y z\n " )
288+ df = pd .DataFrame (y_hat_2d , columns = ['x' , 'y' , 'z' ])
289+ df .insert (0 , 'Atom' , atom_list )
290+ df .to_csv (f , sep = ' ' , index = False , header = False , float_format = "%.10f" )
291+ else :
292+ # Save without atom labels
293+ with open ("x_vectors.out" , "w" ) as f :
294+ f .write ("x y z\n " )
295+ np .savetxt (f , x_hat_2d , fmt = '%18.10f' )
122296
297+ with open ("y_vectors.out" , "w" ) as f :
298+ f .write ("x y z\n " )
299+ np .savetxt (f , y_hat_2d , fmt = '%18.10f' )
300+
301+ print ("✅ x_hat and y_hat vectors saved to x_vectors.out and y_vectors.out" )
302+
303+ # --- Continue with plotting and animation ---
123304 X , Y , E_A , E_B , had_neg_sqrt = logic .compute_surfaces (params , E_X )
124305 if had_neg_sqrt :
125306 print ("⚠️ Some negative values in sqrt term were set to zero during surface calculation." )
126307
127- # -- Plotting & Animation
128308 if ask_yes_no ("Show 3D surface plot now?" ):
129309 fig_w = get_numeric ("Figure width (inches)" , default = logic .DEFAULT_FIGSIZE [0 ])
130310 fig_h = get_numeric ("Figure height (inches)" , default = logic .DEFAULT_FIGSIZE [1 ])
@@ -136,7 +316,7 @@ def main():
136316 dpi = get_numeric ("DPI" , default = logic .DEFAULT_DPI , type_cast = int )
137317 fig , ax = logic .plot_surfaces (X , Y , E_A , E_B , logic .DEFAULT_FIGSIZE [0 ], logic .DEFAULT_FIGSIZE [1 ])
138318 fig .savefig (outpath , dpi = dpi , bbox_inches = 'tight' )
139- plt .close (fig ) # Close the figure to free up memory
319+ plt .close (fig )
140320 print (f"✅ Saved static image as '{ outpath } '" )
141321
142322 if ask_yes_no ("Create a rotation animation?" ):
@@ -152,3 +332,6 @@ def main():
152332 except KeyboardInterrupt :
153333 print ("\n Interrupted. Exiting gracefully." )
154334 sys .exit (0 )
335+
336+ if __name__ == "__main__" :
337+ main ()
0 commit comments