44
55import os
66from dataclasses import dataclass
7- from typing import TYPE_CHECKING
7+ from typing import List , Optional , TYPE_CHECKING
88
99import py7zr
10+ import py7zr .callbacks
11+ from tqdm import tqdm
1012
1113if TYPE_CHECKING :
1214 # This trick will avoid us from causing a cyclical reference within the class
1315 # Credit: https://stackoverflow.com/a/39757388
16+ from falconpy import RealTimeResponse
17+ from caracara import Client
1418 from caracara .modules .rtr .batch_session import RTRBatchSession
1519
1620
21+ class SevenZipExtractProgressBar (py7zr .callbacks .ExtractCallback , tqdm ):
22+ """A progress bar extractor for py7zr.
23+
24+ This code is heavily based on an example published on GitHub here:
25+ https://github.com/miurahr/py7zr/pull/558
26+ """
27+ def __init__ (self , * args , total_bytes : int , ** kwargs ):
28+ super ().__init__ (self , * args , total = total_bytes , ** kwargs )
29+
30+ def report_start_preparation (self ):
31+ pass
32+
33+ def report_start (self , processing_file_path , processing_bytes ):
34+ pass
35+
36+ def report_end (self , processing_file_path , wrote_bytes ):
37+ pass
38+
39+ def report_update (self , decompressed_bytes ):
40+ self .update (int (decompressed_bytes ))
41+
42+ def report_postprocess (self ):
43+ pass
44+
45+ def report_warning (self , message ):
46+ pass
47+
48+
1749@dataclass
1850class GetFile :
19- """
20- Represents a file uploaded to Falcon via a GET command.
51+ """Represent a file uploaded to Falcon via a GET command.
2152
2253 This class may be instantiated many times, with each object stored in a common list,
2354 in order to represent many files retrieved from a GET comamnd executed against a
2455 batch session.
56+
57+ Only one of batch_session, client, or rtr_api need to be set. When the rtr_api property
58+ is called, a RealTimeResponse object will be returned according to this priority list:
59+ - _rtr_api (i.e., rtr_api is set via the GetFile object's property setter)
60+ - client.rtr.rtr_api
61+ - batch_session.api
2562 """
2663
2764 device_id : str = None
@@ -30,19 +67,68 @@ class GetFile:
3067 sha256 : str = None
3168 size : int = None
3269 batch_session : RTRBatchSession = None
70+ client : Client = None
71+ _rtr_api : RealTimeResponse = None
72+
73+ @property
74+ def rtr_api (self ) -> RealTimeResponse :
75+ if self ._rtr_api :
76+ return self ._rtr_api
77+ if self .client :
78+ return self .client .rtr .rtr_api
79+ if self .batch_session :
80+ return self .batch_session .api
81+ raise AttributeError (
82+ "You must set at least one of batch_session, client, or rtr_api, so that this object "
83+ "can access the FalconPy Real Time Response API wrapper."
84+ )
85+
86+ @rtr_api .setter
87+ def rtr_api (self , v : RealTimeResponse ) -> None :
88+ self ._rtr_api = v
3389
3490 def download (
3591 self ,
3692 output_path : str ,
3793 extract : bool = True ,
3894 preserve_7z : bool = False ,
95+ show_download_progress : bool = False ,
96+ download_chunk_size : int = 1048576 , # 1MiB
3997 ):
40- """
41- Download a file to a specified path.
98+ """Download a file to a specified path.
4299
43100 If the path is a folder, the filename will be auto generated.
44101 If the path is a file, it'll be downloaded to that path and name.
102+
103+ All downloads require that the following three attributes be set in the GetFile object:
104+ - session_id (RTR Session ID)
105+ - sha256 (SHA256 hash of the extracted file)
106+ - filename (name of the extracted file)
107+ These three values are all returned when the status of a GET command is queried.
108+
109+ Additionally, this function requires that a Device ID (AID) be stored within the GetFile
110+ object to include the Device ID in the eventual filename. If this is not provided, and
111+ output_path is set to a directory, the AID will be excluded from the calculated eventual
112+ file name. Note that this is irrelevant if output_path is NOT a path to a directory, as
113+ the filename provided to this parameter will be used instead of one derived by the library.
114+
115+ Other parameters:
116+ - show_download_progress: Whether or not to draw the download progress via TQDM.
117+ - download_chunk_size: Size of each chunk to stream via requests. Defaults to 1MiB.
118+ If this is set to 0, chunking will not be used.
45119 """
120+ if (
121+ not self .session_id or
122+ not self .sha256 or
123+ not self .filename
124+ ):
125+ raise ValueError (
126+ "A session ID, SHA256 hash, and filename are all required to download a GET file. "
127+ "Ensure these values are set in the object before calling the download function."
128+ )
129+
130+ # Get the name of the uploaded file from the filename value, which actually contains the
131+ # full path to the file on the origin system's disk.
46132 if self .filename .startswith ("/" ):
47133 # macOS or *nix path
48134 filename = self .filename .rsplit ("/" , maxsplit = 1 )[- 1 ]
@@ -52,11 +138,20 @@ def download(
52138
53139 filename_noext , ext = os .path .splitext (filename )
54140
141+ # Figure out what the file should be named.
142+ # If a directory is provided as an output, we rename the file according to its name,
143+ # hash and origin device AID.
144+ # Otherwise, we use the exact filename provided as a parameter.
55145 if os .path .isdir (output_path ):
56146 # Output path is a folder, so we should compute the filename
147+ if self .device_id :
148+ output_filename = f"{ filename_noext } _{ self .sha256 } _{ self .device_id } { ext } " ,
149+ else :
150+ output_filename = f"{ filename_noext } _{ self .sha256 } { ext } "
151+
57152 full_output_path = os .path .join (
58153 output_path ,
59- f" { filename_noext } _ { self . sha256 } _ { self . device_id } { ext } " ,
154+ output_filename ,
60155 )
61156 full_output_path_7z = full_output_path + ".7z"
62157 else :
@@ -70,29 +165,95 @@ def download(
70165 else :
71166 full_output_path_7z = full_output_path + ".7z"
72167
73- file_contents = self .batch_session .api .get_extracted_file_contents (
74- session_id = self .session_id ,
75- sha256 = self .sha256 ,
76- filename = self .filename ,
77- )
168+ # Chunked downloads can be disabled by providing a 0 as the chunk size.
169+ # This is rarely advantageous, but is provided for compatability.
170+ # Non-chunked downloads do not support progress bars, so we just skip the tqdm invocation
171+ # if download_chunk_size = 0, and instead write the file straight to disk from memory.
172+ if download_chunk_size == 0 :
173+ file_contents = self .rtr_api .get_extracted_file_contents (
174+ session_id = self .session_id ,
175+ sha256 = self .sha256 ,
176+ filename = self .filename ,
177+ )
178+ with open (full_output_path_7z , "wb" ) as output_7z_file :
179+ output_7z_file .write (file_contents )
180+ else :
181+ get_file_response = self .rtr_api .get_extracted_file_contents (
182+ session_id = self .session_id ,
183+ sha256 = self .sha256 ,
184+ filename = self .filename ,
185+ stream = True ,
186+ )
187+
188+ if show_download_progress :
189+ with tqdm .wrapattr (
190+ stream = open (full_output_path_7z , "wb" ),
191+ method = "write" ,
192+ total = 0 , # Content-Length header is not sent by RTR, so tqdm will count upwards
193+ miniters = 1 ,
194+ bytes = True ,
195+ desc = os .path .basename (full_output_path_7z ),
196+ ) as output_7z_file :
197+ for chunk in get_file_response .iter_content (chunk_size = download_chunk_size ):
198+ output_7z_file .write (chunk )
78199
79- with open (full_output_path_7z , "wb" ) as output_7z_file :
80- output_7z_file .write (file_contents )
200+ # We have to manually close the file once the download is complete, or chunk
201+ # data will not be flushed to disk and the 7-Zip archive will be unreadable by
202+ # py7zr. See: https://github.com/tqdm/tqdm/issues/1247
203+ output_7z_file .close ()
204+ else :
205+ # If no progress bar is requested, we still use chunking to avoid storing the whole
206+ # 7-Zip archive into memory before writing it to disk.
207+ with open (full_output_path_7z , "wb" ) as output_7z_file :
208+ for chunk in get_file_response .iter_content (chunk_size = download_chunk_size ):
209+ output_7z_file .write (chunk )
81210
82211 if not extract :
83212 # Downloaded, so we're done now!
84213 return
85214
86- with py7zr .SevenZipFile ( # nosec - The password 'infected' is generic and always the same
215+ target_dir = os .path .dirname (full_output_path_7z )
216+
217+ with py7zr .SevenZipFile ( # nosec - The password "infected" is generic and always the same
87218 file = full_output_path_7z ,
88219 mode = "r" ,
89220 password = "infected" ,
90221 ) as archive :
91- target_dir = os .path .dirname (full_output_path_7z )
92- archive .extract (path = target_dir )
222+ archive_filenames = archive .getnames ()
223+ if show_download_progress :
224+ archive_info = archive .archiveinfo ()
225+ with SevenZipExtractProgressBar (
226+ unit = "B" ,
227+ unit_scale = True ,
228+ miniters = 1 ,
229+ total_bytes = archive_info .uncompressed ,
230+ desc = f"Extracting..." ,
231+ ascii = True ,
232+ ) as progress :
233+ # archive.extract() does not provide a callback parameter, but behaviour should
234+ # not differ since RTR 7-Zip archives should always contain exactly one inner
235+ # file.
236+ archive .extractall (path = target_dir , callback = progress )
237+ else :
238+ archive .extract (path = target_dir )
239+
240+ # Check that we truly only received exactly one output file in the 7-Zip archive.
241+ # If all is well, we rename the first (and hopefully only) output file to match the name
242+ # derived at the beginning of this function.
243+ if archive_filenames and len (archive_filenames ) == 1 :
244+ orig_filename = archive_filenames [0 ]
245+ orig_path = os .path .join (target_dir , orig_filename )
246+ os .rename (orig_path , full_output_path )
247+ else :
248+ raise ValueError (
249+ "The downloaded 7-Zip archive contains the wrong number of files. Contents: %s" ,
250+ str (archive_filenames ),
251+ )
93252
94- if not preserve_7z :
95- os .unlink (full_output_path_7z )
253+ # Delete the 7-Zip archive after extracting its contents if the developer told us it
254+ # does not need to be preserved.
255+ if not preserve_7z :
256+ os .unlink (full_output_path_7z )
96257
97258 def __str__ (self ):
98259 """Return a loggable string representing the contents of the object."""
0 commit comments