-
Notifications
You must be signed in to change notification settings - Fork 245
Expand file tree
/
Copy pathapi.py
More file actions
2079 lines (1848 loc) · 85.2 KB
/
api.py
File metadata and controls
2079 lines (1848 loc) · 85.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from enum import Enum
import json
import logging
from pathlib import Path as PyPath
import traceback
import yaml
import sys
import shutil
import zipfile
import os # Import os module
import tempfile # Import the tempfile module
import time # Import time module
import subprocess # Import subprocess module
from fastapi import (
FastAPI,
HTTPException,
UploadFile,
File,
BackgroundTasks,
Path as FastApiPath,
Query,
Body,
)
from fastapi.responses import FileResponse, StreamingResponse, JSONResponse
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from fastapi.middleware.cors import CORSMiddleware # Import CORS middleware
# Remove unused make_id, replaced by Huey task ID
# from augmentoolkit.utils.make_id import make_id
from huey_config import huey # Import the Huey instance
from tasks import (
run_pipeline_task,
set_final_status,
) # Import the Huey task and set_final_status
from huey.exceptions import (
HueyException,
TaskException,
) # Import Huey exceptions for status check
from redis_config import (
get_progress,
redis_client,
set_progress,
) # Import redis_client and potentially set_progress
import signal # Import signal module
# Import path resolution logic from run_augmentoolkit
from resolve_path import resolve_path
# Import helpers
from file_operation_helpers import (
get_safe_path,
zip_directory,
get_dir_structure,
FileStructure,
MoveItemRequest,
handle_get_structure,
handle_download_item,
handle_delete_item,
handle_move_item,
handle_create_directory, # <-- Add new helper
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration Loading ---
SUPER_CONFIG_PATH = PyPath("super_config.yaml")
PATH_ALIASES = {}
try:
with open(SUPER_CONFIG_PATH, "r", encoding="utf-8") as f:
super_config = yaml.safe_load(f)
PATH_ALIASES = super_config.get("path_aliases", {})
except FileNotFoundError:
print(
f"ERROR: Super config file not found at {SUPER_CONFIG_PATH}. Path aliases will not work."
)
# Consider if the API should fail to start if super_config is essential
# sys.exit(1)
except yaml.YAMLError as e:
print(f"ERROR: Error parsing super config file {SUPER_CONFIG_PATH}: {e}")
# sys.exit(1)
INPUTS_DIR = PyPath("./inputs").resolve()
OUTPUTS_DIR = PyPath("./outputs").resolve()
GENERATION_DIR = PyPath("./generation").resolve()
CONFIGS_DIR = PyPath("./external_configs").resolve()
MAX_FILE_SIZE = 1024 * 1024 * 1000 # 1 GB Limit for uploads/downloads, adjust as needed
# Define LOGS_DIR consistent with tasks.py
LOGS_DIR = PyPath("./logs").resolve()
# Ensure base directories exist
INPUTS_DIR.mkdir(exist_ok=True)
OUTPUTS_DIR.mkdir(exist_ok=True)
GENERATION_DIR.mkdir(exist_ok=True) # Assuming generation dir needs to exist
CONFIGS_DIR.mkdir(exist_ok=True)
LOGS_DIR.mkdir(exist_ok=True) # Ensure logs directory exists
# --- Pydantic Models ---
class PipelineRunRequest(BaseModel):
node_path: str
config_path: Optional[str] = None # Relative to CONFIGS_DIR or absolute
parameters: Optional[Dict[str, Any]] = None
class PipelineRunResponse(BaseModel):
pipeline_id: str # This will now be the Huey Task ID
message: str
class PipelineStatus(str, Enum):
PENDING = "PENDING" # Task is waiting in the queue
RUNNING = "RUNNING" # Task is being executed by a worker
COMPLETED = "COMPLETED" # Task finished successfully
FAILED = "FAILED" # Task failed with an exception
# CANCELLED / REVOKED might be needed later if task cancellation is implemented
REVOKED = "REVOKED" # Task was explicitly revoked
class PipelineStatusResponse(BaseModel):
task_id: str
status: PipelineStatus
message: Optional[str] = None
progress: Optional[float] = Field(None, ge=0.0, le=1.0)
# Progress is harder to track generically with Huey tasks unless the task itself reports it.
# Details can contain the return value of the task or error information
details: Optional[Dict[str, Any]] = None
# Removed unused models
# class FileContent(BaseModel):
# content: str
#
# class QueueStateResponse(BaseModel):
# pending_tasks: List[Dict[str, Any]]
# running_tasks: List[Dict[str, Any]]
# Removed in-memory queue variables
# pipeline_executions: Dict[str, Dict[str, Any]] = {}
# pipeline_queue: List[Dict[str, Any]] = []
# --- Pydantic Models (New Response Model) ---
class QueueStatusResponse(BaseModel):
pending_tasks: List[str] = Field(
..., description="List of task IDs currently pending execution."
)
scheduled_tasks: List[str] = Field(
..., description="List of task IDs scheduled for future execution."
)
message: str
class CreateDirectoryRequest(BaseModel):
relative_path: str = Field(
...,
description="The relative path within the base directory where the new directory should be created.",
)
# --- New Pydantic model for config duplication ---
class DuplicateConfigRequest(BaseModel):
source_alias: str = Field(
...,
description="The alias from super_config.yaml pointing to the source config file.",
)
destination_relative_path: str = Field(
...,
description="The desired relative path (including filename) for the duplicated config within the external_configs directory.",
)
# --- Pydantic model for task parameters response ---
class TaskParametersResponse(BaseModel):
task_id: str
parameters: Dict[str, Any]
# --- FastAPI App ---
app = FastAPI(
title="Augmentoolkit API",
description="API for managing and running Augmentoolkit dataset generation pipelines using Huey.",
version="1.0",
)
# --- CORS Middleware ---
origins = [
"http://localhost:5173", # Frontend origin
"http://127.0.0.1:5173", # Also allow this variant
"http://localhost:5174", # Frontend origin
"http://127.0.0.1:5174", # Also allow this variant
"http://localhost:3000", # Frontend origin
# Add other origins if needed, e.g., production URL
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True, # Allows cookies to be included in requests
allow_methods=["*"], # Allows all standard methods (GET, POST, etc.)
allow_headers=["*"], # Allows all headers
)
# TODO make sure taht pipeline settings in the config makes it run the proper pipeline
# You know what to do -- get this shippable
# edit the two videos, and record new one with the new running process
# configs polish and checking
# documentation update to match new features and utility pipelines
# also will have to rerecord interface thing so be it, because of new chat window
# and censor the history
@app.get("/", summary="Health Check")
async def read_root():
"""Basic health check endpoint."""
return {"message": "Augmentoolkit API is running."}
@app.post(
"/pipelines/run",
response_model=PipelineRunResponse,
status_code=202,
summary="Queue a dataset generation pipeline for execution.",
)
def queue_pipeline_run(request: PipelineRunRequest):
"""
Resolves paths using super_config aliases and enqueues a pipeline run task using Huey.
Returns the task ID which can be used to check the status.
"""
try:
# as little logic for the queue and paths, in the post, as possible
print(f"DEBUG: Queuing task with:")
print(f" Resolved Node Path: {request.node_path}")
print(f" Resolved Config Path: {request.config_path}")
print(f" Parameters: {request.parameters}")
# request.parameters.update(task_id=task.id) # TODO is this how it's done?
# Enqueue the task with resolved paths
task = run_pipeline_task(
node_path=request.node_path,
config_path=request.config_path,
parameters=request.parameters,
)
return PipelineRunResponse(
pipeline_id=task.id,
message="Pipeline run queued successfully.",
)
except Exception as e:
print(f"ERROR during path resolution or task enqueueing: {e}")
traceback.print_exc() # Log traceback for debugging
raise HTTPException(
status_code=500,
detail=f"Failed to resolve paths or enqueue pipeline task: {e}",
)
@app.get(
"/pipelines/available",
response_model=List[str],
summary="Get available pipeline aliases from super_config.yaml.",
)
def get_available_pipelines():
"""
Reads the super_config.yaml and returns a list of path aliases
that do not point to configuration (.yaml) files.
These typically represent runnable pipeline entry points.
"""
available_pipelines = []
if not SUPER_CONFIG_PATH.exists():
logger.error(
f"Super config file not found at {SUPER_CONFIG_PATH} for '/pipelines/available' endpoint."
)
# Return empty list or raise 500? Returning empty list might be more graceful for UI.
return []
# Alternatively: raise HTTPException(status_code=500, detail="Super configuration file not found.")
try:
with open(SUPER_CONFIG_PATH, "r", encoding="utf-8") as f:
super_config = yaml.safe_load(f) or {} # Handle empty file case
aliases = super_config.get("path_aliases", {})
if not isinstance(aliases, dict):
logger.warning(
f"path_aliases in {SUPER_CONFIG_PATH} is not a dictionary. Returning empty list."
)
return []
for alias, path_value in aliases.items():
# Check if the path value is a string and does not end with .yaml
if isinstance(path_value, str) and not path_value.strip().lower().endswith(
".yaml"
):
available_pipelines.append(alias)
logger.info(f"Found {len(available_pipelines)} available pipeline aliases.")
return sorted(available_pipelines) # Return sorted list
except yaml.YAMLError as e:
logger.error(f"Error parsing super config file {SUPER_CONFIG_PATH}: {e}")
raise HTTPException(
status_code=500, detail=f"Error parsing super configuration file: {e}"
)
except Exception as e:
logger.error(
f"Unexpected error reading super config for available pipelines: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while retrieving available pipelines.",
)
# and here I was about to say, oh but how do we get the return value? Butwe dont need it the main output of the pipelines are files.
# so
@app.get(
"/configs/aliases",
response_model=List[str],
summary="Get available config file aliases from super_config.yaml.",
)
def get_available_config_aliases():
"""
Reads the super_config.yaml and returns a list of path aliases
that point to configuration (.yaml) files.
"""
config_aliases = []
if not SUPER_CONFIG_PATH.exists():
logger.error(
f"Super config file not found at {SUPER_CONFIG_PATH} for '/configs/aliases' endpoint."
)
return [] # Graceful empty list if file is missing
try:
with open(SUPER_CONFIG_PATH, "r", encoding="utf-8") as f:
super_config = yaml.safe_load(f) or {} # Handle empty file case
aliases = super_config.get("path_aliases", {})
if not isinstance(aliases, dict):
logger.warning(
f"path_aliases in {SUPER_CONFIG_PATH} is not a dictionary. Returning empty list."
)
return []
for alias, path_value in aliases.items():
# Check if the path value is a string and *does* end with .yaml
if isinstance(path_value, str) and path_value.strip().lower().endswith(
".yaml"
):
config_aliases.append(alias)
logger.info(f"Found {len(config_aliases)} available config aliases.")
logger.info(config_aliases)
return sorted(config_aliases) # Return sorted list
except yaml.YAMLError as e:
logger.error(f"Error parsing super config file {SUPER_CONFIG_PATH}: {e}")
raise HTTPException(
status_code=500, detail=f"Error parsing super configuration file: {e}"
)
except Exception as e:
logger.error(
f"Unexpected error reading super config for available config aliases: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while retrieving available config aliases.",
)
@app.get(
"/tasks/queue",
response_model=QueueStatusResponse,
summary="Get lists of pending and scheduled tasks.",
)
def get_queue_status():
"""
Retrieves the IDs of tasks currently pending or scheduled in the Huey queue.
Note: This does not reliably show tasks that are *actively running*.
"""
try:
pending_tasks = [task.id for task in huey.pending()]
scheduled_tasks = [task.id for task in huey.scheduled()]
count_pending = len(pending_tasks)
count_scheduled = len(scheduled_tasks)
logger.info(
f"Retrieved queue status: {count_pending} pending, {count_scheduled} scheduled."
)
return QueueStatusResponse(
pending_tasks=pending_tasks,
scheduled_tasks=scheduled_tasks,
message=f"Found {count_pending} pending and {count_scheduled} scheduled tasks.",
)
except Exception as e:
logger.error(f"Error retrieving Huey queue status: {e}", exc_info=True)
raise HTTPException(
status_code=500, detail=f"Failed to retrieve queue status: {e}"
)
@app.get(
"/tasks/{task_id}/status",
response_model=PipelineStatusResponse,
summary="Get the status of a pipeline run.",
)
def get_pipeline_status(task_id: str):
"""
Retrieves the status of a specific pipeline run task, primarily using Redis.
Checks Redis for final status, then progress, then checks Huey for pending/revoked state.
"""
redis_status_key = f"status_for_task:{task_id}"
redis_progress_key = f"progress_for_task:{task_id}" # Assuming this is the key used by redis_config.get_progress
try:
# 1. Check Redis for Final Status
final_status_json = redis_client.get(redis_status_key)
if final_status_json:
try:
status_data = json.loads(final_status_json)
logger.info(
f"Status check for {task_id}: Found final status in Redis: {status_data['status']}"
)
# Map stored status string to PipelineStatus enum
final_status_enum = PipelineStatus(
status_data["status"]
) # Raises ValueError if invalid
# Progress should be 1.0 for completed/failed/revoked unless explicitly stored otherwise
progress_value = 1.0
if (
final_status_enum == PipelineStatus.RUNNING
): # Should not happen if final status is set correctly
progress_value = 0.0 # Or query progress key? Defaulting to 0.0 if wrongly marked RUNNING as final.
# If progress info exists alongside final status, use it (e.g., last known progress before failure)
# This part might be redundant if tasks.py sets progress correctly on final status.
# Let's keep it simple for now: assume 1.0 for terminal states.
return PipelineStatusResponse(
task_id=task_id,
status=final_status_enum,
message=status_data.get(
"message", "Status retrieved from final record."
),
progress=progress_value, # Assume 1.0 for terminal states stored here
details=status_data.get("details"),
)
except json.JSONDecodeError as e:
logger.error(
f"Status check for {task_id}: Failed to parse final status JSON from Redis ({redis_status_key}): {e}. Content: {final_status_json[:100]}..."
)
# Fall through to check other methods, but log the error
except (
ValueError
) as e: # Handle case where status string is not in PipelineStatus enum
logger.error(
f"Status check for {task_id}: Invalid status value '{status_data.get('status')}' found in Redis ({redis_status_key}): {e}"
)
# Fall through
except Exception as e: # Catch other unexpected errors during processing
logger.error(
f"Status check for {task_id}: Error processing final status from Redis ({redis_status_key}): {e}",
exc_info=True,
)
# Fall through
# 2. Check Redis for Progress (indicates RUNNING)
pipeline_progress = get_progress(task_id) # Use the existing helper
if pipeline_progress:
logger.info(
f"Status check for {task_id}: Found progress info in Redis. Assuming RUNNING."
)
progress_value = min(
pipeline_progress.get("progress", 0.0), 1.0
) # Cap at 1.0
return PipelineStatusResponse(
task_id=task_id,
status=PipelineStatus.RUNNING,
message=pipeline_progress.get("message", "Task is running."),
progress=progress_value,
details={"source": "progress_tracker"},
)
# 3. Check Huey for Pending/Scheduled (if no final status or progress found)
# This requires iterating through potentially large lists, might be less efficient.
try:
pending_ids = {task.id for task in huey.pending()}
scheduled_ids = {task.id for task in huey.scheduled()}
if task_id in pending_ids or task_id in scheduled_ids:
logger.info(
f"Status check for {task_id}: Found in Huey pending/scheduled list."
)
return PipelineStatusResponse(
task_id=task_id,
status=PipelineStatus.PENDING,
message="Task is pending in the queue.",
progress=0.0,
details={"source": "huey_queue"},
)
except Exception as e:
logger.error(
f"Status check for {task_id}: Error checking Huey pending/scheduled queues: {e}",
exc_info=True,
)
# Continue to next check
# 4. Check Huey for Revoked status (as a fallback)
try:
if huey.is_revoked(task_id):
logger.info(
f"Status check for {task_id}: Found revoked status via huey.is_revoked()."
)
# Check if a final status was set just now or very recently
final_status_json = redis_client.get(redis_status_key)
if final_status_json:
logger.warning(
f"Status check for {task_id}: Task is revoked in Huey, but also has a final status record in Redis. Preferring Redis record."
)
# Re-process the Redis record found above (code duplication, consider refactor)
try:
status_data = json.loads(final_status_json)
final_status_enum = PipelineStatus(status_data["status"])
return PipelineStatusResponse(
task_id=task_id,
status=final_status_enum,
message=status_data.get("message"),
progress=1.0,
details=status_data.get("details"),
)
except Exception:
pass # Ignore errors here, proceed with REVOKED
# If no Redis status, return REVOKED based on Huey state
return PipelineStatusResponse(
task_id=task_id,
status=PipelineStatus.REVOKED,
message="Task was revoked.",
progress=0.0, # Progress is likely 0 if revoked before running
details={"source": "huey_revoked"},
)
except HueyException:
# This likely means the task ID isn't known to Huey *at all* anymore.
logger.warning(
f"Status check for {task_id}: HueyException checking revoked status. Task ID likely invalid or expired from Huey storage."
)
raise HTTPException(
status_code=404,
detail=f"Pipeline task with ID '{task_id}' not found or status expired.",
)
except Exception as e:
logger.error(
f"Status check for {task_id}: Error checking Huey revoked status: {e}",
exc_info=True,
)
# Fall through to final check/404
# 5. If none of the above, Task ID is likely invalid or expired
logger.warning(
f"Status check for {task_id}: No status found in Redis (final/progress) or Huey (pending/scheduled/revoked). Task may be invalid or expired."
)
# Check Redis one last time - maybe it appeared?
final_status_json = redis_client.get(redis_status_key)
if final_status_json:
logger.info(
f"Status check for {task_id}: Final status appeared in Redis on last check."
)
try:
status_data = json.loads(final_status_json)
final_status_enum = PipelineStatus(status_data["status"])
return PipelineStatusResponse(
task_id=task_id,
status=final_status_enum,
message=status_data.get("message"),
progress=1.0,
details=status_data.get("details"),
)
except Exception:
pass # Ignore errors, raise 404
raise HTTPException(
status_code=404,
detail=f"Pipeline task with ID '{task_id}' not found or status could not be determined.",
)
except HTTPException: # Re-raise 404
raise
except Exception as e:
logger.error(
f"Unexpected ERROR during status check for task {task_id}: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500, detail=f"Unexpected error retrieving task status: {e}"
)
@app.get(
"/tasks/{task_id}/parameters",
response_model=TaskParametersResponse,
summary="Get the parameters a task was executed with.",
)
def get_task_parameters(task_id: str):
"""
Retrieves the parameters (overrides and config) used to start a specific task run.
Parameters are stored in Redis when the task begins execution.
"""
redis_key = f"parameters_for_task:{task_id}"
logger.info(
f"Request for parameters for task {task_id} using Redis key: {redis_key}"
)
try:
params_json = redis_client.get(redis_key)
if params_json is None:
logger.warning(
f"Parameters not found in Redis for task {task_id} (key: {redis_key}). Checking task existence."
)
# Check if the task ID itself is known to Huey to give a better error
try:
# Peek at the result without blocking to see if the task ID is valid
huey.result(task_id, blocking=False)
# If the above line doesn't raise HueyException, the task exists or existed.
# Parameters might have expired or failed to store.
raise HTTPException(
status_code=404,
detail=f"Parameters for task {task_id} not found. They may have expired or were not stored.",
)
except HueyException:
# Task ID is not known to Huey
raise HTTPException(
status_code=404, detail=f"Task with ID '{task_id}' not found."
)
except Exception as huey_check_e:
# Catch other potential errors during the check
logger.error(
f"Error checking Huey task status for {task_id} while getting parameters: {huey_check_e}"
)
# Fallback to generic parameter not found error
raise HTTPException(
status_code=404, detail=f"Parameters for task {task_id} not found."
)
# Attempt to parse the JSON string
try:
parameters = json.loads(params_json)
logger.info(
f"Successfully retrieved and parsed parameters for task {task_id}"
)
return TaskParametersResponse(task_id=task_id, parameters=parameters)
except json.JSONDecodeError as e:
logger.error(
f"Failed to parse parameters JSON retrieved from Redis for task {task_id}. Key: {redis_key}. Content starts with: {params_json[:100]}... Error: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Error decoding parameters stored for task {task_id}. Data may be corrupted.",
)
except HTTPException: # Re-raise the 404/500 from inner blocks
raise
except Exception as e:
logger.error(
f"Unexpected error retrieving parameters for task {task_id}: {e}",
exc_info=True,
)
raise HTTPException(
status_code=500,
detail=f"An unexpected error occurred while retrieving parameters for task {task_id}: {e}",
)
def _find_and_kill_run_augmentoolkit_process() -> Optional[Dict[str, Any]]:
"""
Fallback mechanism to find and kill run_augmentoolkit.py process directly.
Returns dict with 'pid' and 'killed' status if found, None if not found.
"""
try:
# Use ps -ef to find run_augmentoolkit.py process
result = subprocess.run(
["ps", "-ef"], capture_output=True, text=True, check=True
)
# Look for lines containing "run_augmentoolkit.py"
for line in result.stdout.splitlines():
if "run_augmentoolkit.py" in line and "grep" not in line:
# Parse the PID (second field in ps -ef output)
fields = line.split()
if len(fields) >= 2:
try:
pid = int(fields[1])
logger.warning(
f"Found run_augmentoolkit.py process via ps -ef fallback: PID {pid}"
)
# Try SIGINT first for graceful shutdown
os.kill(pid, signal.SIGINT)
logger.info(
f"Sent SIGINT to run_augmentoolkit.py process (PID {pid}) via fallback"
)
# Wait for graceful termination
grace_period = 8 # seconds to wait for SIGINT
check_interval = 0.3 # seconds between checks
start_time = time.time()
# Poll to see if process terminates gracefully
while time.time() - start_time < grace_period:
try:
os.kill(pid, 0) # Check if process still exists
time.sleep(check_interval)
except ProcessLookupError:
# Process has terminated gracefully
elapsed = time.time() - start_time
logger.info(
f"Process {pid} terminated gracefully after {elapsed:.1f}s via fallback"
)
return {"pid": pid, "killed": True, "method": "SIGINT"}
# Process didn't respond to SIGINT, escalate to SIGKILL
logger.warning(
f"Process {pid} did not respond to SIGINT after {grace_period}s. Escalating to SIGKILL via fallback."
)
os.kill(pid, signal.SIGKILL)
logger.info(
f"Successfully killed run_augmentoolkit.py process (PID {pid}) with SIGKILL via fallback"
)
return {"pid": pid, "killed": True, "method": "SIGKILL"}
except (ValueError, ProcessLookupError, PermissionError) as e:
logger.error(
f"Failed to kill run_augmentoolkit.py process found via fallback: {e}"
)
continue
logger.info("No run_augmentoolkit.py process found via ps -ef fallback")
return None
except subprocess.CalledProcessError as e:
logger.error(f"Failed to run ps -ef command in fallback: {e}")
return None
except Exception as e:
logger.error(
f"Unexpected error in run_augmentoolkit.py fallback kill mechanism: {e}"
)
return None
@app.post(
"/tasks/{task_id}/interrupt",
status_code=200,
summary="Interrupt a running pipeline task subprocess or revoke a pending task.",
)
def interrupt_or_revoke_task(task_id: str):
"""
Attempts to interrupt a RUNNING task's subprocess via SIGINT,
or revokes a PENDING task (setting final status in Redis).
Reports status if already finished/revoked based on Redis final status.
"""
redis_pid_key = f"worker_pid_for_task:{task_id}"
redis_status_key = f"status_for_task:{task_id}"
logger.info(f"Received interrupt request for task {task_id}")
try:
# === Step 1: Check Final Status in Redis ===
final_status_json = redis_client.get(redis_status_key)
if final_status_json:
try:
status_data = json.loads(final_status_json)
existing_status = status_data.get("status", "UNKNOWN").upper()
logger.warning(
f"Interrupt request for {task_id}: Task already has a final status in Redis: {existing_status}"
)
# Return 409 Conflict if already completed, failed, or revoked
if existing_status in ["COMPLETED", "FAILED", "REVOKED"]:
raise HTTPException(
status_code=409,
detail=f"Task {task_id} has already finished with status: {existing_status}.",
)
# Otherwise, log warning but proceed (e.g., if status was somehow inconsistent)
except (json.JSONDecodeError, ValueError, KeyError) as e:
logger.error(
f"Interrupt request for {task_id}: Error reading existing final status from Redis ({redis_status_key}): {e}. Proceeding with interrupt/revoke attempt."
)
# === Step 2: Check if Running (via Subprocess PID in Redis) ===
pid_bytes = redis_client.get(redis_pid_key)
if pid_bytes:
pid = None
try:
# --- Get PID ---
try:
pid = int(
pid_bytes
) # Assumes redis client returns bytes/string convertible to int
except ValueError:
logger.error(
f"Invalid non-integer PID string '{pid_bytes}' found in Redis for key {redis_pid_key}. Cleaning up."
)
redis_client.delete(redis_pid_key) # Clean up invalid data
# Proceed to check if it's pending, maybe the PID was wrong but task exists
# Try fallback mechanism to find and kill run_augmentoolkit.py
logger.info(
f"Attempting fallback mechanism to find run_augmentoolkit.py process for task {task_id}"
)
fallback_result = _find_and_kill_run_augmentoolkit_process()
if fallback_result and fallback_result.get("killed"):
set_final_status(
task_id,
"REVOKED",
f"Task {task_id} terminated via fallback mechanism (found PID {fallback_result['pid']}).",
details={
"source": "api_interrupt_fallback",
"pid": fallback_result["pid"],
"termination_method": f"{fallback_result.get('method', 'SIGKILL')}_FALLBACK",
},
)
return {
"message": f"Task {task_id} terminated via fallback mechanism after invalid Redis PID."
}
pass # Fall through to pending check
if pid is not None: # Only proceed if PID was valid
# --- Send SIGINT and wait for graceful termination ---
logger.info(
f"Task {task_id} appears to be running (Subprocess PID {pid} found). Sending SIGINT."
)
os.kill(pid, signal.SIGINT)
# Wait for graceful termination
grace_period = 8 # seconds to wait for SIGINT
check_interval = 0.3 # seconds between checks
start_time = time.time()
# Poll to see if process terminates gracefully
process_terminated_gracefully = False
while time.time() - start_time < grace_period:
try:
os.kill(
pid, 0
) # Check if process still exists (signal 0 = no-op)
time.sleep(check_interval)
except ProcessLookupError:
# Process has terminated
elapsed = time.time() - start_time
process_terminated_gracefully = True
logger.info(
f"Process {pid} terminated gracefully after {elapsed:.1f}s"
)
break
except PermissionError:
# Process exists but we can't check it - treat as still running
time.sleep(check_interval)
if process_terminated_gracefully:
# Graceful termination succeeded
elapsed = time.time() - start_time
set_final_status(
task_id,
"REVOKED",
f"Task {task_id} gracefully terminated with SIGINT after {elapsed:.1f}s.",
details={
"source": "api_interrupt",
"pid": pid,
"termination_method": "SIGINT",
},
)
return {
"message": f"Task {task_id} gracefully terminated with SIGINT after {elapsed:.1f}s."
}
# --- Escalate to SIGKILL ---
logger.warning(
f"Process {pid} did not respond to SIGINT after {grace_period}s. Escalating to SIGKILL."
)
try:
os.kill(pid, signal.SIGKILL)
logger.info(
f"Sent SIGKILL to subprocess {pid} for task {task_id}"
)
# Brief wait for SIGKILL to take effect
time.sleep(2)
# Verify SIGKILL worked
try:
os.kill(pid, 0)
# If we get here, process is STILL running after SIGKILL
logger.error(
f"Process {pid} still exists after SIGKILL! This should not happen."
)
message = f"Task {task_id} may still be running despite SIGKILL (PID {pid}). Manual intervention may be required."
termination_method = "SIGKILL_FAILED"
except ProcessLookupError:
# SIGKILL succeeded
logger.info(
f"Process {pid} successfully terminated with SIGKILL"
)
message = f"Task {task_id} forcefully terminated with SIGKILL after SIGINT timeout."
termination_method = "SIGKILL"
except ProcessLookupError:
# Process terminated between our grace period check and SIGKILL attempt
logger.info(
f"Process {pid} terminated just before SIGKILL was sent"
)
message = f"Task {task_id} terminated during escalation (process ended just after SIGINT timeout)."
termination_method = "SIGINT_DELAYED"
except PermissionError:
logger.error(
f"Permission denied sending SIGKILL to PID {pid} for task {task_id}"
)
message = f"Permission error escalating to SIGKILL for task {task_id} (PID {pid})."
termination_method = "PERMISSION_ERROR"
except Exception as kill_e:
logger.error(
f"Unexpected error sending SIGKILL to PID {pid}: {kill_e}"
)
message = (
f"Error escalating to SIGKILL for task {task_id}: {kill_e}"
)
termination_method = "SIGKILL_ERROR"
# Set final status with escalation details
set_final_status(
task_id,
"REVOKED",
message,
details={
"source": "api_interrupt",
"pid": pid,
"termination_method": termination_method,
"escalated_to_sigkill": True,
},
)
return {"message": message}
except ProcessLookupError:
logger.warning(
f"Subprocess with PID {pid} for task {task_id} not found when sending SIGINT. It likely finished/crashed recently."
)
# Don't immediately raise 404, check final status / pending status below.
# Try fallback mechanism before falling through
logger.info(
f"Attempting fallback mechanism to find run_augmentoolkit.py process for task {task_id}"
)
fallback_result = _find_and_kill_run_augmentoolkit_process()
if fallback_result and fallback_result.get("killed"):
set_final_status(
task_id,
"REVOKED",
f"Task {task_id} terminated via fallback mechanism (found PID {fallback_result['pid']}).",
details={
"source": "api_interrupt_fallback",
"pid": fallback_result["pid"],
"termination_method": f"{fallback_result.get('method', 'SIGKILL')}_FALLBACK",
"original_pid": pid,
},
)
return {
"message": f"Task {task_id} terminated via fallback mechanism after PID {pid} was not found."
}
pass # Fall through
except PermissionError:
logger.error(
f"Permission denied trying to send SIGINT to PID {pid} for task {task_id}."
)
raise HTTPException(
status_code=500,
detail=f"Permission error signaling subprocess {pid}.",
)
except Exception as e:
pid_str = str(pid) if pid is not None else "<conversion failed>"
logger.error(
f"Unexpected error processing interrupt for PID {pid_str} / task {task_id}: {e}",
exc_info=True,
)
# Fall through to check pending/final status
pass
# === Fallback: Try to find and kill run_augmentoolkit.py directly ===
# This handles cases where Redis PID was missing or other failures occurred
if not redis_client.exists(redis_pid_key):
logger.info(
f"No PID found in Redis for task {task_id}. Attempting fallback mechanism."
)
fallback_result = _find_and_kill_run_augmentoolkit_process()
if fallback_result and fallback_result.get("killed"):
set_final_status(