Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions embodichain/agents/hierarchy/agent_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@
# limitations under the License.
# ----------------------------------------------------------------------------

from abc import ABCMeta, abstractmethod
from abc import ABCMeta
import os
import cv2
from embodichain.utils.utility import load_json, load_txt
from embodichain.agents.mllm.prompt import *
from embodichain.utils.utility import load_txt
import embodichain.agents.mllm.prompt as mllm_prompt
from embodichain.data import database_agent_prompt_dir, database_2d_dir
from embodichain.utils.utility import encode_image


class AgentBase(metaclass=ABCMeta):
Expand Down
11 changes: 1 addition & 10 deletions embodichain/agents/hierarchy/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,13 @@
from langchain_core.prompts import ChatPromptTemplate
import os
import numpy as np
import functools
from typing import Dict, Tuple, Any
from embodichain.toolkits.code_generation import (
ExecutableOutputParser,
OutputFormatting,
)
from embodichain.toolkits.toolkits import ToolkitsBase
from typing import Dict, Tuple
from embodichain.agents.mllm.prompt import CodePrompt
from embodichain.data import database_agent_prompt_dir
from pathlib import Path
import re
import importlib.util
from langchain_core.messages import HumanMessage
from datetime import datetime
from embodichain.utils.utility import encode_image
import base64


def format_execution_history(execution_history):
Expand Down
9 changes: 1 addition & 8 deletions embodichain/agents/hierarchy/task_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,12 @@
from embodichain.agents.hierarchy.agent_base import AgentBase
from langchain_core.prompts import ChatPromptTemplate
from embodichain.data import database_2d_dir
from embodichain.utils.utility import load_txt, encode_image
from embodichain.utils.utility import load_txt
from embodichain.agents.mllm.prompt import TaskPrompt
from embodichain.data import database_agent_prompt_dir
from pathlib import Path
from langchain_core.messages import HumanMessage
import numpy as np

# from openai import OpenAI
import os
import time
import cv2
import glob
import json
import re

USEFUL_INFO = """The error may be caused by:
Expand Down
2 changes: 1 addition & 1 deletion embodichain/agents/mllm/prompt/code_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from embodichain.utils.utility import encode_image, encode_image_from_path
from embodichain.utils.utility import encode_image


class CodePrompt:
Expand Down
3 changes: 1 addition & 2 deletions embodichain/agents/mllm/prompt/task_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
# ----------------------------------------------------------------------------

import torch
import numpy as np
from langchain_core.messages import SystemMessage
from langchain_core.prompts import (
ChatPromptTemplate,
HumanMessagePromptTemplate,
)
from embodichain.utils.utility import encode_image, encode_image_from_path
from embodichain.utils.utility import encode_image


class TaskPrompt:
Expand Down
10 changes: 1 addition & 9 deletions embodichain/toolkits/code_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,8 @@
# limitations under the License.
# ----------------------------------------------------------------------------

from typing import List, Dict, Tuple, Any
from typing import Dict, Tuple
from langchain_core.output_parsers import BaseOutputParser
from pygments import highlight
from pygments.lexers import PythonLexer
from pygments.formatters import TerminalFormatter

import numpy as np

Expand Down Expand Up @@ -81,11 +78,6 @@ def parse(self, text: str) -> Tuple[str, Dict, Dict]:
gvars = merge_dicts([self._fixed_vars, self.variable_vars])
lvars = None

#
# banned_phrases = ["import", "__"]
# for phrase in banned_phrases:
# assert phrase not in code_str

if gvars is None:
gvars = {}
if lvars is None:
Expand Down
9 changes: 1 addition & 8 deletions embodichain/toolkits/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,16 @@
# limitations under the License.
# ----------------------------------------------------------------------------

from typing import List, Dict
import numpy as np
from embodichain.toolkits.toolkits import ToolkitsBase
from embodichain.utils.logger import log_info, log_warning, log_error
from copy import deepcopy
from embodichain.lab.gym.utils.misc import (
mul_linear_expand,
is_qpos_flip,
get_rotation_replaced_pose,
)
from embodichain.toolkits.graspkit.pg_grasp.antipodal import GraspSelectMethod
from matplotlib import pyplot as plt
import torch
from tqdm import tqdm
from embodichain.lab.sim.planners.motion_generator import MotionGenerator
from embodichain.data.enum import ControlParts, EndEffector, JointType
from scipy.spatial.transform import Rotation as R
from embodichain.utils.utility import encode_image
import ast
Expand Down Expand Up @@ -390,7 +384,6 @@ def grasp(
delta_xy = target_obj_pose[:2, 3] - select_arm_base_pose[:2, 3]
dx, dy = delta_xy[0], delta_xy[1]
aim_horizontal_angle = np.arctan2(dy, dx)
delta_angle = abs(select_arm_current_qpos[0] - aim_horizontal_angle)
select_arm_aim_qpos = deepcopy(select_arm_current_qpos)
select_arm_aim_qpos[0] = aim_horizontal_angle

Expand Down Expand Up @@ -1213,7 +1206,7 @@ def drive(
actions = list(actions.unbind(dim=0))
for i in tqdm(range(len(actions))):
action = actions[i]
obs, reward, terminated, truncated, info = env.step(action)
env.step(action)
return actions


Expand Down
1 change: 0 additions & 1 deletion embodichain/toolkits/toolkits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from abc import ABCMeta, abstractmethod
import os
import cv2
from embodichain.utils.utility import load_json


Expand Down