Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: expected scalar type long int but found float #43

Open
Ethereal1679 opened this issue Oct 29, 2024 · 1 comment
Open

RuntimeError: expected scalar type long int but found float #43

Ethereal1679 opened this issue Oct 29, 2024 · 1 comment

Comments

@Ethereal1679
Copy link

Traceback (most recent call last):
File "train.py", line 43, in
train(args)
File "train.py", line 38, in train
ppo_runner, train_cfg = task_registry.make_alg_runner(env=env, name=args.task, args=args)
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/utils/task_registry.py", line 152, in make_alg_runner
runner = runner_class(env, all_cfg, log_dir, device=args.rl_device)
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/algo/ppo/on_policy_runner.py", line 91, in init
_, _ = self.env.reset()
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/envs/base/legged_robot.py", line 115, in reset
obs, privileged_obs, _, _, _ = self.step(torch.zeros(
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/envs/custom/humanoid_env.py", line 197, in step
return super().step(actions)
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/envs/base/legged_robot.py", line 102, in step
self.post_physics_step()
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/envs/base/legged_robot.py", line 142, in post_physics_step
self.compute_reward()
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/envs/base/legged_robot.py", line 226, in compute_reward
rew = self.reward_functionsi * self.reward_scales[name]
File "/home/yyds/桌面/Gym5_human/humanoid-gym-main/humanoid/envs/custom/humanoid_env.py", line 343, in _reward_feet_contact_number
reward = torch.where(contact == stance_mask, 1, -0.3)
RuntimeError: expected scalar type long int but found float
####################################################
It seems that torch.where is something wrong , can u help me to solve it?

@zlw21gxy
Copy link
Contributor

zlw21gxy commented Nov 1, 2024

Have you changed the definitions of contact or stance_mask? Does this happen with the default settings? Try this line: reward = torch.where(contact == stance_mask, 1.0, -0.3). Thank you for your feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants