diff --git a/README.md b/README.md index 182b5086..dae8b41a 100644 --- a/README.md +++ b/README.md @@ -6,15 +6,16 @@

| Research preview | Paper | Website | - +

**Latest News** 🔥 +- [2024/07] You can now install our package with `pip install knowledge-storm`! - [2024/07] We add `VectorRM` to support grounding on user-provided documents, complementing existing support of search engines (`YouRM`, `BingSearch`). (check out [#58](https://github.com/stanford-oval/storm/pull/58)) - [2024/07] We release demo light for developers a minimal user interface built with streamlit framework in Python, handy for local development and demo hosting (checkout [#54](https://github.com/stanford-oval/storm/pull/54)) - [2024/06] We will present STORM at NAACL 2024! Find us at Poster Session 2 on June 17 or check our [presentation material](assets/storm_naacl2024_slides.pdf). -- [2024/05] We add Bing Search support in [rm.py](src/rm.py). Test STORM with `GPT-4o` - we now configure the article generation part in our demo using `GPT-4o` model. -- [2024/04] We release refactored version of STORM codebase! We define [interface](src/interface.py) for STORM pipeline and reimplement STORM-wiki (check out [`src/storm_wiki`](src/storm_wiki)) to demonstrate how to instantiate the pipeline. We provide API to support customization of different language models and retrieval/search integration. +- [2024/05] We add Bing Search support in [rm.py](knowledge_storm/rm.py). Test STORM with `GPT-4o` - we now configure the article generation part in our demo using `GPT-4o` model. +- [2024/04] We release refactored version of STORM codebase! We define [interface](knowledge_storm/interface.py) for STORM pipeline and reimplement STORM-wiki (check out [`src/storm_wiki`](knowledge_storm/storm_wiki)) to demonstrate how to instantiate the pipeline. We provide API to support customization of different language models and retrieval/search integration. ## Overview [(Try STORM now!)](https://storm.genie.stanford.edu/) @@ -46,17 +47,17 @@ Based on the separation of the two stages, STORM is implemented in a highly modu -## Getting started +## Installation -### 1. Setup -Below, we provide a quick start guide to run STORM locally. +To install the knowledge storm library, use `pip install knowledge-storm`. +You could also install the source code which allows you to modify the behavior of STORM engine directly. 1. Clone the git repository. - ```shell - git clone https://github.com/stanford-oval/storm.git - cd storm - ``` + ```shell + git clone https://github.com/stanford-oval/storm.git + cd storm + ``` 2. Install the required packages. ```shell @@ -64,7 +65,71 @@ Below, we provide a quick start guide to run STORM locally. conda activate storm pip install -r requirements.txt ``` -3. Set up OpenAI API key (if you want to use OpenAI models to power STORM) and [You.com search API](https://api.you.com/) key. Create a file `secrets.toml` under the root directory and add the following content: + + +## API +The STORM knowledge curation engine is defined as a simple Python `STORMWikiRunner` class. + +As STORM is working in the information curation layer, you need to set up the information retrieval module and language model module to create a `STORMWikiRunner` instance. Here is an example of using You.com search engine and OpenAI models. +```python +import os +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.lm import OpenAIModel +from knowledge_storm.rm import YouRM + +lm_configs = STORMWikiLMConfigs() +openai_kwargs = { + 'api_key': os.getenv("OPENAI_API_KEY"), + 'temperature': 1.0, + 'top_p': 0.9, +} +# STORM is a LM system so different components can be powered by different models to reach a good balance between cost and quality. +# For a good practice, choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation. +# Choose a more powerful model for `article_gen_lm` to generate verifiable text with citations. +gpt_35 = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) +gpt_4 = OpenAIModel(model='gpt-4-o', max_tokens=3000, **openai_kwargs) +lm_configs.set_conv_simulator_lm(gpt_35) +lm_configs.set_question_asker_lm(gpt_35) +lm_configs.set_outline_gen_lm(gpt_4) +lm_configs.set_article_gen_lm(gpt_4) +lm_configs.set_article_polish_lm(gpt_4) +# Check out the STORMWikiRunnerArguments class for more configurations. +engine_args = STORMWikiRunnerArguments(...) +rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) +runner = STORMWikiRunner(engine_args, lm_configs, rm) +``` + +Currently, our package support: +- `OpenAIModel`, `AzureOpenAIModel`, `ClaudeModel`, `VLLMClient`, `TGIClient`, `TogetherClient`, `OllamaClient` as language model components +- `YouRM`, `BingSearch`, `VectorRM` as retrieval module components + +:star2: **PRs for integrating more language models into [knowledge_storm/lm.py](knowledge_storm/lm.py) and search engines/retrievers into [knowledge_storm/rm.py](knowledge_storm/rm.py) are highly appreciated!** + +The `STORMWikiRunner` instance can be evoked with the simple `run` method: +```python +topic = input('Topic: ') +runner.run( + topic=topic, + do_research=True, + do_generate_outline=True, + do_generate_article=True, + do_polish_article=True, +) +runner.post_run() +runner.summary() +``` +- `do_research`: if True, simulate conversations with difference perspectives to collect information about the topic; otherwise, load the results. +- `do_generate_outline`: if True, generate an outline for the topic; otherwise, load the results. +- `do_generate_article`: if True, generate an article for the topic based on the outline and the collected information; otherwise, load the results. +- `do_polish_article`: if True, polish the article by adding a summarization section and (optionally) removing duplicate content; otherwise, load the results. + + +## Quick Start with Example Scripts + +We provide scripts in our [examples folder](examples) as a quick start to run STORM with different configurations. + +**To run STORM with `gpt` family models with default configurations:** +1. We suggest using `secrets.toml` to set up the API keys. Create a file `secrets.toml` under the root directory and add the following content: ```shell # Set up OpenAI API key. OPENAI_API_KEY="your_openai_api_key" @@ -77,74 +142,31 @@ Below, we provide a quick start guide to run STORM locally. # Set up You.com search API key. YDC_API_KEY="your_youcom_api_key" ``` +2. Run the following command. + ``` + python examples/run_storm_wiki_gpt.py \ + --output-dir $OUTPUT_DIR \ + --retriever you \ + --do-research \ + --do-generate-outline \ + --do-generate-article \ + --do-polish-article + ``` +**To run STORM using your favorite language models or grounding on your own corpus:** Check out [examples/README.md](examples/README.md). -### 2. Running STORM-wiki locally - -**To run STORM with `gpt` family models with default configurations**: Make sure you have set up the OpenAI API key and run the following command. - -``` -python examples/run_storm_wiki_gpt.py \ - --output-dir $OUTPUT_DIR \ - --retriever you \ - --do-research \ - --do-generate-outline \ - --do-generate-article \ - --do-polish-article -``` -- `--do-research`: if True, simulate conversation to research the topic; otherwise, load the results. -- `--do-generate-outline`: If True, generate an outline for the topic; otherwise, load the results. -- `--do-generate-article`: If True, generate an article for the topic; otherwise, load the results. -- `--do-polish-article`: If True, polish the article by adding a summarization section and (optionally) removing duplicate content. - - -We provide more example scripts under [`examples`](examples) to demonstrate how you can run STORM using your favorite language models or grounding on your own corpus. - - -## Customize STORM -### Customization of the Pipeline +## Customization of the Pipeline -Besides running scripts in `examples`, you can customize STORM based on your own use case. STORM engine consists of 4 modules: +If you have installed the source code, you can customize STORM based on your own use case. STORM engine consists of 4 modules: 1. Knowledge Curation Module: Collects a broad coverage of information about the given topic. 2. Outline Generation Module: Organizes the collected information by generating a hierarchical outline for the curated knowledge. 3. Article Generation Module: Populates the generated outline with the collected information. 4. Article Polishing Module: Refines and enhances the written article for better presentation. -The interface for each module is defined in `src/interface.py`, while their implementations are instantiated in `src/storm_wiki/modules/*`. These modules can be customized according to your specific requirements (e.g., generating sections in bullet point format instead of full paragraphs). - -:star2: **You can share your customization of `Engine` by making PRs to this repo!** - -### Customization of Retriever Module - -As a knowledge curation engine, STORM grabs information from the Retriever module. The Retriever modules are implemented in [`src/rm.py`](src/rm.py). Currently, STORM supports the following retrievers: +The interface for each module is defined in `knowledge_storm/interface.py`, while their implementations are instantiated in `knowledge_storm/storm_wiki/modules/*`. These modules can be customized according to your specific requirements (e.g., generating sections in bullet point format instead of full paragraphs). -- `YouRM`: You.com search engine API -- `BingSearch`: Bing Search API -- `VectorRM`: a retrieval model that retrieves information from user provide corpus - -:star2: **PRs for integrating more search engines/retrievers are highly appreciated!** - -### Customization of Language Models - -STORM provides the following language model implementations in [`src/lm.py`](src/lm.py): - -- `OpenAIModel` -- `ClaudeModel` -- `VLLMClient` -- `TGIClient` -- `TogetherClient` - -:star2: **PRs for integrating more language model clients are highly appreciated!** - -:bulb: **For a good practice,** - -- choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation. -- if you need to conduct the actual writing step, choose a more powerful model for `article_gen_lm`. Based on our experiments, weak models are bad at generating text with citations. -- for open models, adding one-shot example can help it better follow instructions. - -Please refer to the scripts in the [`examples`](examples) directory for concrete guidance on customizing the language model used in the pipeline. ## Replicate NAACL2024 result @@ -157,7 +179,7 @@ Please switch to the branch `NAACL-2024-code-backup` The FreshWiki dataset used in our experiments can be found in [./FreshWiki](FreshWiki). -Run the following commands under [./src](src). +Run the following commands under [./src](knowledge_storm). #### Pre-writing Stage For batch experiment on FreshWiki dataset: @@ -196,7 +218,7 @@ python -m scripts.run_writing --input-source console --engine gpt-4 --do-polish- The generated article will be saved in `{output_dir}/{topic}/storm_gen_article.txt` and the references corresponding to citation index will be saved in `{output_dir}/{topic}/url_to_info.json`. If `--do-polish-article` is set, the polished article will be saved in `{output_dir}/{topic}/storm_gen_article_polished.txt`. ### Customize the STORM Configurations -We set up the default LLM configuration in `LLMConfigs` in [src/modules/utils.py](src/modules/utils.py). You can use `set_conv_simulator_lm()`,`set_question_asker_lm()`, `set_outline_gen_lm()`, `set_article_gen_lm()`, `set_article_polish_lm()` to override the default configuration. These functions take in an instance from `dspy.dsp.LM` or `dspy.dsp.HFModel`. +We set up the default LLM configuration in `LLMConfigs` in [src/modules/utils.py](knowledge_storm/modules/utils.py). You can use `set_conv_simulator_lm()`,`set_question_asker_lm()`, `set_outline_gen_lm()`, `set_article_gen_lm()`, `set_article_polish_lm()` to override the default configuration. These functions take in an instance from `dspy.dsp.LM` or `dspy.dsp.HFModel`. ### Automatic Evaluation @@ -224,7 +246,11 @@ For rubric grading, we use the [prometheus-13b-v1.0](https://huggingface.co/prom -## Contributions +## Roadmap & Contributions +Our team is actively working on: +1. Human-in-the-Loop Functionalities: Supporting user participation in the knowledge curation process. +2. Information Abstraction: Developing abstractions for curated information to support presentation formats beyond the Wikipedia-style report. + If you have any questions or suggestions, please feel free to open an issue or pull request. We welcome contributions to improve the system and the codebase! Contact person: [Yijia Shao](mailto:shaoyj@stanford.edu) and [Yucheng Jiang](mailto:yuchengj@stanford.edu) diff --git a/examples/run_storm_wiki_claude.py b/examples/run_storm_wiki_claude.py index f29ba1ce..31fef1e1 100644 --- a/examples/run_storm_wiki_claude.py +++ b/examples/run_storm_wiki_claude.py @@ -17,14 +17,12 @@ """ import os -import sys from argparse import ArgumentParser -sys.path.append('./src') -from lm import ClaudeModel -from rm import YouRM, BingSearch -from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from utils import load_api_key +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.lm import ClaudeModel +from knowledge_storm.rm import YouRM, BingSearch +from knowledge_storm.utils import load_api_key def main(args): @@ -116,4 +114,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/run_storm_wiki_gpt.py b/examples/run_storm_wiki_gpt.py index f7639d69..b7968152 100644 --- a/examples/run_storm_wiki_gpt.py +++ b/examples/run_storm_wiki_gpt.py @@ -20,14 +20,12 @@ """ import os -import sys from argparse import ArgumentParser -sys.path.append('./src') -from lm import OpenAIModel -from rm import YouRM, BingSearch -from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from utils import load_api_key +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel +from knowledge_storm.rm import YouRM, BingSearch +from knowledge_storm.utils import load_api_key def main(args): @@ -35,23 +33,29 @@ def main(args): lm_configs = STORMWikiLMConfigs() openai_kwargs = { 'api_key': os.getenv("OPENAI_API_KEY"), - 'api_provider': os.getenv('OPENAI_API_TYPE'), 'temperature': 1.0, 'top_p': 0.9, - 'api_base': os.getenv('AZURE_API_BASE'), - 'api_version': os.getenv('AZURE_API_VERSION'), } + ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + # If you are using Azure service, make sure the model name matches your own deployed model name. + # The default name here is only used for demonstration and may not match your case. + gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' + gpt_4_model_name = 'gpt-4o' + if os.getenv('OPENAI_API_TYPE') == 'azure': + openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') + openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) - question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) - outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs) - article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=700, **openai_kwargs) - article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=4000, **openai_kwargs) + conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) + article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) + article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) lm_configs.set_conv_simulator_lm(conv_simulator_lm) lm_configs.set_question_asker_lm(question_asker_lm) @@ -122,4 +126,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/examples/run_storm_wiki_gpt_with_VectorRM.py b/examples/run_storm_wiki_gpt_with_VectorRM.py index c5dd4354..2c07ffc2 100644 --- a/examples/run_storm_wiki_gpt_with_VectorRM.py +++ b/examples/run_storm_wiki_gpt_with_VectorRM.py @@ -30,11 +30,10 @@ import sys from argparse import ArgumentParser -sys.path.append('./src') -from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from rm import VectorRM -from lm import OpenAIModel -from utils import load_api_key +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.rm import VectorRM +from knowledge_storm.lm import OpenAIModel, AzureOpenAIModel +from knowledge_storm.utils import load_api_key def main(args): @@ -45,21 +44,29 @@ def main(args): engine_lm_configs = STORMWikiLMConfigs() openai_kwargs = { 'api_key': os.getenv("OPENAI_API_KEY"), - 'api_provider': os.getenv('OPENAI_API_TYPE'), 'temperature': 1.0, 'top_p': 0.9, } + ModelClass = OpenAIModel if os.getenv('OPENAI_API_TYPE') == 'openai' else AzureOpenAIModel + # If you are using Azure service, make sure the model name matches your own deployed model name. + # The default name here is only used for demonstration and may not match your case. + gpt_35_model_name = 'gpt-3.5-turbo' if os.getenv('OPENAI_API_TYPE') == 'openai' else 'gpt-35-turbo' + gpt_4_model_name = 'gpt-4o' + if os.getenv('OPENAI_API_TYPE') == 'azure': + openai_kwargs['api_base'] = os.getenv('AZURE_API_BASE') + openai_kwargs['api_version'] = os.getenv('AZURE_API_VERSION') + # STORM is a LM system so different components can be powered by different models. # For a good balance between cost and quality, you can choose a cheaper/faster model for conv_simulator_lm # which is used to split queries, synthesize answers in the conversation. We recommend using stronger models # for outline_gen_lm which is responsible for organizing the collected information, and article_gen_lm # which is responsible for generating sections with citations. - conv_simulator_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) - question_asker_lm = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) - outline_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=400, **openai_kwargs) - article_gen_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=700, **openai_kwargs) - article_polish_lm = OpenAIModel(model='gpt-4-0125-preview', max_tokens=4000, **openai_kwargs) + conv_simulator_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + question_asker_lm = ModelClass(model=gpt_35_model_name, max_tokens=500, **openai_kwargs) + outline_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=400, **openai_kwargs) + article_gen_lm = ModelClass(model=gpt_4_model_name, max_tokens=700, **openai_kwargs) + article_polish_lm = ModelClass(model=gpt_4_model_name, max_tokens=4000, **openai_kwargs) engine_lm_configs.set_conv_simulator_lm(conv_simulator_lm) engine_lm_configs.set_question_asker_lm(question_asker_lm) diff --git a/examples/run_storm_wiki_mistral.py b/examples/run_storm_wiki_mistral.py index f7bc22dd..eb6a4ff6 100644 --- a/examples/run_storm_wiki_mistral.py +++ b/examples/run_storm_wiki_mistral.py @@ -16,16 +16,14 @@ storm_gen_article_polished.txt # Polished final article (if args.do_polish_article is True) """ import os -import sys from argparse import ArgumentParser from dspy import Example -sys.path.append('./src') -from lm import VLLMClient -from rm import YouRM, BingSearch -from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from utils import load_api_key +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.lm import VLLMClient +from knowledge_storm.rm import YouRM, BingSearch +from knowledge_storm.utils import load_api_key def main(args): @@ -174,4 +172,4 @@ def main(args): parser.add_argument('--remove-duplicate', action='store_true', help='If True, remove duplicate content from the article.') - main(parser.parse_args()) \ No newline at end of file + main(parser.parse_args()) diff --git a/frontend/demo_light/README.md b/frontend/demo_light/README.md index 6a41c0e0..6ff58789 100644 --- a/frontend/demo_light/README.md +++ b/frontend/demo_light/README.md @@ -15,7 +15,8 @@ This is a minimal user interface for `STORMWikiRunner` which includes the follow

## Setup -1. Besides the required packages for `STORMWikiRunner`, you need to install additional packages: +1. Make sure you have installed `knowledge-storm` or set up the source code correctly. +2. Install additional packages required by the user interface: ```bash pip install -r requirements.txt ``` diff --git a/frontend/demo_light/demo_util.py b/frontend/demo_light/demo_util.py index d940aa09..e8a51823 100644 --- a/frontend/demo_light/demo_util.py +++ b/frontend/demo_light/demo_util.py @@ -1,20 +1,22 @@ import base64 import datetime -import io import json import os import re from typing import Optional import markdown -import pdfkit import pytz import streamlit as st -from lm import OpenAIModel -from rm import YouRM -from storm_wiki.engine import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs -from storm_wiki.modules.callback import BaseCallbackHandler +# If you install the source code instead of the `knowledge-storm` package, +# Uncomment the following lines: +# import sys +# sys.path.append('../../') +from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs +from knowledge_storm.lm import OpenAIModel +from knowledge_storm.rm import YouRM +from knowledge_storm.storm_wiki.modules.callback import BaseCallbackHandler from stoc import stoc @@ -529,7 +531,6 @@ def display_article_page(selected_article_name, selected_article_file_path_dict, _display_main_article(selected_article_file_path_dict) - class StreamlitCallbackHandler(BaseCallbackHandler): def __init__(self, status_container): self.status_container = status_container diff --git a/frontend/demo_light/storm.py b/frontend/demo_light/storm.py index 9a0ae663..c68b88cf 100644 --- a/frontend/demo_light/storm.py +++ b/frontend/demo_light/storm.py @@ -1,12 +1,8 @@ import os -import sys script_dir = os.path.dirname(os.path.abspath(__file__)) wiki_root_dir = os.path.dirname(os.path.dirname(script_dir)) -sys.path.append(os.path.normpath(os.path.join(script_dir, '../../src/storm_wiki'))) -sys.path.append(os.path.normpath(os.path.join(script_dir, '../../src'))) - import demo_util from pages_util import MyArticles, CreateNewArticle from streamlit_float import * diff --git a/knowledge_storm/__init__.py b/knowledge_storm/__init__.py new file mode 100644 index 00000000..f1fd18ea --- /dev/null +++ b/knowledge_storm/__init__.py @@ -0,0 +1,5 @@ +from .storm_wiki.engine import ( + STORMWikiLMConfigs, + STORMWikiRunnerArguments, + STORMWikiRunner +) diff --git a/src/interface.py b/knowledge_storm/interface.py similarity index 100% rename from src/interface.py rename to knowledge_storm/interface.py diff --git a/src/lm.py b/knowledge_storm/lm.py similarity index 92% rename from src/lm.py rename to knowledge_storm/lm.py index 6166cdd8..e1ec8e29 100644 --- a/src/lm.py +++ b/knowledge_storm/lm.py @@ -25,13 +25,10 @@ def __init__( self, model: str = "gpt-3.5-turbo-instruct", api_key: Optional[str] = None, - api_provider: Literal["openai", "azure"] = "openai", - api_base: Optional[str] = None, model_type: Literal["chat", "text"] = None, **kwargs ): - super().__init__(model=model, api_key=api_key, api_provider=api_provider, api_base=api_base, - model_type=model_type, **kwargs) + super().__init__(model=model, api_key=api_key, model_type=model_type, **kwargs) self._token_usage_lock = threading.Lock() self.prompt_tokens = 0 self.completion_tokens = 0 @@ -108,6 +105,44 @@ def __call__( return completions +class AzureOpenAIModel(dspy.AzureOpenAI): + """A wrapper class for dspy.AzureOpenAI.""" + def __init__( + self, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + model: str = "gpt-3.5-turbo-instruct", + api_key: Optional[str] = None, + model_type: Literal["chat", "text"] = "chat", + **kwargs, + ): + super().__init__( + api_base=api_base, api_version=api_version, model=model, api_key=api_key, model_type=model_type, **kwargs) + self._token_usage_lock = threading.Lock() + self.prompt_tokens = 0 + self.completion_tokens = 0 + + def log_usage(self, response): + """Log the total tokens from the OpenAI API response. + Override log_usage() in dspy.AzureOpenAI for tracking accumulated token usage.""" + usage_data = response.get('usage') + if usage_data: + with self._token_usage_lock: + self.prompt_tokens += usage_data.get('prompt_tokens', 0) + self.completion_tokens += usage_data.get('completion_tokens', 0) + + def get_usage_and_reset(self): + """Get the total tokens used and reset the token usage.""" + usage = { + self.kwargs.get('model') or self.kwargs.get('engine'): + {'prompt_tokens': self.prompt_tokens, 'completion_tokens': self.completion_tokens} + } + self.prompt_tokens = 0 + self.completion_tokens = 0 + + return usage + + class ClaudeModel(dspy.dsp.modules.lm.LM): """Copied from dspy/dsp/modules/anthropic.py with the addition of tracking token usage.""" @@ -277,6 +312,7 @@ def _generate(self, prompt, **kwargs): print("Failed to parse JSON response:", response.text) raise Exception("Received invalid JSON response from server") + class OllamaClient(dspy.OllamaLocal): """A wrapper class for dspy.OllamaClient.""" diff --git a/src/rm.py b/knowledge_storm/rm.py similarity index 99% rename from src/rm.py rename to knowledge_storm/rm.py index 5126aa5a..86f59703 100644 --- a/src/rm.py +++ b/knowledge_storm/rm.py @@ -5,13 +5,13 @@ import dspy import pandas as pd import requests -from langchain_huggingface import HuggingFaceEmbeddings from langchain_core.documents import Document +from langchain_huggingface import HuggingFaceEmbeddings from langchain_qdrant import Qdrant from qdrant_client import QdrantClient, models from tqdm import tqdm -from utils import WebPageHelper +from .utils import WebPageHelper class YouRM(dspy.Retrieve): diff --git a/src/storm_wiki/__init__.py b/knowledge_storm/storm_wiki/__init__.py similarity index 100% rename from src/storm_wiki/__init__.py rename to knowledge_storm/storm_wiki/__init__.py diff --git a/src/storm_wiki/engine.py b/knowledge_storm/storm_wiki/engine.py similarity index 85% rename from src/storm_wiki/engine.py rename to knowledge_storm/storm_wiki/engine.py index 4191d35a..e0c8dfcc 100644 --- a/src/storm_wiki/engine.py +++ b/knowledge_storm/storm_wiki/engine.py @@ -5,17 +5,18 @@ from typing import Union, Literal, Optional import dspy -from interface import Engine, LMConfigs -from lm import OpenAIModel -from storm_wiki.modules.article_generation import StormArticleGenerationModule -from storm_wiki.modules.article_polish import StormArticlePolishingModule -from storm_wiki.modules.callback import BaseCallbackHandler -from storm_wiki.modules.knowledge_curation import StormKnowledgeCurationModule -from storm_wiki.modules.outline_generation import StormOutlineGenerationModule -from storm_wiki.modules.persona_generator import StormPersonaGenerator -from storm_wiki.modules.retriever import StormRetriever -from storm_wiki.modules.storm_dataclass import StormInformationTable, StormArticle -from utils import FileIOHelper, makeStringRed + +from .modules.article_generation import StormArticleGenerationModule +from .modules.article_polish import StormArticlePolishingModule +from .modules.callback import BaseCallbackHandler +from .modules.knowledge_curation import StormKnowledgeCurationModule +from .modules.outline_generation import StormOutlineGenerationModule +from .modules.persona_generator import StormPersonaGenerator +from .modules.retriever import StormRetriever +from .modules.storm_dataclass import StormInformationTable, StormArticle +from ..interface import Engine, LMConfigs +from ..lm import OpenAIModel +from ..utils import FileIOHelper, makeStringRed class STORMWikiLMConfigs(LMConfigs): @@ -233,16 +234,20 @@ def post_run(self): f.write(json.dumps(call) + '\n') def _load_information_table_from_local_fs(self, information_table_local_path): - assert os.path.exists(information_table_local_path), makeStringRed(f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.") + assert os.path.exists(information_table_local_path), makeStringRed( + f"{information_table_local_path} not exists. Please set --do-research argument to prepare the conversation_log.json for this topic.") return StormInformationTable.from_conversation_log_file(information_table_local_path) - + def _load_outline_from_local_fs(self, topic, outline_local_path): - assert os.path.exists(outline_local_path), makeStringRed(f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.") + assert os.path.exists(outline_local_path), makeStringRed( + f"{outline_local_path} not exists. Please set --do-generate-outline argument to prepare the storm_gen_outline.txt for this topic.") return StormArticle.from_outline_file(topic=topic, file_path=outline_local_path) def _load_draft_article_from_local_fs(self, topic, draft_article_path, url_to_info_path): - assert os.path.exists(draft_article_path), makeStringRed(f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.") - assert os.path.exists(url_to_info_path), makeStringRed(f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.") + assert os.path.exists(draft_article_path), makeStringRed( + f"{draft_article_path} not exists. Please set --do-generate-article argument to prepare the storm_gen_article.txt for this topic.") + assert os.path.exists(url_to_info_path), makeStringRed( + f"{url_to_info_path} not exists. Please set --do-generate-article argument to prepare the url_to_info.json for this topic.") article_text = FileIOHelper.load_str(draft_article_path) references = FileIOHelper.load_json(url_to_info_path) return StormArticle.from_string(topic_name=topic, article_text=article_text, references=references) @@ -274,7 +279,8 @@ def run(self, callback_handler: A callback handler to handle the intermediate results. """ assert do_research or do_generate_outline or do_generate_article or do_polish_article, \ - makeStringRed("No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article") + makeStringRed( + "No action is specified. Please set at least one of --do-research, --do-generate-outline, --do-generate-article, --do-polish-article") self.topic = topic self.article_dir_name = topic.replace(' ', '_').replace('/', '_') @@ -291,7 +297,8 @@ def run(self, if do_generate_outline: # load information table if it's not initialized if information_table is None: - information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json')) + information_table = self._load_information_table_from_local_fs( + os.path.join(self.article_output_dir, 'conversation_log.json')) outline = self.run_outline_generation_module(information_table=information_table, callback_handler=callback_handler) @@ -299,9 +306,12 @@ def run(self, draft_article: StormArticle = None if do_generate_article: if information_table is None: - information_table = self._load_information_table_from_local_fs(os.path.join(self.article_output_dir, 'conversation_log.json')) + information_table = self._load_information_table_from_local_fs( + os.path.join(self.article_output_dir, 'conversation_log.json')) if outline is None: - outline = self._load_outline_from_local_fs(topic=topic, outline_local_path=os.path.join(self.article_output_dir, 'storm_gen_outline.txt')) + outline = self._load_outline_from_local_fs(topic=topic, + outline_local_path=os.path.join(self.article_output_dir, + 'storm_gen_outline.txt')) draft_article = self.run_article_generation_module(outline=outline, information_table=information_table, callback_handler=callback_handler) @@ -311,5 +321,7 @@ def run(self, if draft_article is None: draft_article_path = os.path.join(self.article_output_dir, 'storm_gen_article.txt') url_to_info_path = os.path.join(self.article_output_dir, 'url_to_info.json') - draft_article = self._load_draft_article_from_local_fs(topic=topic, draft_article_path=draft_article_path, url_to_info_path=url_to_info_path) + draft_article = self._load_draft_article_from_local_fs(topic=topic, + draft_article_path=draft_article_path, + url_to_info_path=url_to_info_path) self.run_article_polishing_module(draft_article=draft_article, remove_duplicate=remove_duplicate) diff --git a/src/storm_wiki/modules/__init__.py b/knowledge_storm/storm_wiki/modules/__init__.py similarity index 74% rename from src/storm_wiki/modules/__init__.py rename to knowledge_storm/storm_wiki/modules/__init__.py index 51ee0121..9419a314 100644 --- a/src/storm_wiki/modules/__init__.py +++ b/knowledge_storm/storm_wiki/modules/__init__.py @@ -1,4 +1,4 @@ from .knowledge_curation import * from .persona_generator import * from .retriever import * -from .storm_dataclass import * \ No newline at end of file +from .storm_dataclass import * diff --git a/src/storm_wiki/modules/article_generation.py b/knowledge_storm/storm_wiki/modules/article_generation.py similarity index 96% rename from src/storm_wiki/modules/article_generation.py rename to knowledge_storm/storm_wiki/modules/article_generation.py index 0dfb76de..a114b3ec 100644 --- a/src/storm_wiki/modules/article_generation.py +++ b/knowledge_storm/storm_wiki/modules/article_generation.py @@ -5,10 +5,11 @@ from typing import List, Union import dspy -from interface import ArticleGenerationModule -from storm_wiki.modules.callback import BaseCallbackHandler -from storm_wiki.modules.storm_dataclass import StormInformationTable, StormArticle, StormInformation -from utils import ArticleTextProcessing + +from .callback import BaseCallbackHandler +from .storm_dataclass import StormInformationTable, StormArticle, StormInformation +from ...interface import ArticleGenerationModule +from ...utils import ArticleTextProcessing class StormArticleGenerationModule(ArticleGenerationModule): diff --git a/src/storm_wiki/modules/article_polish.py b/knowledge_storm/storm_wiki/modules/article_polish.py similarity index 96% rename from src/storm_wiki/modules/article_polish.py rename to knowledge_storm/storm_wiki/modules/article_polish.py index 5f38f058..b70bb834 100644 --- a/src/storm_wiki/modules/article_polish.py +++ b/knowledge_storm/storm_wiki/modules/article_polish.py @@ -2,9 +2,10 @@ from typing import Union import dspy -from interface import ArticlePolishingModule -from storm_wiki.modules.storm_dataclass import StormArticle -from utils import ArticleTextProcessing + +from .storm_dataclass import StormArticle +from ...interface import ArticlePolishingModule +from ...utils import ArticleTextProcessing class StormArticlePolishingModule(ArticlePolishingModule): diff --git a/src/storm_wiki/modules/callback.py b/knowledge_storm/storm_wiki/modules/callback.py similarity index 98% rename from src/storm_wiki/modules/callback.py rename to knowledge_storm/storm_wiki/modules/callback.py index 945a45db..a4b702d4 100644 --- a/src/storm_wiki/modules/callback.py +++ b/knowledge_storm/storm_wiki/modules/callback.py @@ -31,4 +31,4 @@ def on_direct_outline_generation_end(self, outline: str, **kwargs): def on_outline_refinement_end(self, outline: str, **kwargs): """Run when the outline refinement finishes.""" - pass \ No newline at end of file + pass diff --git a/src/storm_wiki/modules/knowledge_curation.py b/knowledge_storm/storm_wiki/modules/knowledge_curation.py similarity index 97% rename from src/storm_wiki/modules/knowledge_curation.py rename to knowledge_storm/storm_wiki/modules/knowledge_curation.py index 4fe7f159..8e881c65 100644 --- a/src/storm_wiki/modules/knowledge_curation.py +++ b/knowledge_storm/storm_wiki/modules/knowledge_curation.py @@ -5,14 +5,16 @@ from typing import Union, List, Tuple, Optional, Dict import dspy -from interface import KnowledgeCurationModule, Retriever -from storm_wiki.modules.callback import BaseCallbackHandler -from storm_wiki.modules.persona_generator import StormPersonaGenerator -from storm_wiki.modules.storm_dataclass import DialogueTurn, StormInformationTable, StormInformation -from utils import ArticleTextProcessing + +from .callback import BaseCallbackHandler +from .persona_generator import StormPersonaGenerator +from .storm_dataclass import DialogueTurn, StormInformationTable, StormInformation +from ...interface import KnowledgeCurationModule, Retriever +from ...utils import ArticleTextProcessing try: from streamlit.runtime.scriptrunner import add_script_run_ctx + streamlit_connection = True except ImportError as err: streamlit_connection = False diff --git a/src/storm_wiki/modules/outline_generation.py b/knowledge_storm/storm_wiki/modules/outline_generation.py similarity index 96% rename from src/storm_wiki/modules/outline_generation.py rename to knowledge_storm/storm_wiki/modules/outline_generation.py index 2b09d523..1f45b1c2 100644 --- a/src/storm_wiki/modules/outline_generation.py +++ b/knowledge_storm/storm_wiki/modules/outline_generation.py @@ -1,10 +1,11 @@ from typing import Union, Optional, Tuple import dspy -from interface import OutlineGenerationModule -from storm_wiki.modules.callback import BaseCallbackHandler -from storm_wiki.modules.storm_dataclass import StormInformationTable, StormArticle -from utils import ArticleTextProcessing + +from .callback import BaseCallbackHandler +from .storm_dataclass import StormInformationTable, StormArticle +from ...interface import OutlineGenerationModule +from ...utils import ArticleTextProcessing class StormOutlineGenerationModule(OutlineGenerationModule): diff --git a/src/storm_wiki/modules/persona_generator.py b/knowledge_storm/storm_wiki/modules/persona_generator.py similarity index 100% rename from src/storm_wiki/modules/persona_generator.py rename to knowledge_storm/storm_wiki/modules/persona_generator.py diff --git a/src/storm_wiki/modules/internet_source_restrictions.json b/knowledge_storm/storm_wiki/modules/retriever.py similarity index 72% rename from src/storm_wiki/modules/internet_source_restrictions.json rename to knowledge_storm/storm_wiki/modules/retriever.py index 0a71e1ae..179ae99b 100644 --- a/src/storm_wiki/modules/internet_source_restrictions.json +++ b/knowledge_storm/storm_wiki/modules/retriever.py @@ -1,5 +1,15 @@ -{ - "generally_unreliable": [ +from typing import Union, List +from urllib.parse import urlparse + +import dspy + +from .storm_dataclass import StormInformation +from ...interface import Retriever, Information +from ...utils import ArticleTextProcessing + +# Internet source restrictions according to Wikipedia standard: +# https://en.wikipedia.org/wiki/Wikipedia:Reliable_sources/Perennial_sources +GENERALLY_UNRELIABLE = { "112_Ukraine", "Ad_Fontes_Media", "AlterNet", @@ -139,9 +149,8 @@ "WordPress.com", "Worldometer", "YouTube", - "ZDNet" - ], - "deprecated": [ + "ZDNet"} +DEPRECATED = { "Al_Mayadeen", "ANNA_News", "Baidu_Baike", @@ -189,8 +198,8 @@ "Voltaire_Network", "WorldNetDaily", "Zero_Hedge" - ], - "blacklisted": [ +} +BLACKLISTED = { "Advameg", "bestgore.com", "Breitbart_News", @@ -210,5 +219,32 @@ "Swarajya", "Veterans_Today", "ZoomInfo" - ] } + + +def is_valid_wikipedia_source(url): + parsed_url = urlparse(url) + # Check if the URL is from a reliable domain + combined_set = GENERALLY_UNRELIABLE | DEPRECATED | BLACKLISTED + for domain in combined_set: + if domain in parsed_url.netloc: + return False + + return True + + +class StormRetriever(Retriever): + def __init__(self, rm: dspy.Retrieve, k=3): + super().__init__(search_top_k=k) + self._rm = rm + if hasattr(rm, 'is_valid_source'): + rm.is_valid_source = is_valid_wikipedia_source + + def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]: + retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls) + for data in retrieved_data_list: + for i in range(len(data['snippets'])): + # STORM generate the article with citations. We do not consider multi-hop citations. + # Remove citations in the source to avoid confusion. + data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i]) + return [StormInformation.from_dict(data) for data in retrieved_data_list] diff --git a/src/storm_wiki/modules/storm_dataclass.py b/knowledge_storm/storm_wiki/modules/storm_dataclass.py similarity index 99% rename from src/storm_wiki/modules/storm_dataclass.py rename to knowledge_storm/storm_wiki/modules/storm_dataclass.py index d75760ce..4f54ec46 100644 --- a/src/storm_wiki/modules/storm_dataclass.py +++ b/knowledge_storm/storm_wiki/modules/storm_dataclass.py @@ -4,10 +4,11 @@ from typing import Union, Optional, Any, List, Tuple, Dict import numpy as np -from interface import Information, InformationTable, Article, ArticleSectionNode from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity -from utils import ArticleTextProcessing, FileIOHelper + +from ...interface import Information, InformationTable, Article, ArticleSectionNode +from ...utils import ArticleTextProcessing, FileIOHelper class StormInformation(Information): diff --git a/src/utils.py b/knowledge_storm/utils.py similarity index 99% rename from src/utils.py rename to knowledge_storm/utils.py index cc1a6e58..5cf6f457 100644 --- a/src/utils.py +++ b/knowledge_storm/utils.py @@ -1,5 +1,6 @@ import concurrent.futures import json +import logging import os import pickle import re @@ -11,6 +12,8 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from trafilatura import extract +logging.getLogger("httpx").setLevel(logging.WARNING) # Disable INFO logging for httpx. + def load_api_key(toml_file_path): try: diff --git a/requirements.txt b/requirements.txt index b1bdfa51..8ac1b95b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,27 +1,7 @@ dspy_ai==2.4.9 -streamlit==1.31.1 wikipedia==1.4.0 -streamlit_authenticator==0.2.3 -streamlit_oauth==0.1.8 -streamlit-card -google-cloud==0.34.0 -google-cloud-vision==3.5.0 -google-cloud-storage==2.14.0 sentence_transformers toml -markdown -unidecode -extra-streamlit-components==0.1.60 -google-cloud-firestore==2.14.0 -firebase-admin==6.4.0 -streamlit_extras -streamlit_cookies_manager -deprecation==2.1.0 -st-pages==0.4.5 -streamlit-float -streamlit-option-menu -sentry-sdk -pdfkit==1.0.0 langchain-text-splitters trafilatura langchain-huggingface diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..e4fa40c1 --- /dev/null +++ b/setup.py @@ -0,0 +1,39 @@ +import re + +from setuptools import setup, find_packages + +# Read the content of the README file +with open("README.md", encoding="utf-8") as f: + long_description = f.read() + # Remove p tags. + pattern = re.compile(r'.*?

', re.DOTALL) + long_description = re.sub(pattern, '', long_description) + +# Read the content of the requirements.txt file +with open("requirements.txt", encoding="utf-8") as f: + requirements = f.read().splitlines() + + +setup( + name="knowledge-storm", + version="0.2.3", + author="Yijia Shao, Yucheng Jiang", + author_email="shaoyj@stanford.edu, yuchengj@stanford.edu", + description="STORM: A language model-powered knowledge curation engine.", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/stanford-oval/storm", + license="MIT License", + packages=find_packages(), + classifiers=[ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + ], + python_requires='>=3.9', + install_requires=requirements, +) diff --git a/src/storm_wiki/modules/retriever.py b/src/storm_wiki/modules/retriever.py deleted file mode 100644 index 79cc2060..00000000 --- a/src/storm_wiki/modules/retriever.py +++ /dev/null @@ -1,45 +0,0 @@ -import json -import os -from typing import Union, List -from urllib.parse import urlparse - -import dspy -import storm_wiki.modules.storm_dataclass as storm_dataclass -from interface import Retriever, Information -from rm import YouRM -from utils import ArticleTextProcessing - -SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) -with open(os.path.join(SCRIPT_DIR, 'internet_source_restrictions.json')) as f: - domain_restriction_dict = json.load(f) - GENERALLY_UNRELIABLE = set(domain_restriction_dict["generally_unreliable"]) - DEPRECATED = set(domain_restriction_dict["deprecated"]) - BLACKLISTED = set(domain_restriction_dict["blacklisted"]) - - -def is_valid_wikipedia_source(url): - parsed_url = urlparse(url) - # Check if the URL is from a reliable domain - combined_set = GENERALLY_UNRELIABLE | DEPRECATED | BLACKLISTED - for domain in combined_set: - if domain in parsed_url.netloc: - return False - - return True - - -class StormRetriever(Retriever): - def __init__(self, rm: dspy.Retrieve, k=3): - super().__init__(search_top_k=k) - self._rm = rm - if hasattr(rm, 'is_valid_source'): - rm.is_valid_source = is_valid_wikipedia_source - - def retrieve(self, query: Union[str, List[str]], exclude_urls: List[str] = []) -> List[Information]: - retrieved_data_list = self._rm(query_or_queries=query, exclude_urls=exclude_urls) - for data in retrieved_data_list: - for i in range(len(data['snippets'])): - # STORM generate the article with citations. We do not consider multi-hop citations. - # Remove citations in the source to avoid confusion. - data['snippets'][i] = ArticleTextProcessing.remove_citations(data['snippets'][i]) - return [storm_dataclass.StormInformation.from_dict(data) for data in retrieved_data_list]