| | import jinja2 |
| | from aiflows.base_flows import AtomicFlow |
| | from aiflows.utils import logging |
| | from aiflows.utils import general_helpers |
| | from typing import Dict,Any,Optional,List |
| | from aiflows.prompt_template import JinjaPrompt |
| | from copy import deepcopy |
| | from aiflows.messages import FlowMessage |
| | import os |
| | import hydra |
| | log = logging.get_logger(__name__) |
| |
|
| | class DemonstrationsAtomicFlow(AtomicFlow): |
| | """ This class implements a Demonstrations Atomic Flow. It is a flow which is usually used to pass demonstrations (of user assistant interactions) |
| | to the ChatAtomicFlow. |
| | |
| | *Configuration Parameters*: |
| | |
| | - `name` (str): The name of the flow. Default: "DemonstrationsAtomicFlow" |
| | - `description` (str): A description of the flow. This description is used to generate the help message of the flow. |
| | Default: "A flow that passes demonstrations to the ChatFlow" |
| | - `data` (List[Dict[str, Any]]): The data of the demonstrations. |
| | If data is None, the data is loaded from the file specified in the params["data_dir"]. |
| | Default: No default value this field must be set. |
| | - `params` (Dict[str, Any]): The parameters specific to the dataset of the demonstrations. Its default parameters are: |
| | - `data_dir` (str): The directory where the demonstrations are stored. If the data is not directly passed to the flow through `data` then |
| | the data is loaded from this directory. Default: No default value this field must be set. |
| | - `demonstrations_id` (str): The id of the demonstrations (name of the data file). If the data is not directly passed to the flow through `data` then |
| | the data is loaded from this file. Default: No default value this field must be set. |
| | - `demonstrations_k` (int): The number of demonstrations to pass to the ChatFlow. |
| | If None, all the demonstrations are passed to the ChatFlow. Default: None |
| | - `query_prompt_template` (Dict[str, Any]): The prompt template used to generate the query of the demonstrations. |
| | By default its of type flows.prompt_template.JinjaPrompt. None of the parameters of the prompt are defined by default and therefore need to be defined if one |
| | wants to use the query_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt. |
| | - `response_prompt_template` (Dict[str, Any]): The prompt template used to generate the response of the demonstrations. By default its of type flows.prompt_template.JinjaPrompt. |
| | None of the parameters of the prompt are defined by default and therefore need to be defined if one |
| | wants to use the response_prompt_template. Default parameters are defined in flows.prompt_template.jinja2_prompts.JinjaPrompt. |
| | |
| | *Input Interface*: |
| | |
| | - The input interface expected by its successor flow (e.g. typically ChatAtomicFlow so the input interface is the one expected by ChatAtomicFlow) |
| | |
| | *Output Interface*: |
| | |
| | - Whichever data that was passed in the input_message (e.g. typically ChatAtomicFlow so the input interface expected by ChatAtomicFlow)) |
| | - `demonstrations` (List[Dict[str, Any]]): A list of demonstrations. Each demonstration is a dictionary with the following keys: |
| | - idx (int): The index of the demonstration |
| | - query (str): The query of the demonstration |
| | - response (str): The response of the demonstration |
| | |
| | :param params: The parameters specific to the dataset of the demonstrations. It must sould contain the following keys: |
| | - 'data_dir' (str): The directory where the demonstrations are stored. This field is used if the data is not directly passed to the flow through the 'data' field. |
| | - 'demonstrations_id' (str): The id of the demonstrations (name of the data file). This field is used if the data is not directly passed to the flow through the 'data' field. |
| | - 'demonstrations_k' (int): The number of demonstrations to pass to the ChatFlow. If None, all the demonstrations are passed to the ChatFlow. |
| | - 'ids_to_keep' (Optional[Union[str, List[str]]]): The ids of the demonstrations to keep. If None, all the demonstrations are kept. |
| | :type params: Dict[str, Any] |
| | :param query_prompt_template: The prompt template used to generate the query of the demonstrations. |
| | :type query_prompt_template: JinjaPrompt |
| | :param response_prompt_template: The prompt template used to generate the response of the demonstrations. |
| | :type response_prompt_template: JinjaPrompt |
| | :param data: The data of the demonstrations. If None, the data is loaded from the file specified in the params. |
| | :type data: Optional[List[Dict[str, Any]]] |
| | """ |
| | demonstrations_k: Optional[int] = None |
| | query_prompt_template: JinjaPrompt |
| | response_prompt_template: JinjaPrompt |
| | params: Dict |
| | |
| | def __init__(self,params,query_prompt_template,response_prompt_template, data=None,**kwargs): |
| | super().__init__(**kwargs) |
| | self.params = params |
| | self.data = data |
| | self.demonstrations_k = self.params.get("demonstrations_k", None) |
| | |
| | |
| | self.query_prompt_template = query_prompt_template |
| | |
| | self.response_prompt_template = response_prompt_template |
| | if self.data is None: |
| | self._load_data() |
| | |
| | @classmethod |
| | def _set_up_prompts(cls, config): |
| | """ This method instantiates the prompt templates of the flow (used when instantiating the flow from a config file) |
| | |
| | :param config: The configuration of the flow. |
| | :type config: Dict[str, Any] |
| | :return: A dictionary of keyword arguments to pass to the constructor of the flow. |
| | :rtype: Dict[str, Any] |
| | """ |
| | kwargs = {} |
| | kwargs["query_prompt_template"] = \ |
| | hydra.utils.instantiate(config['query_prompt_template'], _convert_="partial") |
| | kwargs["response_prompt_template"] = \ |
| | hydra.utils.instantiate(config['response_prompt_template'], _convert_="partial") |
| | return kwargs |
| | |
| | @classmethod |
| | def instantiate_from_config(cls, config): |
| | """ This method instantiates the flow from a config file. |
| | |
| | :param config: The configuration of the flow. |
| | :type config: Dict[str, Any] |
| | :return: The instantiated flow. |
| | :rtype: Flow |
| | """ |
| | flow_config = deepcopy(config) |
| |
|
| | kwargs = {"flow_config": flow_config} |
| |
|
| | |
| | kwargs.update(cls._set_up_prompts(flow_config)) |
| | kwargs.update({"params": flow_config["params"]}) |
| | kwargs.update({"data": flow_config["data"]}) |
| | |
| | return cls(**kwargs) |
| | |
| | def _get_query_message_content(self, sample_data: Dict): |
| | """ This method returns the query message content of a demonstration given the sample data (by rendering the query prompt template). |
| | |
| | :param sample_data: The sample data of the demonstration. |
| | :type sample_data: Dict[str, Any] |
| | :return: The query message content of the demonstration. |
| | :rtype: str |
| | """ |
| | input_variables = self.query_prompt_template.input_variables |
| | return self.query_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
| |
|
| | def _get_response_message_content(self, sample_data: Dict): |
| | """ This method returns the response message content of a demonstration given the sample data (by rendering the response prompt template). |
| | |
| | :param sample_data: The sample data of the demonstration. |
| | :type sample_data: Dict[str, Any] |
| | :return: The response message content of the demonstration. |
| | :rtype: str |
| | """ |
| | input_variables = self.response_prompt_template.input_variables |
| | return self.response_prompt_template.format(**{k: sample_data[k] for k in input_variables}) |
| | |
| | def _get_io_pair(self, idx): |
| | """ This method, given the index of a demonstration, returns an query-response pair from the demonstrations data. |
| | |
| | :param idx: The index of the demonstration. |
| | :type idx: int |
| | :return: The query-response pair at idx from the demonstrations data. |
| | :rtype: Dict[str, Any] |
| | """ |
| | dp = self.data[idx] |
| | |
| | query_data = dp["query_data"] |
| | response_data = dp["response_data"] |
| | |
| | query = self._get_query_message_content(query_data) |
| | response = self._get_response_message_content(response_data) |
| | |
| | return {"idx": idx, "query": query,"response": response} |
| | |
| | def _get_io_pairs(self,input_data: Dict[str, Any]) -> List[Any]: |
| | """ This method returns the demonstrations that are passed to the destination flow (typically ChatAtomicFlow). |
| | |
| | :param input_data: The input data of the flow. |
| | :type input_data: Dict[str, Any] |
| | :return: The demonstrations that are passed to the destination flow. |
| | :rtype: List[Any] |
| | """ |
| | demonstrations_k = self.demonstrations_k if self.demonstrations_k is not None else len(self.data) |
| | io_pairs = [self._get_io_pair(idx) for idx in range(demonstrations_k)] |
| | return io_pairs |
| | |
| | def _load_data(self): |
| | """ This method loads the demonstrations from the file specified in the params. It also filters the demonstrations if the ids_to_keep parameter is specified.""" |
| | demonstrations_file = os.path.join(self.params["data_dir"], f"{self.params['demonstrations_id']}.jsonl") |
| | self.data = general_helpers.read_jsonlines(demonstrations_file) |
| | |
| | if self.params.get("ids_to_keep", False): |
| | if isinstance(self.params["ids_to_keep"], str): |
| | ids_to_keep = set(self.params["ids_to_keep"].split(",")) |
| | else: |
| | ids_to_keep = set(self.params["ids_to_keep"]) |
| |
|
| | self.data = [d for d in self.data if d["id"] in ids_to_keep] |
| |
|
| | log.info("Loaded the demonstrations for %d datapoints from %s", len(self.data), self.params["data_dir"]) |
| |
|
| | def run(self, |
| | input_message: FlowMessage): |
| | """ This method runs the flow. It returns the data of the input_message with the demonstrations added to it. |
| | |
| | :param input_message: The input message of the flow. |
| | :type input_message: FlowMessage |
| | """ |
| | input_data = input_message.data |
| | reply = self.package_output_message( |
| | input_message=input_message, |
| | response = {**{"demonstrations": self._get_io_pairs(input_data=input_data)},**input_data} |
| | ) |
| | self.send_message(reply) |