Skip to content

Commit

Permalink
[Fix] Configuring SSL proxy via openapi_config object (#321)
Browse files Browse the repository at this point in the history
## Problem

A few different problems being solved here:
- `Pinecone` class accepts an optional param, `openapi_config`. When
this is passed (usually as a vehicle for SSL configurations), it
currently clobbers the `api_key` param so the user sees an error message
about not providing an api_key (even though they did pass it) if they
attempt to perform a control plane operation

```python
from pinecone import Pinecone
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration

openapi_config = OpenApiConfiguration()
openapi_config.ssl_ca_cert = '/path/to/cert'

pc = Pinecone(api_key='key, openapi_config=openapi_config)
pc.list_indexes() // error: No api-key provided
```

- In a related issue, the `openapi_config` (with SSL configs) was not
being correctly passed through to the underlying `DataPlaneApi` for data
calls. So users with custom network configurations requiring SSL config
would see SSL validation failures when attempting data operations.

```python
from pinecone import Pinecone
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration

openapi_config = OpenApiConfiguration()
openapi_config.ssl_ca_cert = '/path/to/cert'

pc = Pinecone(api_key='key, openapi_config=openapi_config)
pc.list_indexes() // error: No api-key provided
```

## Solution

- Adjust the ConfigBuilder to avoid clobbering API key
- Move some logic into a util function so that behavior will be
consistent across both data and control planes
- Ensure configuration is passed to from the `Pinecone` object to the
index client
- deepcopy the openapi_config object before modifying it so that
index-specific host changes do clobber control plane or calls to other
indexes.

## Future work

- In the future, we should deprecate `openapi_config` and have some way
of passing SSL config without all the baggage that comes with this
OpenApiConfiguration object. This config object is an undocumented
holdover from earlier versions of the client and breaks the abstraction
the client is trying to provide to smooth out the UX of the generated
SDK code.

## Type of Change

- [x] Bug fix (non-breaking change which fixes an issue)

## Test Plan

- Added tests
  • Loading branch information
jhamon authored Mar 14, 2024
1 parent b02388e commit a201cb6
Show file tree
Hide file tree
Showing 11 changed files with 175 additions and 55 deletions.
12 changes: 7 additions & 5 deletions pinecone/config/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import NamedTuple, Optional, Dict
import os
import copy

from pinecone.exceptions import PineconeConfigurationError
from pinecone.config.openapi import OpenApiConfigFactory
Expand Down Expand Up @@ -46,10 +47,11 @@ def build(
if not host:
raise PineconeConfigurationError("You haven't specified a host.")

openapi_config = (
openapi_config
or kwargs.pop("openapi_config", None)
or OpenApiConfigFactory.build(api_key=api_key, host=host)
)
if openapi_config:
openapi_config = copy.deepcopy(openapi_config)
openapi_config.host = host
openapi_config.api_key = {"ApiKeyAuth": api_key}
else:
openapi_config = OpenApiConfigFactory.build(api_key=api_key, host=host)

return Config(api_key, host, openapi_config, additional_headers)
34 changes: 18 additions & 16 deletions pinecone/control/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pinecone.config import PineconeConfig, Config

from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
from pinecone.core.client.api_client import ApiClient
from pinecone.utils import get_user_agent, normalize_host
from pinecone.utils import normalize_host, setup_openapi_client
from pinecone.core.client.models import (
CreateCollectionRequest,
CreateIndexRequest,
Expand Down Expand Up @@ -85,25 +84,20 @@ def __init__(
or share with Pinecone support. **Be very careful with this option, as it will print out
your API key** which forms part of a required authentication header. Default: `false`
"""
if config or kwargs.get("config"):
configKwarg = config or kwargs.get("config")
if not isinstance(configKwarg, Config):
if config:
if not isinstance(config, Config):
raise TypeError("config must be of type pinecone.config.Config")
else:
self.config = configKwarg
self.config = config
else:
self.config = PineconeConfig.build(api_key=api_key, host=host, additional_headers=additional_headers, **kwargs)

self.pool_threads = pool_threads

if index_api:
self.index_api = index_api
else:
api_client = ApiClient(configuration=self.config.openapi_config, pool_threads=self.pool_threads)
api_client.user_agent = get_user_agent()
extra_headers = self.config.additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)
self.index_api = ManageIndexesApi(api_client)
self.index_api = setup_openapi_client(ManageIndexesApi, self.config, pool_threads)

self.index_host_store = IndexHostStore()
""" @private """
Expand Down Expand Up @@ -521,12 +515,20 @@ def Index(self, name: str = '', host: str = '', **kwargs):
raise ValueError("Either name or host must be specified")

pt = kwargs.pop('pool_threads', None) or self.pool_threads
api_key = self.config.api_key
openapi_config = self.config.openapi_config

if host != '':
# Use host url if it is provided
return Index(api_key=self.config.api_key, host=normalize_host(host), pool_threads=pt, **kwargs)

if name != '':
index_host=normalize_host(host)
else:
# Otherwise, get host url from describe_index using the index name
index_host = self.index_host_store.get_host(self.index_api, self.config, name)
return Index(api_key=self.config.api_key, host=index_host, pool_threads=pt, **kwargs)

return Index(
host=index_host,
api_key=api_key,
pool_threads=pt,
openapi_config=openapi_config,
**kwargs
)
26 changes: 11 additions & 15 deletions pinecone/data/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
ListResponse
)
from pinecone.core.client.api.data_plane_api import DataPlaneApi
from ..utils import get_user_agent
from ..utils import setup_openapi_client
from .vector_factory import VectorFactory

__all__ = [
Expand Down Expand Up @@ -75,27 +75,23 @@ def __init__(
host: str,
pool_threads: Optional[int] = 1,
additional_headers: Optional[Dict[str, str]] = {},
openapi_config = None,
**kwargs
):
self._config = ConfigBuilder.build(api_key=api_key, host=host, **kwargs)

api_client = ApiClient(configuration=self._config.openapi_config,
pool_threads=pool_threads)

# Configure request headers
api_client.user_agent = get_user_agent()
extra_headers = additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)

self._api_client = api_client
self._vector_api = DataPlaneApi(api_client=api_client)
self._config = ConfigBuilder.build(
api_key=api_key,
host=host,
additional_headers=additional_headers,
openapi_config=openapi_config,
**kwargs
)
self._vector_api = setup_openapi_client(DataPlaneApi, self._config, pool_threads)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self._api_client.close()
self._vector_api.api_client.close()

@validate_and_convert_errors
def upsert(
Expand Down
3 changes: 2 additions & 1 deletion pinecone/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
from .deprecation_notice import warn_deprecated
from .fix_tuple_length import fix_tuple_length
from .convert_to_list import convert_to_list
from .normalize_host import normalize_host
from .normalize_host import normalize_host
from .setup_openapi_client import setup_openapi_client
14 changes: 14 additions & 0 deletions pinecone/utils/setup_openapi_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pinecone.core.client.api_client import ApiClient
from .user_agent import get_user_agent

def setup_openapi_client(api_klass, config, pool_threads):
api_client = ApiClient(
configuration=config.openapi_config,
pool_threads=pool_threads
)
api_client.user_agent = get_user_agent()
extra_headers = config.additional_headers or {}
for key, value in extra_headers.items():
api_client.set_default_header(key, value)
client = api_klass(api_client)
return client
4 changes: 4 additions & 0 deletions tests/integration/data/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ def build_client():
from pinecone import Pinecone
return Pinecone(api_key=api_key(), additional_headers={'sdk-test-suite': 'pinecone-python-client'})

@pytest.fixture(scope='session')
def api_key_fixture():
return api_key()

@pytest.fixture(scope='session')
def client():
return build_client()
Expand Down
18 changes: 18 additions & 0 deletions tests/integration/data/test_openapi_configuration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pytest
import os

from pinecone import Pinecone
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
from urllib3 import make_headers

@pytest.mark.skipif(os.getenv('USE_GRPC') != 'false', reason='Only test when using REST')
class TestIndexOpenapiConfig:
def test_passing_openapi_config(self, api_key_fixture, index_host):
oai_config = OpenApiConfiguration.get_default_copy()
p = Pinecone(api_key=api_key_fixture, openapi_config=oai_config)
assert p.config.api_key == api_key_fixture
p.list_indexes() # should not throw

index = p.Index(host=index_host)
assert index._config.api_key == api_key_fixture
index.describe_index_stats()
47 changes: 44 additions & 3 deletions tests/unit/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pytest
import os

from urllib3 import make_headers

class TestConfig:
@pytest.fixture(autouse=True)
def run_before_and_after_tests(tmpdir):
Expand Down Expand Up @@ -49,13 +51,14 @@ def test_init_with_positional_args(self):
def test_init_with_kwargs(self):
api_key = "my-api-key"
controller_host = "my-controller-host"
openapi_config = OpenApiConfiguration(api_key="openapi-api-key")
openapi_config = OpenApiConfiguration()
openapi_config.ssl_ca_cert = 'path/to/cert'

config = PineconeConfig.build(api_key=api_key, host=controller_host, openapi_config=openapi_config)

assert config.api_key == api_key
assert config.host == 'https://' + controller_host
assert config.openapi_config == openapi_config
assert config.openapi_config.ssl_ca_cert == 'path/to/cert'

def test_resolution_order_kwargs_over_env_vars(self):
"""
Expand Down Expand Up @@ -84,5 +87,43 @@ def test_config_pool_threads(self):
pc = Pinecone(api_key="test-api-key", host="test-controller-host", pool_threads=10)
assert pc.index_api.api_client.pool_threads == 10
idx = pc.Index(host='my-index-host', name='my-index-name')
assert idx._api_client.pool_threads == 10
assert idx._vector_api.api_client.pool_threads == 10

def test_config_when_openapi_config_is_passed_merges_api_key(self):
oai_config = OpenApiConfiguration()
pc = Pinecone(api_key='asdf', openapi_config=oai_config)
assert pc.config.openapi_config.api_key == {'ApiKeyAuth': 'asdf'}

def test_ssl_config_passed_to_index_client(self):
oai_config = OpenApiConfiguration()
oai_config.ssl_ca_cert = 'path/to/cert'
proxy_headers = make_headers(proxy_basic_auth='asdf')
oai_config.proxy_headers = proxy_headers

pc = Pinecone(api_key='key', openapi_config=oai_config)

assert pc.config.openapi_config.ssl_ca_cert == 'path/to/cert'
assert pc.config.openapi_config.proxy_headers == proxy_headers

idx = pc.Index(host='host')
assert idx._vector_api.api_client.configuration.ssl_ca_cert == 'path/to/cert'
assert idx._vector_api.api_client.configuration.proxy_headers == proxy_headers

def test_host_config_not_clobbered_by_index(self):
oai_config = OpenApiConfiguration()
oai_config.ssl_ca_cert = 'path/to/cert'
proxy_headers = make_headers(proxy_basic_auth='asdf')
oai_config.proxy_headers = proxy_headers

pc = Pinecone(api_key='key', openapi_config=oai_config)

assert pc.config.openapi_config.ssl_ca_cert == 'path/to/cert'
assert pc.config.openapi_config.proxy_headers == proxy_headers
assert pc.config.openapi_config.host == 'https://api.pinecone.io'

idx = pc.Index(host='host')
assert idx._vector_api.api_client.configuration.ssl_ca_cert == 'path/to/cert'
assert idx._vector_api.api_client.configuration.proxy_headers == proxy_headers
assert idx._vector_api.api_client.configuration.host == 'https://host'

assert pc.config.openapi_config.host == 'https://api.pinecone.io'
36 changes: 36 additions & 0 deletions tests/unit/test_config_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from pinecone.core.client.configuration import Configuration as OpenApiConfiguration
from pinecone.config import ConfigBuilder
from pinecone import PineconeConfigurationError

class TestConfigBuilder:
def test_build_simple(self):
config = ConfigBuilder.build(api_key="my-api-key", host="https://my-controller-host")
assert config.api_key == "my-api-key"
assert config.host == "https://my-controller-host"
assert config.additional_headers == {}
assert config.openapi_config.host == "https://my-controller-host"
assert config.openapi_config.api_key == {"ApiKeyAuth": "my-api-key"}

def test_build_merges_key_and_host_when_openapi_config_provided(self):
config = ConfigBuilder.build(
api_key="my-api-key",
host="https://my-controller-host",
openapi_config=OpenApiConfiguration()
)
assert config.api_key == "my-api-key"
assert config.host == "https://my-controller-host"
assert config.additional_headers == {}
assert config.openapi_config.host == "https://my-controller-host"
assert config.openapi_config.api_key == {"ApiKeyAuth": "my-api-key"}

def test_build_errors_when_no_api_key_is_present(self):
with pytest.raises(PineconeConfigurationError) as e:
ConfigBuilder.build()
assert str(e.value) == "You haven't specified an Api-Key."

def test_build_errors_when_no_host_is_present(self):
with pytest.raises(PineconeConfigurationError) as e:
ConfigBuilder.build(api_key='my-api-key')
assert str(e.value) == "You haven't specified a host."
14 changes: 10 additions & 4 deletions tests/unit/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from pinecone import Pinecone, PodSpec, ServerlessSpec
from pinecone.core.client.models import IndexList, IndexModel
from pinecone.core.client.api.manage_indexes_api import ManageIndexesApi
from pinecone.core.client.configuration import Configuration as OpenApiConfiguration

import time

@pytest.fixture
Expand Down Expand Up @@ -107,25 +109,29 @@ def test_list_indexes_returns_iterable(self, mocker, index_list_response):
response = p.list_indexes()
assert [i.name for i in response] == ["index1", "index2", "index3"]

def test_api_key_and_openapi_config(self, mocker):
p = Pinecone(api_key="123", openapi_config=OpenApiConfiguration.get_default_copy())
assert p.config.api_key == "123"

class TestIndexConfig:
def test_default_pool_threads(self):
pc = Pinecone(api_key="123-456-789")
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 1
assert index._vector_api.api_client.pool_threads == 1

def test_pool_threads_when_indexapi_passed(self):
pc = Pinecone(api_key="123-456-789", pool_threads=2, index_api=ManageIndexesApi())
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 2
assert index._vector_api.api_client.pool_threads == 2

def test_target_index_with_pool_threads_inherited(self):
pc = Pinecone(api_key="123-456-789", pool_threads=10, foo='bar')
index = pc.Index(host='my-host.svg.pinecone.io')
assert index._api_client.pool_threads == 10
assert index._vector_api.api_client.pool_threads == 10

def test_target_index_with_pool_threads_kwarg(self):
pc = Pinecone(api_key="123-456-789", pool_threads=10)
index = pc.Index(host='my-host.svg.pinecone.io', pool_threads=5)
assert index._api_client.pool_threads == 5
assert index._vector_api.api_client.pool_threads == 5


22 changes: 11 additions & 11 deletions tests/unit/test_index_initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ class TestIndexClientInitialization():
def test_no_additional_headers_leaves_useragent_only(self, additional_headers):
pc = Pinecone(api_key='YOUR_API_KEY')
index = pc.Index(host='myhost', additional_headers=additional_headers)
assert len(index._api_client.default_headers) == 1
assert 'User-Agent' in index._api_client.default_headers
assert 'python-client-' in index._api_client.default_headers['User-Agent']
assert len(index._vector_api.api_client.default_headers) == 1
assert 'User-Agent' in index._vector_api.api_client.default_headers
assert 'python-client-' in index._vector_api.api_client.default_headers['User-Agent']

def test_additional_headers_one_additional(self):
pc = Pinecone(api_key='YOUR_API_KEY')
index = pc.Index(
host='myhost',
additional_headers={'test-header': 'test-header-value'}
)
assert 'test-header' in index._api_client.default_headers
assert len(index._api_client.default_headers) == 2
assert 'test-header' in index._vector_api.api_client.default_headers
assert len(index._vector_api.api_client.default_headers) == 2

def test_multiple_additional_headers(self):
pc = Pinecone(api_key='YOUR_API_KEY')
Expand All @@ -34,9 +34,9 @@ def test_multiple_additional_headers(self):
'test-header2': 'test-header-value2'
}
)
assert 'test-header' in index._api_client.default_headers
assert 'test-header2' in index._api_client.default_headers
assert len(index._api_client.default_headers) == 3
assert 'test-header' in index._vector_api.api_client.default_headers
assert 'test-header2' in index._vector_api.api_client.default_headers
assert len(index._vector_api.api_client.default_headers) == 3

def test_overwrite_useragent(self):
# This doesn't seem like a common use case, but we may want to allow this
Expand All @@ -48,6 +48,6 @@ def test_overwrite_useragent(self):
'User-Agent': 'test-user-agent'
}
)
assert len(index._api_client.default_headers) == 1
assert 'User-Agent' in index._api_client.default_headers
assert index._api_client.default_headers['User-Agent'] == 'test-user-agent'
assert len(index._vector_api.api_client.default_headers) == 1
assert 'User-Agent' in index._vector_api.api_client.default_headers
assert index._vector_api.api_client.default_headers['User-Agent'] == 'test-user-agent'

0 comments on commit a201cb6

Please sign in to comment.