Skip to content
82 changes: 65 additions & 17 deletions private_gpt/ui/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@ class Modes(str, Enum):
]


class Styles(str, Enum):
STREAMING = "Streaming"
NON_STREAMING = "Non-Streaming"


STYLES: list[Styles] = [Styles.STREAMING, Styles.NON_STREAMING]


class Source(BaseModel):
file: str
page: str
Expand Down Expand Up @@ -105,6 +113,9 @@ def __init__(
)
self._system_prompt = self._get_default_system_prompt(self._default_mode)

# Initialize default response style: Streaming
self.response_style = STYLES[0]

def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
Expand Down Expand Up @@ -185,18 +196,30 @@ def build_history() -> list[ChatMessage]:
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)

query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
context_filter=context_filter,
)
yield from yield_deltas(query_stream)
match self.response_style:
case Styles.STREAMING:
query_stream = self._chat_service.stream_chat(
all_messages, use_context=False
)
yield from yield_deltas(query_stream)
case Styles.NON_STREAMING:
query_response = self._chat_service.chat(
all_messages, use_context=False
).response
yield from [query_response]

case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
match self.response_style:
case Styles.STREAMING:
llm_stream = self._chat_service.stream_chat(
all_messages, use_context=False
)
yield from yield_deltas(llm_stream)
case Styles.NON_STREAMING:
llm_response = self._chat_service.chat(
all_messages, use_context=False
).response
yield from [llm_response]

case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant(
Expand Down Expand Up @@ -224,12 +247,21 @@ def build_history() -> list[ChatMessage]:
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)

summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
match self.response_style:
case Styles.STREAMING:
summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
case Styles.NON_STREAMING:
summary_response = self._summarize_service.summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from summary_response

# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
Expand Down Expand Up @@ -282,6 +314,9 @@ def _set_current_mode(self, mode: Modes) -> Any:
gr.update(value=self._explanation_mode),
]

def _set_current_response_style(self, response_style: Styles) -> Any:
self.response_style = response_style

def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
Expand Down Expand Up @@ -405,6 +440,15 @@ def _build_ui_blocks(self) -> gr.Blocks:
max_lines=3,
interactive=False,
)
default_response_style = STYLES[0]
response_style = (
gr.Dropdown(
[response_style.value for response_style in STYLES],
label="Response Style",
value=default_response_style,
interactive=True,
),
)
upload_button = gr.components.UploadButton(
"Upload File(s)",
type="filepath",
Expand Down Expand Up @@ -498,6 +542,10 @@ def _build_ui_blocks(self) -> gr.Blocks:
self._set_system_prompt,
inputs=system_prompt_input,
)
# When response style changes
response_style[0].change(
self._set_current_response_style, inputs=response_style
)

def get_model_label() -> str | None:
"""Get model label from llm mode setting YAML.
Expand Down