-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathTrain.m
More file actions
58 lines (55 loc) · 2.2 KB
/
Train.m
File metadata and controls
58 lines (55 loc) · 2.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
classdef Train < handle
properties
MetaOptimizer
moName
BaseOptimizer
boName
TrainingSet
TrainingSetName
epoch = 1
env
end
methods
function obj = Train(moName, boName, envName, problemset)
obj.moName = moName;
obj.boName = boName;
obj.BaseOptimizer = feval(boName);
obj.TrainingSetName = problemset.psName;
[obj.TrainingSet, ~] = splitProblemSet(problemset);
obj.env = feval(envName, obj.TrainingSet, obj.BaseOptimizer, 'train');
obsInfo = getObservationInfo(obj.env);
actInfo = getActionInfo(obj.env);
obj.MetaOptimizer = feval(moName,obsInfo,actInfo);
end
function result = run(obj)
set(0, 'DefaultFigureVisible', 'off');
time_str = datestr(datetime('now'), 'yyyymmddHHMMSS');
folderName = [class(obj.MetaOptimizer), '_', ...
obj.TrainingSetName, '_D', ...
num2str(obj.TrainingSet{1}.D), '_', ...
time_str];
currentScriptPath = mfilename('fullpath');
% 获取当前脚本所在的目录
currentDir = fileparts(currentScriptPath);
% 构建 SaveAgentDirectory 的完整路径
saveLogDir = fullfile(currentDir, 'Data', 'TrainingLog',folderName);
% 确保目录存在
if ~exist(saveLogDir, 'dir')
mkdir(saveLogDir);
end
trainOpts = rlTrainingOptions(...
'MaxEpisodes',20000,...
'MaxStepsPerEpisode',1000,...
'SaveAgentCriteria','EpisodeReward',...
'SaveAgentValue',0,...
'ScoreAveragingWindowLength',10,...
'SaveAgentDirectory',saveLogDir,...
'Plots', 'none'); %close
trainingInfo = train(obj.MetaOptimizer,obj.env,trainOpts);
agent = obj.MetaOptimizer;
save( [saveLogDir, '\',obj.moName,'_finalAgent.mat'],'agent');
save( [currentDir,'\AgentModel\', obj.moName,'_finalAgent.mat'],'agent');
result.trainingInfo = trainingInfo;
end
end
end