Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ritaranx authored Jun 24, 2024
1 parent c5f7f34 commit ad783bd
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 0 deletions.
121 changes: 121 additions & 0 deletions eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from torch import Tensor

from typing import List, Dict, Union, Tuple
import numpy as np
import logging

from tqdm import trange

import torch

from torch import Tensor
from transformers import AutoTokenizer, AutoModel

from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from transformers.file_utils import PaddingStrategy

import logging
import pathlib, os

def last_token_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
embedding = last_hidden[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden.shape[0]
embedding = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
return embedding


def get_detailed_instruct_query(task_description: str, query: str) -> str:
return f'{task_description}\nQuery: {query}'

def get_detailed_instruct_passage(passage: str) -> str:
return f'Represent this passage\npassage: {passage}'


logger = logging.getLogger(__name__)


class SentenceBERT:
def __init__(self, model_path: Union[str, Tuple] = "BMRetriever/BMRetriever-7B", sep: str = " ", **kwargs):
self.sep = sep
self.task = 'Given a scientific claim, retrieve documents that support or refute the claim'

self.model = AutoModel.from_pretrained(model_path, device_map="auto", torch_dtype=torch.float32, cache_dir="/localscratch/yueyu/cache")
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
self.max_length = 512
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.padding_side="left"

def encode(self, input_texts, batch_size=8, **kwargs):
# Tokenize the input texts
embeddings = []
self.model.eval()
with torch.no_grad():
for i in trange(0, len(input_texts), batch_size):
input_text = input_texts[i: (i+batch_size)]
batch_dict = self.tokenizer(
input_text,
max_length=self.max_length-1,
return_attention_mask=False,
return_token_type_ids=False,
padding=PaddingStrategy.DO_NOT_PAD,
truncation=True
)
batch_dict['input_ids'] = [input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
batch_dict = self.tokenizer.pad(batch_dict, padding=True, return_attention_mask=True, return_tensors='pt').to("cuda")
outputs = self.model(**batch_dict)
embedding = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings.append(embedding)
embeddings = torch.cat(embeddings, dim = 0)
print(embeddings.shape)
return embeddings


def encode_queries(self, queries: List[str], batch_size: int = 16, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
queries = [get_detailed_instruct_query(self.task, query) for query in queries]
return self.encode(queries, batch_size=batch_size, **kwargs)

def encode_corpus(self, corpus: Union[List[Dict[str, str]], Dict[str, List]], batch_size: int = 8, **kwargs) -> Union[List[Tensor], np.ndarray, Tensor]:
if type(corpus) is dict:
sentences = [(corpus["title"][i] + self.sep + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() for i in range(len(corpus['text']))]
else:
sentences = [(doc["title"] + self.sep + doc["text"]).strip() if "title" in doc else doc["text"].strip() for doc in corpus]
sentences = [get_detailed_instruct_passage(passage) for passage in sentences]
return self.encode(sentences, batch_size=batch_size, **kwargs)



#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.INFO,
handlers=[LoggingHandler()])
#### /print debug information to stdout

#### Download scifact.zip dataset and unzip the dataset
dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
out_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), "datasets")
data_path = util.download_and_unzip(url, out_dir)

#### Provide the data_path where scifact has been downloaded and unzipped
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")


#### Load the SBERT model and retrieve using cosine-similarity
model = DRES(SentenceBERT(model_path="BMRetriever/BMRetriever-1B"), batch_size=32)
retriever = EvaluateRetrieval(model, score_function="dot") # or "cos_sim" for cosine similarity #cos_sim
results = retriever.retrieve(corpus, queries)

#### Evaluate your model with NDCG@k, MAP@K, Recall@K and Precision@K where k = [1,3,5,10,100,1000]
ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
19 changes: 19 additions & 0 deletions prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@


task_name_to_instruct = {
"gooaq": "Given a question, retrieve relevant passages that answer the question",
"eli5": "Provided a user question, retrieve the highest voted answers on Reddit ELI5 forum",
"squad": "Given a question, retrieve Wikipedia passages that answer the question",
"nli": "Given a premise, retrieve a hypothesis that is entailed by the premise",
'msmarco': 'Given a web search query, retrieve relevant passages that answer the query',
'nq': 'Given a question, retrieve Wikipedia passages that answer the question',
'quora': 'Given a question, retrieve questions that are semantically equivalent to the given question',
"fever": 'Given a claim, retrieve documents that support or refute the claim',
"hotpotqa": 'Given a multi-hop question, retrieve documents that can help answer the question',
"msmarco_doc": "Given a web search query, retrieve relevant documents that answer the query",
'nfcorpus': 'Given a question, retrieve relevant documents that best answer the question',
'scidocs': 'Given a scientific paper title, retrieve paper abstracts that are cited by the given paper',
'scifact': 'Given a scientific claim, retrieve documents that support or refute the claim',
'trec-covid': 'Given a query on COVID-19, retrieve documents that answer the query',
}

0 comments on commit ad783bd

Please sign in to comment.