Skip to content

Commit f1f650e

Browse files
committed
Added Cohere example [skip ci]
1 parent 1421ce3 commit f1f650e

File tree

3 files changed

+77
-0
lines changed

3 files changed

+77
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Follow the instructions for your database library:
1717
Or check out some examples:
1818

1919
- [Embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/openai/src/main.rs) with OpenAI
20+
- [Binary embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/cohere/src/main.rs) with Cohere
2021
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
2122
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2223

examples/cohere/Cargo.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
pgvector = { path = "../..", features = ["postgres"] }
9+
postgres = "0.19"
10+
serde_json = "1"
11+
ureq = { version = "2", features = ["json"] }

examples/cohere/src/main.rs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
use pgvector::Bit;
2+
use postgres::{Client, NoTls};
3+
use serde_json::Value;
4+
use std::error::Error;
5+
6+
fn main() -> Result<(), Box<dyn Error>> {
7+
let mut client = Client::configure()
8+
.host("localhost")
9+
.dbname("pgvector_example")
10+
.user(std::env::var("USER")?.as_str())
11+
.connect(NoTls)?;
12+
13+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
14+
client.execute("DROP TABLE IF EXISTS documents", &[])?;
15+
client.execute("CREATE TABLE documents (id serial PRIMARY KEY, content text, embedding bit(1024))", &[])?;
16+
17+
let input = [
18+
"The dog is barking",
19+
"The cat is purring",
20+
"The bear is growling",
21+
];
22+
let embeddings = fetch_embeddings(&input, "search_document")?;
23+
for (content, embedding) in input.iter().zip(embeddings) {
24+
let embedding = Bit::from_bytes(&embedding);
25+
client.execute("INSERT INTO documents (content, embedding) VALUES ($1, $2)", &[&content, &embedding])?;
26+
}
27+
28+
let query = "forest";
29+
let query_embedding = fetch_embeddings(&[query], "search_query")?;
30+
for row in client.query("SELECT content FROM documents ORDER BY embedding <~> $1 LIMIT 5", &[&Bit::from_bytes(&query_embedding[0])])? {
31+
let content: &str = row.get(0);
32+
println!("{}", content);
33+
}
34+
35+
Ok(())
36+
}
37+
38+
fn fetch_embeddings(texts: &[&str], input_type: &str) -> Result<Vec<Vec<u8>>, Box<dyn Error>> {
39+
let api_key = std::env::var("CO_API_KEY").or(Err("Set CO_API_KEY"))?;
40+
41+
let response: Value = ureq::post("https://api.cohere.com/v1/embed")
42+
.set("Authorization", &format!("Bearer {}", api_key))
43+
.send_json(ureq::json!({
44+
"texts": texts,
45+
"model": "embed-english-v3.0",
46+
"input_type": input_type,
47+
"embedding_types": &["ubinary"],
48+
}))?
49+
.into_json()?;
50+
51+
let embeddings = response["embeddings"]["ubinary"]
52+
.as_array()
53+
.unwrap()
54+
.iter()
55+
.map(|v| {
56+
v.as_array()
57+
.unwrap()
58+
.iter()
59+
.map(|v| v.as_f64().unwrap() as u8)
60+
.collect()
61+
})
62+
.collect();
63+
64+
Ok(embeddings)
65+
}

0 commit comments

Comments
 (0)