Skip to content

Commit d2e1858

Browse files
authored
fix: qdrant unable to see index_name (#1705)
Signed-off-by: jupyterjazz <[email protected]>
1 parent 0ea6846 commit d2e1858

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

docarray/index/backends/qdrant.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,6 @@ class QdrantDocumentIndex(BaseDocIndex, Generic[TSchema]):
6767

6868
def __init__(self, db_config=None, **kwargs):
6969
"""Initialize QdrantDocumentIndex"""
70-
if db_config is not None and getattr(db_config, 'index_name'):
71-
db_config.collection_name = db_config.index_name
72-
7370
super().__init__(db_config=db_config, **kwargs)
7471
self._db_config: QdrantDocumentIndex.DBConfig = cast(
7572
QdrantDocumentIndex.DBConfig, self._db_config
@@ -101,7 +98,11 @@ def collection_name(self):
10198
'To do so, use the syntax: QdrantDocumentIndex[DocumentType]'
10299
)
103100

104-
return self._db_config.collection_name or default_collection_name
101+
return (
102+
self._db_config.collection_name
103+
or self._db_config.index_name
104+
or default_collection_name
105+
)
105106

106107
@property
107108
def index_name(self):
@@ -563,7 +564,7 @@ def _text_search_batched(
563564

564565
def _filter_by_parent_id(self, id: str) -> Optional[List[str]]:
565566
response, _ = self._client.scroll(
566-
collection_name=self._db_config.collection_name, # type: ignore
567+
collection_name=self.collection_name, # type: ignore
567568
scroll_filter=rest.Filter(
568569
must=[
569570
rest.FieldCondition(
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
import pytest
3+
from pydantic import Field
4+
5+
from docarray import BaseDoc
6+
from docarray.index import QdrantDocumentIndex
7+
from docarray.typing import NdArray
8+
from tests.index.qdrant.fixtures import start_storage, tmp_collection_name # noqa: F401
9+
10+
11+
pytestmark = [pytest.mark.slow, pytest.mark.index]
12+
13+
14+
def test_configure_dim():
15+
class Schema1(BaseDoc):
16+
tens: NdArray = Field(dim=10)
17+
18+
index = QdrantDocumentIndex[Schema1](host='localhost')
19+
20+
docs = [Schema1(tens=np.random.random((10,))) for _ in range(10)]
21+
index.index(docs)
22+
23+
assert index.num_docs() == 10
24+
25+
class Schema2(BaseDoc):
26+
tens: NdArray[20]
27+
28+
index = QdrantDocumentIndex[Schema2](host='localhost')
29+
docs = [Schema2(tens=np.random.random((20,))) for _ in range(10)]
30+
index.index(docs)
31+
32+
assert index.num_docs() == 10
33+
34+
35+
def test_index_name():
36+
class Schema(BaseDoc):
37+
tens: NdArray = Field(dim=10)
38+
39+
index1 = QdrantDocumentIndex[Schema]()
40+
assert index1.index_name == 'schema'
41+
42+
index2 = QdrantDocumentIndex[Schema](index_name='my_index')
43+
assert index2.index_name == 'my_index'
44+
45+
index3 = QdrantDocumentIndex[Schema](collection_name='my_index')
46+
assert index3.index_name == 'my_index'

0 commit comments

Comments
 (0)