Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes for pydantic in gpt_langchain.py #1722

Merged
merged 1 commit into from
Aug 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions apps/language_models/langchain/gpt_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ class GradioInference(LLM):
chat_client: bool = False

return_full_text: bool = True
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False

prompter: Any = None
Expand Down Expand Up @@ -481,7 +481,7 @@ def _call(
# so server should get prompt_type or '', not plain
# This is good, so gradio server can also handle stopping.py conditions
# this is different than TGI server that uses prompter to inject prompt_type prompting
stream_output = self.stream
stream_output = self.stream_output
gr_client = self.client
client_langchain_mode = "Disabled"
client_langchain_action = LangChainAction.QUERY.value
Expand Down Expand Up @@ -596,7 +596,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
inference_server_url: str = ""
timeout: int = 300
headers: dict = None
stream: bool = False
stream_output: bool = Field(False, alias="stream")
sanitize_bot_response: bool = False
prompter: Any = None
tokenizer: Any = None
Expand Down Expand Up @@ -663,7 +663,7 @@ def _call(
# lower bound because client is re-used if multi-threading
self.client.timeout = max(300, self.timeout)

if not self.stream:
if not self.stream_output:
res = self.client.generate(
prompt,
**gen_server_kwargs,
Expand Down Expand Up @@ -852,7 +852,7 @@ def get_llm(
top_p=top_p,
# typical_p=top_p,
callbacks=callbacks if stream_output else None,
stream=stream_output,
stream_output=stream_output,
prompter=prompter,
tokenizer=tokenizer,
client=hf_client,
Expand Down
Loading