Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproducing InstructBLIP on Flickr30K #719

Open
gabrielsantosrv opened this issue Jun 25, 2024 · 0 comments
Open

Reproducing InstructBLIP on Flickr30K #719

gabrielsantosrv opened this issue Jun 25, 2024 · 0 comments

Comments

@gabrielsantosrv
Copy link

Hi,

I'm trying to reproduce the results reported on "InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning". But, I'm facing difficulty reproducing the InstructBLIP (Vicuna-7B) results on Flickr30K test set for the image captioning task.

I'm using the model from Hugginface and executing the code snippet below, and I'm getting a Cider score of 60.9 while the reported one is 82.4.

I'm using the prompt reported on the paper, "A short image description: ", and decoding hyperparams from the example in huggingface. I wonder if I'm using the correct hyperparams and prompt?

PS: using the same hyperparams and the prompt "A short image caption." increases the cider score to 83.1

    model_name = "Salesforce/instructblip-vicuna-7b"
    processor = InstructBlipProcessor.from_pretrained(model_name)
    model = InstructBlipForConditionalGeneration.from_pretrained(model_name,  torch_dtype=torch.float16)
    model.to(device)
    model.eval()

    prompt = ["A short image description: "] * config.batch
    transform = lambda img: processor(images=img, text=prompt, return_tensors="pt")    
    dataset = load_dataset(config=config, transform=transform)
    
    results = []
    for batch in tqdm.tqdm(dataset, desc="Inference"):
        img_ids, images, _ = batch
        inputs = images.to(device)
        outputs = model.generate(
            **inputs,
            do_sample=False,
            num_beams=5,
            max_length=256,
            min_length=1,
            top_p=0.9,
            repetition_penalty=1.5,
            length_penalty=1.0,
            temperature=1,
        )
        generated_text = processor.batch_decode(outputs, skip_special_tokens=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant