Skip to content

Commit 50415bd

Browse files
committed
list of files for input
1 parent 8502423 commit 50415bd

1 file changed

Lines changed: 58 additions & 41 deletions

File tree

src/graphnet/utilities/filesys.py

Lines changed: 58 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pathlib import Path
44
import re
5+
import os
56
from typing import List, Optional, Tuple, Union
67

78

@@ -31,7 +32,7 @@ def has_extension(filename: str, extensions: List[str]) -> bool:
3132

3233

3334
def find_i3_files(
34-
directories: Union[str, List[str]],
35+
inputs: Union[str, List[str]],
3536
gcd_rescue: Optional[str] = None,
3637
recursive: Optional[bool] = True,
3738
) -> Tuple[List[str], List[str]]:
@@ -42,7 +43,8 @@ def find_i3_files(
4243
in the directory.
4344
4445
Args:
45-
directories: Directories to search recursively for I3 files.
46+
inputs: Directories to search recursively for I3 files.
47+
Or list of I3 files.
4648
gcd_rescue: Path to the GCD that will be default if no GCD is present
4749
in the directory.
4850
recursive: Whether or not to search the directories recursively.
@@ -51,46 +53,61 @@ def find_i3_files(
5153
i3_list: Paths to I3 files in `directories`
5254
gcd_list: Paths to GCD files for each I3 file.
5355
"""
54-
if isinstance(directories, str):
55-
directories = [directories]
56+
if isinstance(inputs, str):
57+
inputs = [inputs]
5658

5759
# Output containers
5860
i3_files = []
5961
gcd_files = []
60-
61-
for directory in directories:
62-
63-
# Find all I3-like files in `directory`, may or may not be recursively.
64-
paths = []
65-
i3_patterns = ["*.bz2", "*.zst", "*.gz"]
66-
for i3_pattern in i3_patterns:
67-
if recursive:
68-
paths.extend(list(Path(directory).rglob(i3_pattern)))
69-
else:
70-
paths.extend(list(Path(directory).glob(i3_pattern)))
71-
72-
# Loop over all folders containing such I3-like files.
73-
folders = sorted(set([path.parent for path in paths]))
74-
for folder in folders:
75-
76-
# List all I3 and GCD files, respectively, in the current folder.
77-
folder_files = [
78-
str(path) for path in paths if path.parent == folder
79-
]
80-
folder_i3_files = list(filter(is_i3_file, folder_files))
81-
folder_gcd_files = list(filter(is_gcd_file, folder_files))
82-
83-
# Make sure that no more than one GCD file is found;
84-
# and use rescue file if none is found.
85-
assert len(folder_gcd_files) <= 1
86-
if len(folder_gcd_files) == 0:
87-
assert gcd_rescue is not None
88-
folder_gcd_files = [gcd_rescue]
89-
90-
# Store list of I3 files and corresponding GCD files.
91-
folder_gcd_files = folder_gcd_files * len(folder_i3_files)
92-
93-
gcd_files.extend(folder_gcd_files)
94-
i3_files.extend(folder_i3_files)
95-
96-
return i3_files, gcd_files
62+
if all([is_i3_file(input) for input in inputs]):
63+
print("Assuming list of files.")
64+
assert gcd_rescue is not None
65+
gcd_files = [gcd_rescue] * len(inputs)
66+
return inputs, gcd_files
67+
68+
elif all(os.path.isdir(input) for input in inputs):
69+
print("Assuming list of directories.")
70+
71+
for directory in inputs:
72+
73+
# Find all I3-like files in `directory`.
74+
paths = []
75+
i3_patterns = ["*.bz2", "*.zst", "*.gz"]
76+
for i3_pattern in i3_patterns:
77+
if recursive:
78+
paths.extend(list(Path(directory).rglob(i3_pattern)))
79+
else:
80+
paths.extend(list(Path(directory).glob(i3_pattern)))
81+
82+
# Loop over all folders containing such I3-like files.
83+
folders = sorted(set([path.parent for path in paths]))
84+
for folder in folders:
85+
86+
# List all I3 and GCD files, in the current folder.
87+
folder_files = [
88+
str(path) for path in paths if path.parent == folder
89+
]
90+
folder_i3_files = list(filter(is_i3_file, folder_files))
91+
folder_gcd_files = list(filter(is_gcd_file, folder_files))
92+
93+
# Make sure that no more than one GCD file is found;
94+
# and use rescue file if none is found.
95+
assert len(folder_gcd_files) <= 1
96+
if len(folder_gcd_files) == 0:
97+
assert gcd_rescue is not None
98+
folder_gcd_files = [gcd_rescue]
99+
100+
# Store list of I3 files and corresponding GCD files.
101+
folder_gcd_files = folder_gcd_files * len(folder_i3_files)
102+
103+
gcd_files.extend(folder_gcd_files)
104+
i3_files.extend(folder_i3_files)
105+
return i3_files, gcd_files
106+
else:
107+
if any([os.path.isdir(input) for input in inputs]):
108+
raise ValueError(
109+
"Inputs contains a mix of files and directories \
110+
which is not supported."
111+
)
112+
else:
113+
raise ValueError("Some inputs are not valid directories or files.")

0 commit comments

Comments
 (0)