Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

运行train_rl.py时,在train_model这里报错 #5

Open
fffanrrr opened this issue Jan 15, 2024 · 6 comments
Open

运行train_rl.py时,在train_model这里报错 #5

fffanrrr opened this issue Jan 15, 2024 · 6 comments

Comments

@fffanrrr
Copy link

Traceback (most recent call last):
File "D:\downloads\StockFormer-main\StockFormer-main\code\train_rl.py", line 210, in
trained_sac = agent.train_model(model=model_sac,
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\models\DRLAgent.py", line 151, in train_model
model = model.learn(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\MAE_SAC.py", line 364, in learn
return super(SAC, self).learn(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 352, in learn
rollout = self.collect_rollouts(
File "D:\downloads\StockFormer-main\StockFormer-main\code\MySAC\SAC\off_policy_algorithm.py", line 584, in collect_rollouts
if callback.on_step() is False:
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step
return self._on_step()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 192, in _on_step
continue_training = callback.on_step() and continue_training
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 88, in on_step
return self._on_step()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\callbacks.py", line 379, in _on_step
episode_rewards, episode_lengths = evaluate_policy(
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\evaluation.py", line 86, in evaluate_policy
observations, rewards, dones, infos = env.step(actions)
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\base_vec_env.py", line 163, in step
return self.step_wait()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\vec_monitor.py", line 76, in step_wait
obs, rewards, dones, infos = self.venv.step_wait()
File "D:\downloads\StockFormer-main\StockFormer-main\code\stable_baselines3\common\vec_env\dummy_vec_env.py", line 43, in step_wait
obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step(
File "D:\downloads\StockFormer-main\StockFormer-main\code\envs\env_stocktrading_hybrid_control.py", line 279, in step
plt.savefig(
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\pyplot.py", line 1119, in savefig
res = fig.savefig(*args, **kwargs) # type: ignore[func-returns-value]
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\figure.py", line 3390, in savefig
self.canvas.print_figure(fname, **kwargs)
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2193, in print_figure
result = print_method(
File "C:\Users\lenovo\anaconda3\lib\site-packages\matplotlib\backend_bases.py", line 2043, in
print_method = functools.wraps(meth)(lambda *args, **kwargs: meth(
TypeError: print_png() got an unexpected keyword argument 'index'

@fffanrrr
Copy link
Author

请问你们用的matplotlib是什么版本的?

@hugo2046
Copy link

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

@elven2016
Copy link

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

我也是这么解决的

@trialbox
Copy link

我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了

请问您的全跑通了么?我在mae_sac.py里卡住了,temporal_feature_short 和 temporal_feature_long 维度不对,后面self.query_projection(queries)报错。

@fffanrrr
Copy link
Author

fffanrrr commented May 24, 2024 via email

@trialbox
Copy link

我跑通了,我看下你报错的截图可以吗

---- 回复的原邮件 ---- | 发件人 | @.> | | 日期 | 2024年05月23日 17:00 | | 收件人 | @.> | | 抄送至 | @.>@.> | | 主题 | Re: [gsyyysg/StockFormer] 运行train_rl.py时,在train_model这里报错 (Issue #5) | 我的方法是注释掉code\envs\env_stocktrading_hybrid_control.py的279,286行的index=False就没问题了 请问您的全跑通了么?我在mae_sac.py里卡住了,temporal_feature_short 和 temporal_feature_long 维度不对,后面self.query_projection(queries)报错。 — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

Logging to tensorboard_log/mysac/StockFormer/_1
Traceback (most recent call last):
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/train_rl.py", line 213, in
trained_sac = agent.train_model(model=model_sac,
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/models/DRLAgent.py", line 153, in train_model
model = model.learn(
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 379, in learn
return super(SAC, self).learn(
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 354, in learn
rollout = self.collect_rollouts(
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 572, in collect_rollouts
action, buffer_action = self._sample_action(learning_starts, action_noise)
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/off_policy_algorithm.py", line 412, in _sample_action
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/MAE_SAC.py", line 405, in predict
state_tensor = self.actor_transformer(obs_tensor, temporal_short, temporal_long, holding)
File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/MySAC/SAC/policy_transformer.py", line 35, in forward
temporal_hybrid_feature, attn = self.attention(
File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/media/liruidev/data/WorkDir/Project/trade/StockFormer/code/Transformer/models/attn.py", line 64, in forward
queries = self.query_projection(queries).view(B, L, H, -1)
File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
return forward_call(*input, **kwargs)
File "/home/liruidev/App/anaconda3/envs/StockFormer/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x37 and 128x128)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants