In my current organization, I am working on a model to sort post-consumer clothes to make them easier to recycle into products of different grades.
For example, if a post-consumer cloth has buttons and zippers and its quality is not great, we use it to make products like table cloths etc.
Since sorting is a critical step in the process, achieving high accuracy in the model is very important.
Collecting and accumulating enough clothes for training is a time-consuming task. To address this, we use synthetic data to train the model. The process involves training with synthetic data, fine-tuning with real-world data, and then evaluating the model on real-world data. This approach effectively mitigates the issue of data shortages.
We have developed a pipeline to generate synthetic data using generative AI models. The model selection is based on its compute efficiency and generative accuracy. The generative accuracy is assessed by testing the model on a sample set of inputs and performing a visual inspection to ensure quality.
This streamlined approach ensures that our model is trained efficiently while maintaining high accuracy and adaptability to real-world data.
Below are the steps that are involved in the pipeline.
[COLOR] [DRESS_TYPE] with [FASTENER] placed on a [LOCATION]
.[COLOR]
, [DRESS_TYPE]
, etc.) with actual values.Red gown with buttons placed on a wooden table.
Blue sundress with a zipper placed on a marble countertop.
I experimented with several models from the Hugging Face library and found that the models mentioned earlier best met my requirements for output quality and inference time. The results were evaluated through visual inspection on a sample dataset. Currently, I am exploring methods to verify if the generated images contain out-of-distribution (OOD) objects that could negatively impact the accuracy of the classification model used for sorting.
One of the key issues I am addressing is ensuring that the generated image matches the prompt. For example, if the prompt specifies a dress on a conveyor belt, but the generated image does not include the conveyor belt, the algorithm must detect this mismatch. To solve this, I am considering adding another layer of validation using segmentation models and filter out images with missing or incorrect objects.
For a more computationally efficient approach, I am investigating image processing techniques like the Histogram of Oriented Gradients (HOG). Since we know the dress color and the background surface (e.g., the conveyor belt), any additional objects in the image can be detected through their unique gradient histograms. This method provides a lightweight alternative for identifying discrepancies without relying on complex deep learning models.
By integrating these techniques, we aim to create a robust pipeline that ensures the generated data aligns with the intended prompts, ultimately improving the performance and reliability of the sorting model.
main.py
import yaml
from itertools import product
from image_generator import TextToImageGenerator
from image_validation import ImageValidation
# Load the YAML data
with open('data_config.yaml', 'r') as file:
data = yaml.safe_load(file)
BASE_PROMPT = "{} {} with {} placed on a {} without any other disturbing objects on the table"
# Extract the categories
colors = data['colors']
dress_types = data['dress_types']
trims = data['trims']
locations = data['locations']
# Generate all combinations
combinations = product(colors, dress_types, trims, locations)
# Print combinations
for combo in combinations:
print(f"Color: {combo[0]}, Dress Type: {combo[1]}, Trims: {combo[2]}, Location: {combo[3]}")
# Generate the prompt text using the provided combination
print(BASE_PROMPT.format(combo[0], combo[1], combo[2], combo[3]))
prompt = BASE_PROMPT.format(combo[0], combo[1], combo[2], combo[3])
generator = TextToImageGenerator('black-forest-labs/FLUX.1-dev')
# Load LoRA weights
generator.load_lora_weights('openfree/claude-monet', weight_name='claude-monet.safetensors')
# Generate an image from a prompt
generator.generate_image(prompt, output_file="my_generated_image.png")
##############################################################################################
# Image validation
validator = ImageValidation()
# Path to the image
img_path = "my_generated_image.png"
# Question to ask
question = "Is the dress placed on a flat surface?"
# Validate the image
answer = validator.validate_image(img_path, question)
print(f"Answer: {answer}")
if answer == 'yes':
# it is a valid image
pass
image_generator.py
from diffusers import AutoPipelineForText2Image
import torch
class TextToImageGenerator:
def __init__(self, model_name, dtype=torch.bfloat16):
"""
Initialize the text-to-image pipeline with the given model.
"""
self.pipeline = AutoPipelineForText2Image.from_pretrained(model_name, torch_dtype=dtype)
def load_lora_weights(self, lora_weights, weight_name=None):
"""
Load LoRA weights into the pipeline.
"""
self.pipeline.load_lora_weights(lora_weights, weight_name=weight_name)
def generate_image(self, prompt, output_file="output_image.png"):
"""
Generate an image based on the given prompt and save it to a file.
"""
image = self.pipeline(prompt).images[0]
image.save(output_file)
print(f"Image saved to {output_file}")
image_validation.py
from PIL import Image
from transformers import BlipProcessor, BlipForQuestionAnswering
class ImageValidation:
def __init__(self, model_name="Salesforce/blip-vqa-base"):
"""
Initialize the processor and model for Visual Question Answering.
"""
self.processor = BlipProcessor.from_pretrained(model_name)
self.model = BlipForQuestionAnswering.from_pretrained(model_name)
def validate_image(self, img_path, question):
"""
Validate an image based on a given question.
Args:
img_path (str): Path to the image.
question (str): The question to ask about the image.
Returns:
str: The model's answer to the question.
"""
try:
# Load and preprocess the image
raw_image = Image.open(img_path).convert('RGB')
except Exception as e:
return f"Error loading image: {e}"
# Process the image and question
inputs = self.processor(raw_image, question, return_tensors="pt")
# Generate the answer
output = self.model.generate(**inputs)
answer = self.processor.decode(output[0], skip_special_tokens=True)
return answer
data_config.yaml
colors:
- Red
- Blue
- Black
- White
- Yellow
- Green
- Pink
- Purple
- Orange
- Beige
dress_types:
- Gown
- Sundress
- Cocktail Dress
- Wedding Dress
- Party Dress
- Maxi Dress
- Sheath Dress
- Evening Dress
- Wrap Dress
- Tutu Dress
trims:
- Buttons
- Zipper
locations:
- Wooden Table
- Marble Countertop
- Glass Desk
- Vintage Dresser
- Clean Workbench
- Picnic Table
- Polished Coffee Table
- Tiled Table
- Metal Desk
- Dining Table
Prompt: Red Gown with Buttons placed on a Wooden Table without any other disturbing objects on the table
We can use the visual question answering model to detect out of distribution objects (apples, fork, cup, hanger) and filter out this image.
This is the framework we are currently using, and we are continuously exploring new techniques to enhance our synthetic data generation pipeline. I am excited to collaborate and contribute to any similar open-source projects. If you’re working on something aligned, please don’t hesitate to reach out to me—I’d love to connect and share ideas!
Written on December 27th, 2024 by Karthik