Skip to content

Commit

Permalink
Standardize test groups
Browse files Browse the repository at this point in the history
- Use same test groups for benchmarking and evaluation
- Add a custom enum class with initiutive methods to dynamically create test groups
- Use custom enum to reduce manually creation of test groups
- Update benchmark cli args to accept test group argument
- Add pydantic validator to validate test group and test categories
  • Loading branch information
devanshamin committed Jul 9, 2024
1 parent 893c9af commit 4d8a86c
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 47 deletions.
32 changes: 21 additions & 11 deletions berkeley-function-call-leaderboard/bfcl/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from bfcl.model_handler.base import BaseHandler, ModelStyle
from bfcl.types import (LeaderboardCategory, LeaderboardCategories,
LeaderboardVersion, ModelType)
LeaderboardVersion, ModelType, LeaderboardCategoryGroup)

load_dotenv()

Expand Down Expand Up @@ -37,18 +37,25 @@ def get_args() -> argparse.Namespace:
parser.add_argument(
'--model-type',
type=ModelType,
choices=[category.value for category in ModelType],
choices=[mtype.value for mtype in ModelType],
default=ModelType.PROPRIETARY.value,
help="Model type: Open-source or Proprietary (default: 'proprietary')"
)
parser.add_argument(
'--test-category',
'--test-group',
type=LeaderboardCategoryGroup,
choices=[group.value for group in LeaderboardCategoryGroup],
default=None,
help='Test category group (default: None)'
)
parser.add_argument(
'--test-categories',
type=str,
default=LeaderboardCategory.ALL.value,
default=None,
help=(
'Comma-separated list of test categories '
f"({','.join(category.value for category in LeaderboardCategory)}). "
"(default: 'all')"
f"({','.join(cat.value for cat in LeaderboardCategory)}). "
"(default: None)"
)
)
parser.add_argument(
Expand All @@ -68,15 +75,18 @@ def get_args() -> argparse.Namespace:


def _get_test_categories(args) -> LeaderboardCategories:
if args.test_category == LeaderboardCategory.ALL.value:
categories = [category for category in LeaderboardCategory if category != LeaderboardCategory.ALL]
else:
if args.test_categories:
categories = []
for value in args.test_category.split(','):
for value in args.test_categories.split(','):
if value not in LeaderboardCategory._value2member_map_:
raise ValueError(f'Invalid test category: "{value}"!')
categories.append(LeaderboardCategory(value))
return LeaderboardCategories(categories=categories, version=args.version)
args.test_categories = categories
return LeaderboardCategories(
test_group=args.test_group,
test_categories=args.test_categories,
version=args.version
)


def _get_model_handler(args) -> BaseHandler:
Expand Down
96 changes: 60 additions & 36 deletions berkeley-function-call-leaderboard/bfcl/types.py
Original file line number Diff line number Diff line change
@@ -1,68 +1,92 @@
import json
import hashlib
from enum import Enum
from typing import Any, List, Dict
from pathlib import Path
from typing import Any, List, Dict, Type

from pydantic import BaseModel
from pydantic import BaseModel, model_validator
from huggingface_hub import hf_hub_download

from bfcl.utils import CustomEnum


class ModelType(str, Enum):
OSS = 'oss'
PROPRIETARY = 'proprietary'

class LeaderboardNonPythonCategory(str, CustomEnum):
JAVA = 'java'
JAVASCRIPT = 'javascript'

class LeaderboardExecutableCategory(str, Enum):
EXEC_SIMPLE = 'executable_simple'
EXEC_PARALLEL_FUNCTION = 'executable_parallel_function'
EXEC_MULTIPLE_FUNCTION = 'executable_multiple_function'
EXEC_PARALLEL_MULTIPLE_FUNCTION = 'executable_parallel_multiple_function'
REST = 'rest'


class LeaderboardAstCategory(str, Enum):
class LeaderboardAstCategory(str, CustomEnum):
SIMPLE = 'simple'
RELEVANCE = 'relevance'
PARALLEL_FUNCTION = 'parallel_function'
MULTIPLE_FUNCTION = 'multiple_function'
PARALLEL_FUNCTION = 'parallel_function'
PARALLEL_MULTIPLE_FUNCTION = 'parallel_multiple_function'
JAVA = 'java'
JAVASCRIPT = 'javascript'


class LeaderboardCategory(str, Enum):
EXEC_SIMPLE = LeaderboardExecutableCategory.EXEC_SIMPLE.value
EXEC_PARALLEL_FUNCTION = LeaderboardExecutableCategory.EXEC_PARALLEL_FUNCTION.value
EXEC_MULTIPLE_FUNCTION = LeaderboardExecutableCategory.EXEC_MULTIPLE_FUNCTION.value
EXEC_PARALLEL_MULTIPLE_FUNCTION = LeaderboardExecutableCategory.EXEC_PARALLEL_MULTIPLE_FUNCTION.value
REST = LeaderboardExecutableCategory.REST.value
SIMPLE = LeaderboardAstCategory.SIMPLE.value
RELEVANCE = LeaderboardAstCategory.RELEVANCE.value
PARALLEL_FUNCTION = LeaderboardAstCategory.PARALLEL_FUNCTION.value
MULTIPLE_FUNCTION = LeaderboardAstCategory.MULTIPLE_FUNCTION.value
PARALLEL_MULTIPLE_FUNCTION = LeaderboardAstCategory.PARALLEL_MULTIPLE_FUNCTION.value
JAVA = LeaderboardAstCategory.JAVA.value
JAVASCRIPT = LeaderboardAstCategory.JAVASCRIPT.value
SQL = 'sql'
CHATABLE = 'chatable'
ALL = 'all' # Adding the 'ALL' category
JAVA = LeaderboardNonPythonCategory.JAVA
JAVASCRIPT = LeaderboardNonPythonCategory.JAVASCRIPT

class LeaderboardExecutableCategory(str, CustomEnum):
EXECUTABLE_SIMPLE = 'executable_simple'
EXECUTABLE_PARALLEL_FUNCTION = 'executable_parallel_function'
EXECUTABLE_MULTIPLE_FUNCTION = 'executable_multiple_function'
EXECUTABLE_PARALLEL_MULTIPLE_FUNCTION = 'executable_parallel_multiple_function'
REST = 'rest'

LeaderboardPythonCategory: Type[CustomEnum] = (
LeaderboardAstCategory
.add(LeaderboardExecutableCategory)
.subtract(LeaderboardNonPythonCategory)
.rename('LeaderboardPythonCategory')
)

LeaderboardCategory: Type[CustomEnum] = (
LeaderboardPythonCategory
.add(LeaderboardNonPythonCategory)
.rename('LeaderboardCategory')
.update(dict(SQL='sql', CHATABLE='chatable'))
)

class LeaderboardCategoryGroup(str, Enum):
AST = 'ast'
EXECUTABLE = 'executable'
NON_PYTHON = 'non_python'
PYTHON = 'python'
ALL = 'all'

CATEGORY_GROUP_MAPPING = {
LeaderboardCategoryGroup.AST: LeaderboardAstCategory,
LeaderboardCategoryGroup.EXECUTABLE: LeaderboardExecutableCategory,
LeaderboardCategoryGroup.NON_PYTHON: LeaderboardNonPythonCategory,
LeaderboardCategoryGroup.PYTHON: LeaderboardPythonCategory,
LeaderboardCategoryGroup.ALL: LeaderboardCategory
}

class LeaderboardVersion(str, Enum):
V1 = 'v1'


class LeaderboardCategories(BaseModel):
categories: List[LeaderboardCategory]
test_group: LeaderboardCategoryGroup | None = None
test_categories: List[LeaderboardCategory] | None = None # type: ignore
version: LeaderboardVersion = LeaderboardVersion.V1
cache_dir: Path | str = '.cache'

@model_validator(mode='before')
@classmethod
def check_either_field_provided(cls, values):
if values.get('test_group') is not None and values.get('test_categories') is not None:
raise ValueError("Provide either 'test_group' or 'test_categories', not both")
elif values.get('test_group') is None and values.get('test_categories') is None:
raise ValueError("Provide either 'test_group' or 'test_categories'")
return values

def model_post_init(self, __context: Any) -> None:
if LeaderboardCategory.ALL in self.categories:
self.categories = [cat for cat in LeaderboardCategory if cat != LeaderboardCategory.ALL]
if self.test_group:
self.test_categories = [cat for cat in CATEGORY_GROUP_MAPPING[self.test_group]]
self.cache_dir = Path.cwd() / self.cache_dir

@property
def output_file_path(self) -> Path:
uid = self._generate_hash(self.model_dump_json())
Expand Down
25 changes: 25 additions & 0 deletions berkeley-function-call-leaderboard/bfcl/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from enum import Enum


class CustomEnum(Enum):
@classmethod
def add(cls, other):
combined_members = {member.name: member.value for member in cls}
combined_members.update({member.name: member.value for member in other})
return __class__(cls.__name__, combined_members)

@classmethod
def subtract(cls, other):
remaining_members = {member.name: member.value for member in cls if member not in other}
return __class__(cls.__name__, remaining_members)

@classmethod
def rename(cls, new_name):
members = {member.name: member.value for member in cls}
return __class__(new_name, members)

@classmethod
def update(cls, new_members):
members = {member.name: member.value for member in cls}
members.update(new_members)
return __class__(cls.__name__, members)

0 comments on commit 4d8a86c

Please sign in to comment.