Skip to content

Commit

Permalink
Merge pull request #55 from fraricci/react
Browse files Browse the repository at this point in the history
ReAct agent implementation
  • Loading branch information
jan-janssen authored Aug 30, 2024
2 parents 8d93861 + a5d1488 commit 8cf5e1e
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 28 deletions.
60 changes: 42 additions & 18 deletions langsim/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from langchain.agents import AgentExecutor
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain.agents.output_parsers import JSONAgentOutputParser
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.tools.render import render_text_description_and_args
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
Expand All @@ -13,40 +15,37 @@
get_atom_dict_bulk_structure,
get_atom_dict_equilibrated_structure,
)
from langsim.prompt import SYSTEM_PROMPT
from langsim.prompt import SYSTEM_PROMPT_ALT, SYSTEM_PROMPT
from langchain import hub


def get_executor(api_provider, api_key, api_url=None, api_model=None, api_temperature=0):
def get_executor(api_provider, api_key, api_url=None, api_model=None, api_temperature=0, agent_type="default"):
if api_provider.lower() == "openai":
from langchain_openai import ChatOpenAI
if api_model is None:
api_model = "gpt-4"
api_model = "gpt-4o"
llm = ChatOpenAI(
model=api_model, temperature=api_temperature, openai_api_key=api_key, base_url=api_url,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", SYSTEM_PROMPT),
("human", "{conversation}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)

elif api_provider.lower() == "anthropic":
from langchain_anthropic import ChatAnthropic
if api_model is None:
api_model = "claude-3-5-sonnet-20240620"
llm = ChatAnthropic(
model=api_model, temperature=api_temperature, anthropic_api_key=api_key,
)
prompt = ChatPromptTemplate.from_messages(
[
("system", SYSTEM_PROMPT),
("human", "{conversation}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]

elif api_provider.lower() == "groq":
from langchain_groq import ChatGroq
if api_model is None:
api_model = "Llama-3-Groq-70B-Tool-Use"
llm = ChatGroq(
model=api_model, temperature=api_temperature, groq_api_key=api_key,
)
else:
raise ValueError()

tools = [
get_equilibrium_volume,
get_atom_dict_bulk_structure,
Expand All @@ -55,16 +54,41 @@ def get_executor(api_provider, api_key, api_url=None, api_model=None, api_temper
get_bulk_modulus,
get_experimental_elastic_property_wikipedia,
]

llm_with_tools = llm.bind_tools(tools)

if agent_type == "react":
REACT_PROMPT_FROM_HUB = hub.pull("langchain-ai/react-agent-template")
prompt = REACT_PROMPT_FROM_HUB.partial(instructions=SYSTEM_PROMPT_ALT,
tools=render_text_description_and_args(tools),
tool_names=", ".join([t.name for t in tools]),
)
input_key = "input"

elif agent_type == "default":
prompt = ChatPromptTemplate.from_messages(
[
("system", SYSTEM_PROMPT),
("human", "{conversation}"),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
)
input_key = "conversation"

else:
raise ValueError()


agent = (
{
"conversation": lambda x: x["conversation"],
input_key: lambda x: x[input_key],
"agent_scratchpad": lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
),
}
| prompt
| llm_with_tools
| OpenAIToolsAgentOutputParser()
# | JSONAgentOutputParser()
)
return AgentExecutor(agent=agent, tools=tools, verbose=True)
return AgentExecutor(agent=agent, tools=tools, verbose=True)
20 changes: 15 additions & 5 deletions langsim/magics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@
from langsim.llm import get_executor


def get_output(messages):
def get_output(messages,agent_type,temp):
env = os.environ
agent_executor = get_executor(
api_provider=env.get("LANGSIM_PROVIDER", "OPENAI"),
api_key=env.get("LANGSIM_API_KEY"),
api_url=env.get("LANGSIM_API_URL", None),
api_model=env.get("LANGSIM_MODEL", None),
api_temperature=env.get("LANGSIM_TEMP", 0),
api_temperature=env.get("LANGSIM_TEMP", temp),
)
return list(agent_executor.stream({"conversation": messages}))[-1]

if agent_type =="react":
input_key = "input"
else:
input_key = "conversation"

return list(agent_executor.stream({input_key: messages}))[-1]


# Class to manage state and expose the main magics
Expand All @@ -50,7 +56,11 @@ def __init__(self, shell):
help="""Return output as raw text instead of rendering it as Markdown[Default: False].
"""
)
@argument('-T', '--temp', type=float, default=0.6,
@argument(
'-at', '--agent_type', action="store_true",
help="""Agent type: options are default or react."""
)
@argument('-T', '--temp', type=float, default=0.0,
help="""Temperature, float in [0,1]. Higher values push the algorithm
to generate more aggressive/"creative" output. [default=0.1].""")
@argument('prompt', nargs='*',
Expand All @@ -70,7 +80,7 @@ def chat(self, line, cell=None):
else:
prompt = cell
self.messages.append(("human", prompt))
response = get_output(self.messages)
response = get_output(self.messages,args.agent_type,args.temp)
output = response['output']
self.messages.append(("ai", output))
if args.raw:
Expand Down
9 changes: 9 additions & 0 deletions langsim/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@
- Do not start a calculation unless the human has chosen a model.
- Do not convert the AtomDict class to a python dictionary.
"""

SYSTEM_PROMPT_ALT = """Your name is LangSim and you are very powerful agent in the field of commputational materials science.
Rules:
- Do not start a calculation unless the human has chosen a model.
- Ask to the human which model want to use before running a calculation
- Do not convert the AtomDict class to a python dictionary.
"""

26 changes: 21 additions & 5 deletions langsim/tools/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def get_atom_dict_equilibrated_structure(atom_dict: AtomsDict, calculator_str: s
Args:
atom_dict (AtomsDict): DataClass representing the atomic structure
calculator_str (str): selected model specified by the calculator string
calculator_str (str): a model from the following options: "emt", "mace"
- "emt": Effective medium theory is a computationally efficient, analytical model
that describes the macroscopic properties of composite materials.
- "mace": this is a machine learning force field for predicting many-body atomic
interactions that covers the periodic table.
Returns:
AtomsDict: DataClass representing the equilibrated atomic structure
Expand All @@ -55,7 +59,11 @@ def plot_equation_of_state(atom_dict: AtomsDict, calculator_str: str) -> str:
Args:
atom_dict (AtomsDict): DataClass representing the atomic structure
calculator_str (str): selected model specified by the calculator string
calculator_str (str): a model from the following options: "emt", "mace"
- "emt": Effective medium theory is a computationally efficient, analytical model
that describes the macroscopic properties of composite materials.
- "mace": this is a machine learning force field for predicting many-body atomic
interactions that covers the periodic table.
Returns:
str: plot of the equation of state
Expand All @@ -74,7 +82,11 @@ def get_bulk_modulus(atom_dict: AtomsDict, calculator_str: str) -> str:
Args:
atom_dict (AtomsDict): DataClass representing the atomic structure
calculator_str (str): selected model specified by the calculator string
calculator_str (str): a model from the following options: "emt", "mace"
- "emt": Effective medium theory is a computationally efficient, analytical model
that describes the macroscopic properties of composite materials.
- "mace": this is a machine learning force field for predicting many-body atomic
interactions that covers the periodic table.
Returns:
str: Bulk Modulus in GPa
Expand All @@ -97,7 +109,11 @@ def get_equilibrium_volume(atom_dict: AtomsDict, calculator_str: str) -> str:
Args:
atom_dict (AtomsDict): DataClass representing the atomic structure
calculator_str (str): selected model specified by the calculator string
calculator_str (str): a model from the following options: "emt", "mace"
- "emt": Effective medium theory is a computationally efficient, analytical model
that describes the macroscopic properties of composite materials.
- "mace": this is a machine learning force field for predicting many-body atomic
interactions that covers the periodic table.
Returns:
str: Equilibrium volume in Angstrom^3
Expand Down Expand Up @@ -191,4 +207,4 @@ def get_element_property_mendeleev(chemical_symbol: str, property: str) -> str:
else:
return str(property_value)
else:
return f"Property '{property}' is not available for the element '{chemical_symbol}'."
return f"Property '{property}' is not available for the element '{chemical_symbol}'."

0 comments on commit 8cf5e1e

Please sign in to comment.