| """ |
| Client test. |
| |
| Run server: |
| |
| python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b |
| |
| NOTE: For private models, add --use-auth_token=True |
| |
| NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches. |
| Currently, this will force model to be on a single GPU. |
| |
| Then run this client as: |
| |
| python src/client_test.py |
| |
| |
| |
| For HF spaces: |
| |
| HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py |
| |
| Result: |
| |
| Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔ |
| {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''} |
| |
| |
| For demo: |
| |
| HOST="https://gpt.h2o.ai" python src/client_test.py |
| |
| Result: |
| |
| Loaded as API: https://gpt.h2o.ai ✔ |
| {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''} |
| |
| NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict: |
| |
| {'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''} |
| |
| |
| """ |
| import ast |
| import time |
| import os |
| import markdown |
| import pytest |
| from bs4 import BeautifulSoup |
|
|
| from src.utils import is_gradio_version4 |
|
|
| try: |
| from enums import DocumentSubset, LangChainAction |
| except: |
| from src.enums import DocumentSubset, LangChainAction |
|
|
| from tests.utils import get_inf_server |
|
|
| debug = False |
|
|
| os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' |
|
|
|
|
| def get_client(serialize=not is_gradio_version4): |
| from gradio_client import Client |
|
|
| client = Client(get_inf_server(), serialize=serialize) |
| if debug: |
| print(client.view_api(all_endpoints=True)) |
| return client |
|
|
|
|
| def get_args(prompt, prompt_type=None, chat=False, stream_output=False, |
| max_new_tokens=50, |
| top_k_docs=3, |
| langchain_mode='Disabled', |
| add_chat_history_to_context=True, |
| langchain_action=LangChainAction.QUERY.value, |
| langchain_agents=[], |
| prompt_dict=None, |
| version=None, |
| h2ogpt_key=None, |
| visible_models=None, |
| system_prompt='', |
| add_search_to_context=False, |
| chat_conversation=None, |
| text_context_list=None, |
| document_choice=[], |
| document_source_substrings=[], |
| document_source_substrings_op='and', |
| document_content_substrings=[], |
| document_content_substrings_op='and', |
| max_time=20, |
| repetition_penalty=1.0, |
| do_sample=True, |
| ): |
| from collections import OrderedDict |
| kwargs = OrderedDict(instruction=prompt if chat else '', |
| iinput='', |
| context='', |
| |
| |
| stream_output=stream_output, |
| prompt_type=prompt_type, |
| prompt_dict=prompt_dict, |
| temperature=0.1, |
| top_p=0.75, |
| top_k=40, |
| penalty_alpha=0, |
| num_beams=1, |
| max_new_tokens=max_new_tokens, |
| min_new_tokens=0, |
| early_stopping=False, |
| max_time=max_time, |
| repetition_penalty=repetition_penalty, |
| num_return_sequences=1, |
| do_sample=do_sample, |
| chat=chat, |
| instruction_nochat=prompt if not chat else '', |
| iinput_nochat='', |
| langchain_mode=langchain_mode, |
| add_chat_history_to_context=add_chat_history_to_context, |
| langchain_action=langchain_action, |
| langchain_agents=langchain_agents, |
| top_k_docs=top_k_docs, |
| chunk=True, |
| chunk_size=512, |
| document_subset=DocumentSubset.Relevant.name, |
| document_choice=[] or document_choice, |
| document_source_substrings=[] or document_source_substrings, |
| document_source_substrings_op='and' or document_source_substrings_op, |
| document_content_substrings=[] or document_content_substrings, |
| document_content_substrings_op='and' or document_content_substrings_op, |
| pre_prompt_query=None, |
| prompt_query=None, |
| pre_prompt_summary=None, |
| prompt_summary=None, |
| hyde_llm_prompt=None, |
| system_prompt=system_prompt, |
| image_audio_loaders=None, |
| pdf_loaders=None, |
| url_loaders=None, |
| jq_schema=None, |
| extract_frames=None, |
| llava_prompt=None, |
| visible_models=visible_models, |
| h2ogpt_key=h2ogpt_key, |
| add_search_to_context=add_search_to_context, |
| chat_conversation=chat_conversation, |
| text_context_list=text_context_list, |
| docs_ordering_type=None, |
| min_max_new_tokens=None, |
| max_input_tokens=None, |
| max_total_input_tokens=None, |
| docs_token_handling=None, |
| docs_joiner=None, |
| hyde_level=0, |
| hyde_template=None, |
| hyde_show_only_final=False, |
| doc_json_mode=False, |
|
|
| chatbot_role='None', |
| speaker='None', |
| tts_language='autodetect', |
| tts_speed=1.0, |
| ) |
| diff = 0 |
| if version is None: |
| |
| version = 1 |
| if version == 0: |
| diff = 1 |
| if version >= 1: |
| kwargs.update(dict(system_prompt=system_prompt)) |
| diff = 0 |
|
|
| from evaluate_params import eval_func_param_names |
| assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == diff |
| if chat: |
| |
| kwargs.update(dict(chatbot=[])) |
|
|
| return kwargs, list(kwargs.values()) |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_basic(prompt_type='human_bot', version=None, visible_models=None, prompt='Who are you?', |
| h2ogpt_key=None): |
| return run_client_nochat(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50, version=version, |
| visible_models=visible_models, h2ogpt_key=h2ogpt_key) |
|
|
|
|
| """ |
| time HOST=https://gpt-internal.h2o.ai PYTHONPATH=. pytest -n 20 src/client_test.py::test_client_basic_benchmark |
| 32 seconds to answer 20 questions at once with 70B llama2 on 4x A100 80GB using TGI 0.9.3 |
| """ |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| @pytest.mark.parametrize("id", range(20)) |
| def test_client_basic_benchmark(id, prompt_type='human_bot', version=None): |
| return run_client_nochat(prompt=""" |
| /nfs4/llm/h2ogpt/h2ogpt/bin/python /home/arno/pycharm-2022.2.2/plugins/python/helpers/pycharm/_jb_pytest_runner.py --target src/client_test.py::test_client_basic |
| Testing started at 8:41 AM ... |
| Launching pytest with arguments src/client_test.py::test_client_basic --no-header --no-summary -q in /nfs4/llm/h2ogpt |
| |
| ============================= test session starts ============================== |
| collecting ... |
| src/client_test.py:None (src/client_test.py) |
| ImportError while importing test module '/nfs4/llm/h2ogpt/src/client_test.py'. |
| Hint: make sure your test modules/packages have valid Python names. |
| Traceback: |
| h2ogpt/lib/python3.10/site-packages/_pytest/python.py:618: in _importtestmodule |
| mod = import_path(self.path, mode=importmode, root=self.config.rootpath) |
| h2ogpt/lib/python3.10/site-packages/_pytest/pathlib.py:533: in import_path |
| importlib.import_module(module_name) |
| /usr/lib/python3.10/importlib/__init__.py:126: in import_module |
| return _bootstrap._gcd_import(name[level:], package, level) |
| <frozen importlib._bootstrap>:1050: in _gcd_import |
| ??? |
| <frozen importlib._bootstrap>:1027: in _find_and_load |
| ??? |
| <frozen importlib._bootstrap>:1006: in _find_and_load_unlocked |
| ??? |
| <frozen importlib._bootstrap>:688: in _load_unlocked |
| ??? |
| h2ogpt/lib/python3.10/site-packages/_pytest/assertion/rewrite.py:168: in exec_module |
| exec(co, module.__dict__) |
| src/client_test.py:51: in <module> |
| from enums import DocumentSubset, LangChainAction |
| E ModuleNotFoundError: No module named 'enums' |
| |
| |
| collected 0 items / 1 error |
| |
| =============================== 1 error in 0.14s =============================== |
| ERROR: not found: /nfs4/llm/h2ogpt/src/client_test.py::test_client_basic |
| (no name '/nfs4/llm/h2ogpt/src/client_test.py::test_client_basic' in any of [<Module client_test.py>]) |
| |
| |
| Process finished with exit code 4 |
| |
| What happened? |
| """, prompt_type=prompt_type, max_new_tokens=100, version=version) |
|
|
|
|
| def run_client_nochat(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None, visible_models=None): |
| kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version, |
| visible_models=visible_models, h2ogpt_key=h2ogpt_key) |
|
|
| api_name = '/submit_nochat' |
| client = get_client(serialize=not is_gradio_version4) |
| res = client.predict( |
| *tuple(args), |
| api_name=api_name, |
| ) |
| print("Raw client result: %s" % res, flush=True) |
| res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], |
| response=md_to_text(res)) |
| print(res_dict) |
| return res_dict, client |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_basic_api(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50, version=version, |
| h2ogpt_key=h2ogpt_key) |
|
|
|
|
| def run_client_nochat_api(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None): |
| kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens, version=version, |
| h2ogpt_key=h2ogpt_key) |
|
|
| api_name = '/submit_nochat_api' |
| client = get_client(serialize=not is_gradio_version4) |
| res = client.predict( |
| str(dict(kwargs)), |
| api_name=api_name, |
| ) |
| print("Raw client result: %s" % res, flush=True) |
| res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], |
| response=md_to_text(ast.literal_eval(res)['response']), |
| sources=ast.literal_eval(res)['sources']) |
| print(res_dict) |
| return res_dict, client |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_basic_api_lean(prompt='Who are you?', prompt_type='human_bot', version=None, h2ogpt_key=None, |
| chat_conversation=None, system_prompt=''): |
| return run_client_nochat_api_lean(prompt=prompt, prompt_type=prompt_type, max_new_tokens=50, |
| version=version, h2ogpt_key=h2ogpt_key, |
| chat_conversation=chat_conversation, |
| system_prompt=system_prompt) |
|
|
|
|
| def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens, version=None, h2ogpt_key=None, |
| chat_conversation=None, system_prompt=''): |
| kwargs = dict(instruction_nochat=prompt, h2ogpt_key=h2ogpt_key, chat_conversation=chat_conversation, |
| system_prompt=system_prompt) |
|
|
| api_name = '/submit_nochat_api' |
| client = get_client(serialize=not is_gradio_version4) |
| res = client.predict( |
| str(dict(kwargs)), |
| api_name=api_name, |
| ) |
| print("Raw client result: %s" % res, flush=True) |
| res_dict = dict(prompt=kwargs['instruction_nochat'], |
| response=md_to_text(ast.literal_eval(res)['response']), |
| sources=ast.literal_eval(res)['sources'], |
| h2ogpt_key=h2ogpt_key) |
| print(res_dict) |
| return res_dict, client |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_basic_api_lean_morestuff(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50, |
| version=version, h2ogpt_key=h2ogpt_key) |
|
|
|
|
| def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512, version=None, |
| h2ogpt_key=None): |
| kwargs = dict( |
| instruction='', |
| iinput='', |
| context='', |
| stream_output=False, |
| prompt_type=prompt_type, |
| temperature=0.1, |
| top_p=0.75, |
| top_k=40, |
| penalty_alpha=0, |
| num_beams=1, |
| max_new_tokens=1024, |
| min_new_tokens=0, |
| early_stopping=False, |
| max_time=20, |
| repetition_penalty=1.0, |
| num_return_sequences=1, |
| do_sample=True, |
| chat=False, |
| instruction_nochat=prompt, |
| iinput_nochat='', |
| langchain_mode='Disabled', |
| add_chat_history_to_context=True, |
| langchain_action=LangChainAction.QUERY.value, |
| langchain_agents=[], |
| top_k_docs=4, |
| document_subset=DocumentSubset.Relevant.name, |
| document_choice=[], |
| document_source_substrings=[], |
| document_source_substrings_op='and', |
| document_content_substrings=[], |
| document_content_substrings_op='and', |
| h2ogpt_key=h2ogpt_key, |
| add_search_to_context=False, |
| ) |
|
|
| api_name = '/submit_nochat_api' |
| client = get_client(serialize=not is_gradio_version4) |
| res = client.predict( |
| str(dict(kwargs)), |
| api_name=api_name, |
| ) |
| print("Raw client result: %s" % res, flush=True) |
| res_dict = dict(prompt=kwargs['instruction_nochat'], |
| response=md_to_text(ast.literal_eval(res)['response']), |
| sources=ast.literal_eval(res)['sources'], |
| h2ogpt_key=h2ogpt_key) |
| print(res_dict) |
| return res_dict, client |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_chat(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50, |
| langchain_mode='Disabled', |
| langchain_action=LangChainAction.QUERY.value, |
| langchain_agents=[], |
| version=version, |
| h2ogpt_key=h2ogpt_key) |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_chat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, |
| stream_output=True, max_new_tokens=512, |
| langchain_mode='Disabled', |
| langchain_action=LangChainAction.QUERY.value, |
| langchain_agents=[], |
| version=version, |
| h2ogpt_key=h2ogpt_key) |
|
|
|
|
| def run_client_chat(prompt='', |
| stream_output=None, |
| max_new_tokens=128, |
| langchain_mode='Disabled', |
| langchain_action=LangChainAction.QUERY.value, |
| langchain_agents=[], |
| prompt_type=None, prompt_dict=None, |
| version=None, |
| h2ogpt_key=None, |
| chat_conversation=None, |
| system_prompt='', |
| document_choice=[], |
| document_content_substrings=[], |
| document_content_substrings_op='and', |
| document_source_substrings=[], |
| document_source_substrings_op='and', |
| top_k_docs=3, |
| max_time=20, |
| repetition_penalty=1.0, |
| do_sample=True): |
| client = get_client(serialize=False) |
|
|
| kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, |
| max_new_tokens=max_new_tokens, |
| langchain_mode=langchain_mode, |
| langchain_action=langchain_action, |
| langchain_agents=langchain_agents, |
| prompt_dict=prompt_dict, |
| version=version, |
| h2ogpt_key=h2ogpt_key, |
| chat_conversation=chat_conversation, |
| system_prompt=system_prompt, |
| document_choice=document_choice, |
| document_source_substrings=document_source_substrings, |
| document_source_substrings_op=document_source_substrings_op, |
| document_content_substrings=document_content_substrings, |
| document_content_substrings_op=document_content_substrings_op, |
| top_k_docs=top_k_docs, |
| max_time=max_time, |
| repetition_penalty=repetition_penalty, |
| do_sample=do_sample) |
| return run_client(client, prompt, args, kwargs) |
|
|
|
|
| def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): |
| if is_gradio_version4: |
| kwargs['answer_with_sources'] = True |
| kwargs['show_accordions'] = True |
| kwargs['append_sources_to_answer'] = True |
| kwargs['append_sources_to_chat'] = False |
| kwargs['show_link_in_sources'] = True |
| res_dict, client = run_client_gen(client, kwargs, do_md_to_text=do_md_to_text) |
| res_dict['response'] += str(res_dict['sources_str']) |
| return res_dict, client |
| |
|
|
| assert kwargs['chat'], "Chat mode only" |
| res = client.predict(*tuple(args), api_name='/instruction') |
| args[-1] += [res[-1]] |
|
|
| res_dict = kwargs |
| res_dict['prompt'] = prompt |
| if not kwargs['stream_output']: |
| res = client.predict(*tuple(args), api_name='/instruction_bot') |
| res_dict['response'] = res[0][-1][1] |
| print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) |
| return res_dict, client |
| else: |
| job = client.submit(*tuple(args), api_name='/instruction_bot') |
| res1 = '' |
| while not job.done(): |
| outputs_list = job.communicator.job.outputs |
| if outputs_list: |
| res = job.communicator.job.outputs[-1] |
| res1 = res[0][-1][-1] |
| res1 = md_to_text(res1, do_md_to_text=do_md_to_text) |
| print(res1) |
| time.sleep(0.1) |
| full_outputs = job.outputs() |
| if verbose: |
| print('job.outputs: %s' % str(full_outputs)) |
| |
| |
| |
| |
| |
| res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text) |
| return res_dict, client |
|
|
|
|
| @pytest.mark.skip(reason="For manual use against some server, no server launched") |
| def test_client_nochat_stream(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, |
| stream_output=True, max_new_tokens=512, |
| langchain_mode='Disabled', |
| langchain_action=LangChainAction.QUERY.value, |
| langchain_agents=[], |
| version=version, |
| h2ogpt_key=h2ogpt_key) |
|
|
|
|
| def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, |
| langchain_mode, langchain_action, langchain_agents, version=None, |
| h2ogpt_key=None): |
| client = get_client(serialize=False) |
|
|
| kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output, |
| max_new_tokens=max_new_tokens, langchain_mode=langchain_mode, |
| langchain_action=langchain_action, langchain_agents=langchain_agents, |
| version=version, h2ogpt_key=h2ogpt_key) |
| return run_client_gen(client, kwargs) |
|
|
|
|
| def run_client_gen(client, kwargs, do_md_to_text=True): |
| res_dict = kwargs |
| res_dict['prompt'] = kwargs['instruction'] or kwargs['instruction_nochat'] |
| if not kwargs['stream_output']: |
| res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api') |
| res_dict1 = ast.literal_eval(res) |
| res_dict.update(res_dict1) |
| print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) |
| return res_dict, client |
| else: |
| job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api') |
| while not job.done(): |
| outputs_list = job.communicator.job.outputs |
| if outputs_list: |
| res = job.communicator.job.outputs[-1] |
| res_dict1 = ast.literal_eval(res) |
| print('Stream: %s' % res_dict1['response']) |
| time.sleep(0.1) |
| res_list = job.outputs() |
| assert len(res_list) > 0, "No response, check server" |
| res = res_list[-1] |
| res_dict1 = ast.literal_eval(res) |
| print('Final: %s' % res_dict1['response']) |
| res_dict.update(res_dict1) |
| return res_dict, client |
|
|
|
|
| def md_to_text(md, do_md_to_text=True): |
| if not do_md_to_text: |
| return md |
| assert md is not None, "Markdown is None" |
| html = markdown.markdown(md) |
| soup = BeautifulSoup(html, features='html.parser') |
| return soup.get_text() |
|
|
|
|
| def run_client_many(prompt_type='human_bot', version=None, h2ogpt_key=None): |
| kwargs = dict(prompt_type=prompt_type, version=version, h2ogpt_key=h2ogpt_key) |
| ret1, _ = test_client_chat(**kwargs) |
| ret2, _ = test_client_chat_stream(**kwargs) |
| ret3, _ = test_client_nochat_stream(**kwargs) |
| ret4, _ = test_client_basic(**kwargs) |
| ret5, _ = test_client_basic_api(**kwargs) |
| ret6, _ = test_client_basic_api_lean(**kwargs) |
| ret7, _ = test_client_basic_api_lean_morestuff(**kwargs) |
| return ret1, ret2, ret3, ret4, ret5, ret6, ret7 |
|
|
|
|
| if __name__ == '__main__': |
| run_client_many() |
|
|