Skip to content

Commit 269d37b

Browse files
Adds streaming file downloads and progress bars for RTR GET file downloads
1 parent cd82a8f commit 269d37b

3 files changed

Lines changed: 282 additions & 93 deletions

File tree

caracara/modules/rtr/get_file.py

Lines changed: 179 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,61 @@
44

55
import os
66
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING
7+
from typing import List, Optional, TYPE_CHECKING
88

99
import py7zr
10+
import py7zr.callbacks
11+
from tqdm import tqdm
1012

1113
if TYPE_CHECKING:
1214
# This trick will avoid us from causing a cyclical reference within the class
1315
# Credit: https://stackoverflow.com/a/39757388
16+
from falconpy import RealTimeResponse
17+
from caracara import Client
1418
from caracara.modules.rtr.batch_session import RTRBatchSession
1519

1620

21+
class SevenZipExtractProgressBar(py7zr.callbacks.ExtractCallback, tqdm):
22+
"""A progress bar extractor for py7zr.
23+
24+
This code is heavily based on an example published on GitHub here:
25+
https://github.com/miurahr/py7zr/pull/558
26+
"""
27+
def __init__(self, *args, total_bytes: int, **kwargs):
28+
super().__init__(self, *args, total=total_bytes, **kwargs)
29+
30+
def report_start_preparation(self):
31+
pass
32+
33+
def report_start(self, processing_file_path, processing_bytes):
34+
pass
35+
36+
def report_end(self, processing_file_path, wrote_bytes):
37+
pass
38+
39+
def report_update(self, decompressed_bytes):
40+
self.update(int(decompressed_bytes))
41+
42+
def report_postprocess(self):
43+
pass
44+
45+
def report_warning(self, message):
46+
pass
47+
48+
1749
@dataclass
1850
class GetFile:
19-
"""
20-
Represents a file uploaded to Falcon via a GET command.
51+
"""Represent a file uploaded to Falcon via a GET command.
2152
2253
This class may be instantiated many times, with each object stored in a common list,
2354
in order to represent many files retrieved from a GET comamnd executed against a
2455
batch session.
56+
57+
Only one of batch_session, client, or rtr_api need to be set. When the rtr_api property
58+
is called, a RealTimeResponse object will be returned according to this priority list:
59+
- _rtr_api (i.e., rtr_api is set via the GetFile object's property setter)
60+
- client.rtr.rtr_api
61+
- batch_session.api
2562
"""
2663

2764
device_id: str = None
@@ -30,19 +67,68 @@ class GetFile:
3067
sha256: str = None
3168
size: int = None
3269
batch_session: RTRBatchSession = None
70+
client: Client = None
71+
_rtr_api: RealTimeResponse = None
72+
73+
@property
74+
def rtr_api(self) -> RealTimeResponse:
75+
if self._rtr_api:
76+
return self._rtr_api
77+
if self.client:
78+
return self.client.rtr.rtr_api
79+
if self.batch_session:
80+
return self.batch_session.api
81+
raise AttributeError(
82+
"You must set at least one of batch_session, client, or rtr_api, so that this object "
83+
"can access the FalconPy Real Time Response API wrapper."
84+
)
85+
86+
@rtr_api.setter
87+
def rtr_api(self, v: RealTimeResponse) -> None:
88+
self._rtr_api = v
3389

3490
def download(
3591
self,
3692
output_path: str,
3793
extract: bool = True,
3894
preserve_7z: bool = False,
95+
show_download_progress: bool = False,
96+
download_chunk_size: int = 1048576, # 1MiB
3997
):
40-
"""
41-
Download a file to a specified path.
98+
"""Download a file to a specified path.
4299
43100
If the path is a folder, the filename will be auto generated.
44101
If the path is a file, it'll be downloaded to that path and name.
102+
103+
All downloads require that the following three attributes be set in the GetFile object:
104+
- session_id (RTR Session ID)
105+
- sha256 (SHA256 hash of the extracted file)
106+
- filename (name of the extracted file)
107+
These three values are all returned when the status of a GET command is queried.
108+
109+
Additionally, this function requires that a Device ID (AID) be stored within the GetFile
110+
object to include the Device ID in the eventual filename. If this is not provided, and
111+
output_path is set to a directory, the AID will be excluded from the calculated eventual
112+
file name. Note that this is irrelevant if output_path is NOT a path to a directory, as
113+
the filename provided to this parameter will be used instead of one derived by the library.
114+
115+
Other parameters:
116+
- show_download_progress: Whether or not to draw the download progress via TQDM.
117+
- download_chunk_size: Size of each chunk to stream via requests. Defaults to 1MiB.
118+
If this is set to 0, chunking will not be used.
45119
"""
120+
if (
121+
not self.session_id or
122+
not self.sha256 or
123+
not self.filename
124+
):
125+
raise ValueError(
126+
"A session ID, SHA256 hash, and filename are all required to download a GET file. "
127+
"Ensure these values are set in the object before calling the download function."
128+
)
129+
130+
# Get the name of the uploaded file from the filename value, which actually contains the
131+
# full path to the file on the origin system's disk.
46132
if self.filename.startswith("/"):
47133
# macOS or *nix path
48134
filename = self.filename.rsplit("/", maxsplit=1)[-1]
@@ -52,11 +138,20 @@ def download(
52138

53139
filename_noext, ext = os.path.splitext(filename)
54140

141+
# Figure out what the file should be named.
142+
# If a directory is provided as an output, we rename the file according to its name,
143+
# hash and origin device AID.
144+
# Otherwise, we use the exact filename provided as a parameter.
55145
if os.path.isdir(output_path):
56146
# Output path is a folder, so we should compute the filename
147+
if self.device_id:
148+
output_filename = f"{filename_noext}_{self.sha256}_{self.device_id}{ext}",
149+
else:
150+
output_filename = f"{filename_noext}_{self.sha256}{ext}"
151+
57152
full_output_path = os.path.join(
58153
output_path,
59-
f"{filename_noext}_{self.sha256}_{self.device_id}{ext}",
154+
output_filename,
60155
)
61156
full_output_path_7z = full_output_path + ".7z"
62157
else:
@@ -70,29 +165,95 @@ def download(
70165
else:
71166
full_output_path_7z = full_output_path + ".7z"
72167

73-
file_contents = self.batch_session.api.get_extracted_file_contents(
74-
session_id=self.session_id,
75-
sha256=self.sha256,
76-
filename=self.filename,
77-
)
168+
# Chunked downloads can be disabled by providing a 0 as the chunk size.
169+
# This is rarely advantageous, but is provided for compatability.
170+
# Non-chunked downloads do not support progress bars, so we just skip the tqdm invocation
171+
# if download_chunk_size = 0, and instead write the file straight to disk from memory.
172+
if download_chunk_size == 0:
173+
file_contents = self.rtr_api.get_extracted_file_contents(
174+
session_id=self.session_id,
175+
sha256=self.sha256,
176+
filename=self.filename,
177+
)
178+
with open(full_output_path_7z, "wb") as output_7z_file:
179+
output_7z_file.write(file_contents)
180+
else:
181+
get_file_response = self.rtr_api.get_extracted_file_contents(
182+
session_id=self.session_id,
183+
sha256=self.sha256,
184+
filename=self.filename,
185+
stream=True,
186+
)
187+
188+
if show_download_progress:
189+
with tqdm.wrapattr(
190+
stream=open(full_output_path_7z, "wb"),
191+
method="write",
192+
total=0, # Content-Length header is not sent by RTR, so tqdm will count upwards
193+
miniters=1,
194+
bytes=True,
195+
desc=os.path.basename(full_output_path_7z),
196+
) as output_7z_file:
197+
for chunk in get_file_response.iter_content(chunk_size=download_chunk_size):
198+
output_7z_file.write(chunk)
78199

79-
with open(full_output_path_7z, "wb") as output_7z_file:
80-
output_7z_file.write(file_contents)
200+
# We have to manually close the file once the download is complete, or chunk
201+
# data will not be flushed to disk and the 7-Zip archive will be unreadable by
202+
# py7zr. See: https://github.com/tqdm/tqdm/issues/1247
203+
output_7z_file.close()
204+
else:
205+
# If no progress bar is requested, we still use chunking to avoid storing the whole
206+
# 7-Zip archive into memory before writing it to disk.
207+
with open(full_output_path_7z, "wb") as output_7z_file:
208+
for chunk in get_file_response.iter_content(chunk_size=download_chunk_size):
209+
output_7z_file.write(chunk)
81210

82211
if not extract:
83212
# Downloaded, so we're done now!
84213
return
85214

86-
with py7zr.SevenZipFile( # nosec - The password 'infected' is generic and always the same
215+
target_dir = os.path.dirname(full_output_path_7z)
216+
217+
with py7zr.SevenZipFile( # nosec - The password "infected" is generic and always the same
87218
file=full_output_path_7z,
88219
mode="r",
89220
password="infected",
90221
) as archive:
91-
target_dir = os.path.dirname(full_output_path_7z)
92-
archive.extract(path=target_dir)
222+
archive_filenames = archive.getnames()
223+
if show_download_progress:
224+
archive_info = archive.archiveinfo()
225+
with SevenZipExtractProgressBar(
226+
unit="B",
227+
unit_scale=True,
228+
miniters=1,
229+
total_bytes=archive_info.uncompressed,
230+
desc=f"Extracting...",
231+
ascii=True,
232+
) as progress:
233+
# archive.extract() does not provide a callback parameter, but behaviour should
234+
# not differ since RTR 7-Zip archives should always contain exactly one inner
235+
# file.
236+
archive.extractall(path=target_dir, callback=progress)
237+
else:
238+
archive.extract(path=target_dir)
239+
240+
# Check that we truly only received exactly one output file in the 7-Zip archive.
241+
# If all is well, we rename the first (and hopefully only) output file to match the name
242+
# derived at the beginning of this function.
243+
if archive_filenames and len(archive_filenames) == 1:
244+
orig_filename = archive_filenames[0]
245+
orig_path = os.path.join(target_dir, orig_filename)
246+
os.rename(orig_path, full_output_path)
247+
else:
248+
raise ValueError(
249+
"The downloaded 7-Zip archive contains the wrong number of files. Contents: %s",
250+
str(archive_filenames),
251+
)
93252

94-
if not preserve_7z:
95-
os.unlink(full_output_path_7z)
253+
# Delete the 7-Zip archive after extracting its contents if the developer told us it
254+
# does not need to be preserved.
255+
if not preserve_7z:
256+
os.unlink(full_output_path_7z)
96257

97258
def __str__(self):
98259
"""Return a loggable string representing the contents of the object."""

0 commit comments

Comments
 (0)