Skip to content

Commit

Permalink
Fixing black and mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
JosselinSomervilleRoberts committed Oct 18, 2023
1 parent 068b199 commit 427e723
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def generate_requests(self, eval_instance: Instance) -> List[RequestState]:
)

request = Request(
model=self.adapter_spec.model,
model=self.adapter_spec.model_deployment,
multimodal_prompt=prompt.multimedia_object,
num_completions=self.adapter_spec.num_outputs,
temperature=self.adapter_spec.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def generate_requests(self, eval_instance: Instance) -> List[RequestState]:
)

request = Request(
model=self.adapter_spec.model,
model=self.adapter_spec.model_deployment,
multimodal_prompt=prompt.multimedia_object,
num_completions=self.adapter_spec.num_outputs,
temperature=self.adapter_spec.temperature,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def teardown_method(self, _):

def test_construct_prompt(self):
adapter_spec: AdapterSpec = AdapterSpec(
model="simple/model1",
model_deployment="simple/model1",
method=ADAPT_GENERATION_MULTIMODAL,
global_prefix="[START]",
instructions="Please answer the following question about the images.",
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_construct_prompt(self):

def test_construct_prompt_multi_label(self):
adapter_spec: AdapterSpec = AdapterSpec(
model="simple/model1",
model_deployment="simple/model1",
method=ADAPT_GENERATION_MULTIMODAL,
global_prefix="[START]",
instructions="Please answer the following question about the images.",
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_construct_prompt_idefics_instruct_example(self):
Constructing the same prompt from this example: https://huggingface.co/blog/idefics
"""
adapter_spec: AdapterSpec = AdapterSpec(
model="simple/model1",
model_deployment="simple/model1",
method=ADAPT_GENERATION_MULTIMODAL,
input_prefix="User: ",
input_suffix="<end_of_utterance>",
Expand Down
4 changes: 3 additions & 1 deletion src/helm/benchmark/metrics/basic_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,7 +801,9 @@ def compute_logprob_and_length(request_state: RequestState, window_service: Wind
num_choices = len(references)

tokenizer_service: TokenizerService = metric_service
window_service: WindowService = WindowServiceFactory.get_window_service(adapter_spec.model_deployment, tokenizer_service)
window_service: WindowService = WindowServiceFactory.get_window_service(
adapter_spec.model_deployment, tokenizer_service
)
reference_stats: Dict[ReferenceKey, ReferenceStat] = {}
for request_state in reference_request_states:
assert request_state.reference_index is not None and request_state.request_mode is not None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ def estimate_tokens(self, request: Request, metric_service: MetricService) -> in
"""
Estimate the number of tokens for a given request based on the organization.
"""
token_cost_estimator: TokenCostEstimator = self._get_estimator(request.model_organization)
token_cost_estimator: TokenCostEstimator = self._get_estimator(request.model_host)
return token_cost_estimator.estimate_tokens(request, metric_service)
15 changes: 8 additions & 7 deletions src/helm/benchmark/model_deployment_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def register_model_deployment(model_deployment: ModelDeployment) -> None:

try:
model_metadata: ModelMetadata = get_model_metadata(model_name)
if model_deployment.name not in model_metadata.deployment_names:
deployment_names: List[str] = model_metadata.deployment_names or [model_metadata.name]
if model_deployment.name not in deployment_names:
if model_metadata.deployment_names is None:
model_metadata.deployment_names = []
model_metadata.deployment_names.append(model_deployment.name)
except ValueError:
# No model metadata exists for this model name.
Expand Down Expand Up @@ -151,20 +154,18 @@ def maybe_register_model_deployments_from_base_path(base_path: str) -> None:


# ===================== UTIL FUNCTIONS ==================== #
def get_model_deployment(name: str) -> Optional[ModelDeployment]:
def get_model_deployment(name: str) -> ModelDeployment:
if name not in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
raise ValueError(f"Model deployment {name} not found")
return DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT.get(name)
return DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[name]


def get_model_deployments_by_host_group(host_group: str) -> List[str]:
"""
Gets models by host group.
Example: together => TODO(PR)
"""
return [
deployment.name for deployment in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT if deployment.host_group == host_group
]
return [deployment.name for deployment in ALL_MODEL_DEPLOYMENTS if deployment.host_group == host_group]


def get_model_deployment_host_group(name: str) -> str:
Expand All @@ -188,7 +189,7 @@ def get_default_deployment_for_model(model_metadata: ModelMetadata) -> ModelDepl
"""
if model_metadata.name in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
return DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[model_metadata.name]
elif len(model_metadata.deployment_names) > 0:
elif model_metadata.deployment_names is not None and len(model_metadata.deployment_names) > 0:
deployment_name: str = model_metadata.deployment_names[0]
if deployment_name in DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT:
return DEPLOYMENT_NAME_TO_MODEL_DEPLOYMENT[deployment_name]
Expand Down
12 changes: 9 additions & 3 deletions src/helm/benchmark/model_metadata_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,16 @@
# Some models can follow instructions.
INSTRUCTION_FOLLOWING_MODEL_TAG: str = "instruction_following"

# For Vision-langauge models (VLMs)
VISION_LANGUAGE_MODEL_TAG: str = "vision_language"


MODEL_METADATA_FILE = "model_metadata.yaml"


@dataclass(frozen=True)
# Frozen is set to false as the model_deployment_registry.py file
# might populate the deployment_names field.
@dataclass(frozen=False)
class ModelMetadata:
# Name of the model group (e.g. "openai/davinci").
# This is the name of the model, not the name of the deployment.
Expand Down Expand Up @@ -144,7 +149,8 @@ class ModelMetadataList:
ModelMetadata(
name="anthropic/claude-v1.3",
display_name="Anthropic Claude v1.3",
description="A 52B parameter language model, trained using reinforcement learning from human feedback [paper](https://arxiv.org/pdf/2204.05862.pdf).",
description="A 52B parameter language model, trained using reinforcement learning from human feedback "
"[paper](https://arxiv.org/pdf/2204.05862.pdf).",
access="limited",
num_parameters=52000000000,
release_date="2023-03-17",
Expand Down Expand Up @@ -216,7 +222,7 @@ def get_models_by_creator_organization(organization: str) -> List[str]:
Gets models by creator organization.
Example: ai21 => ai21/j1-jumbo, ai21/j1-grande, ai21-large.
"""
return [model.name for model in MODEL_NAME_TO_MODEL_METADATA if model.creator_organization == organization]
return [model.name for model in ALL_MODELS_METADATA if model.creator_organization == organization]


def get_model_names_with_tag(tag: str) -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions src/helm/benchmark/presentation/summarize.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,11 +886,11 @@ def _adapter_spec_sort_key(spec):
rows = []
for adapter_spec, info in zip(adapter_specs, infos):
deployment: str = adapter_spec.model_deployment
metadata: ModelMetadata = get_metadata_for_deployment(deployment)
model_name: str = metadata.name
model_metadata: ModelMetadata = get_metadata_for_deployment(deployment)
model_name: str = model_metadata.name

runs = adapter_to_runs[adapter_spec]
display_name = get_method_display_name(metadata.display_name, info)
display_name = get_method_display_name(model_metadata.display_name, info)

# Link to all the runs under this model
if link_to_runs:
Expand Down
4 changes: 2 additions & 2 deletions src/helm/proxy/clients/auto_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
from helm.proxy.tokenizers.tokenizer import Tokenizer
from helm.proxy.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer

from .http_model_client import HTTPModelClient
# TODO(PR): Remove this
# from .http_model_client import HTTPModelClient

if TYPE_CHECKING:
import helm.proxy.clients.huggingface_client
Expand Down Expand Up @@ -69,7 +70,6 @@ def _get_client(self, model: str) -> Client:
if client is None:
host_group: str = model.split("/")[0]
cache_config: CacheConfig = self._build_cache_config(host_group)
tokenizer: Tokenizer = self._get_tokenizer(host_group)

# TODO: Migrate all clients to use model deployments
# TODO(PR): Remove the TODO above.
Expand Down
22 changes: 8 additions & 14 deletions src/helm/proxy/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# TODO(PR): Delete this file.

from dataclasses import dataclass, field
from typing import Dict, List
from typing import List

# TODO(PR): Move these to model_metadata_registry.py
# # Different modalities
Expand Down Expand Up @@ -65,11 +65,9 @@

# # Some models can follow instructions.
# INSTRUCTION_FOLLOWING_MODEL_TAG: str = "instruction_following"
from typing import Dict, List, Optional
from helm.benchmark.model_deployment_registry import ModelDeployment

# For Vision-langauge models (VLMs)
VISION_LANGUAGE_MODEL_TAG: str = "vision_language"
# # For Vision-langauge models (VLMs)
# VISION_LANGUAGE_MODEL_TAG: str = "vision_language"


@dataclass
Expand Down Expand Up @@ -110,12 +108,6 @@ class Model:
# Tags corresponding to the properties of the model.
tags: List[str] = field(default_factory=list)

# List of the model deployments for this model.
# Should at least contain one model deployment.
# Refers to the field "name" in the ModelDeployment class.
# Defaults to a single model deployment with the same name as the model.
deployment_names: Optional[List[str]] = None

@property
def creator_organization(self) -> str:
"""
Expand All @@ -140,7 +132,7 @@ def engine(self) -> str:

# TODO(PR): Port all these models to the new format.

ALL_MODELS = [
ALL_MODELS: List[Model] = [
# # Local Model
# Model(
# group="neurips",
Expand Down Expand Up @@ -320,12 +312,14 @@ def engine(self) -> str:
# Model(
# group="cohere",
# name="cohere/command-medium-beta",
# tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, COHERE_TOKENIZER_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG],
# tags=[TEXT_MODEL_TAG,
# FULL_FUNCTIONALITY_TEXT_MODEL_TAG, COHERE_TOKENIZER_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG],
# ),
# Model(
# group="cohere",
# name="cohere/command-xlarge-beta",
# tags=[TEXT_MODEL_TAG, FULL_FUNCTIONALITY_TEXT_MODEL_TAG, COHERE_TOKENIZER_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG],
# tags=[TEXT_MODEL_TAG,
# FULL_FUNCTIONALITY_TEXT_MODEL_TAG, COHERE_TOKENIZER_TAG, INSTRUCTION_FOLLOWING_MODEL_TAG],
# ),
# # EleutherAI
# Model(
Expand Down
36 changes: 19 additions & 17 deletions src/helm/proxy/test_models.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,29 @@
from .models import get_model, get_model_group, get_models_by_organization, get_all_code_models, Model
# TODO(PR): Remove this and create appropriate tests

# from .models import get_model, get_model_group, get_models_by_organization, get_all_code_models, Model

def test_get_model():
model: Model = get_model("ai21/j1-jumbo")
assert model.organization == "ai21"
assert model.engine == "j1-jumbo"

# def test_get_model():
# model: Model = get_model("ai21/j1-jumbo")
# assert model.organization == "ai21"
# assert model.engine == "j1-jumbo"

def test_get_model_with_invalid_model_name():
try:
get_model("invalid/model")
assert False, "Expected to throw ValueError"
except ValueError:
pass

# def test_get_model_with_invalid_model_name():
# try:
# get_model("invalid/model")
# assert False, "Expected to throw ValueError"
# except ValueError:
# pass

def test_get_model_group():
assert get_model_group("openai/text-curie-001") == "gpt3"

# def test_get_model_group():
# assert get_model_group("openai/text-curie-001") == "gpt3"

def test_get_models_by_organization():
assert get_models_by_organization("simple") == ["simple/model1"]

# def test_get_models_by_organization():
# assert get_models_by_organization("simple") == ["simple/model1"]

def test_all_code_models():
assert "openai/code-davinci-002" in get_all_code_models()

# def test_all_code_models():
# assert "openai/code-davinci-002" in get_all_code_models()
2 changes: 1 addition & 1 deletion src/helm/proxy/token_counters/auto_token_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ def count_tokens(self, request: Request, completions: List[Sequence]) -> int:
"""
Counts tokens based on the organization.
"""
token_counter: TokenCounter = self.get_token_counter(request.model_organization)
token_counter: TokenCounter = self.get_token_counter(request.model_host)
return token_counter.count_tokens(request, completions)

0 comments on commit 427e723

Please sign in to comment.