Created
June 5, 2025 11:10
-
-
Save bigsnarfdude/52dbdf8f47152cfa42ee79668942a4de to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import csv | |
import time | |
import re | |
import requests | |
import json | |
from typing import List, Tuple, Dict | |
import concurrent.futures | |
import threading | |
class UniversityClassifier: | |
def __init__(self, model: str = "gemma3:12b-it-qat", ollama_url: str = "http://localhost:11434", | |
batch_size: int = 25, max_workers: int = 3): | |
""" | |
Initialize the classifier with Ollama. | |
Args: | |
model: Ollama model to use | |
ollama_url: URL where Ollama is running | |
batch_size: Number of universities to process in one batch | |
max_workers: Number of concurrent threads for processing | |
""" | |
self.model = model | |
self.ollama_url = ollama_url | |
self.batch_size = batch_size | |
self.max_workers = max_workers | |
self._test_connection() | |
def _test_connection(self): | |
"""Test connection to Ollama server.""" | |
try: | |
response = requests.get(f"{self.ollama_url}/api/tags", timeout=5) | |
if response.status_code == 200: | |
models = [model['name'] for model in response.json().get('models', [])] | |
if self.model in models: | |
print(f"✓ Connected to Ollama. Using model: {self.model}") | |
else: | |
print(f"⚠ Model {self.model} not found. Available models: {models}") | |
print(f"Make sure to run: ollama pull {self.model}") | |
else: | |
print(f"✗ Cannot connect to Ollama at {self.ollama_url}") | |
except requests.exceptions.RequestException as e: | |
print(f"✗ Cannot connect to Ollama: {e}") | |
print("Make sure Ollama is running with: ollama serve") | |
def _call_ollama_batch(self, prompt: str) -> str: | |
"""Make a request to Ollama API for batch processing.""" | |
data = { | |
"model": self.model, | |
"prompt": prompt, | |
"stream": False, | |
"options": { | |
"temperature": 0.1, | |
"top_p": 0.9, | |
"num_ctx": 8192, # Reduced context | |
"num_gpu": -1, | |
"num_thread": 8 # Moderate CPU usage | |
# Removed "format": "json" to allow free-form text | |
} | |
} | |
try: | |
response = requests.post( | |
f"{self.ollama_url}/api/generate", | |
json=data, | |
timeout=180 # 3 minutes with retry logic | |
) | |
response.raise_for_status() | |
return response.json()['response'].strip() | |
except requests.exceptions.RequestException as e: | |
raise Exception(f"Ollama API error: {e}") | |
def normalize_institution_name(self, raw_name: str) -> str: | |
"""Clean and normalize the institution name.""" | |
normalized = re.sub(r'\s+', ' ', raw_name.strip()) | |
replacements = { | |
'U.': 'University', 'Univ.': 'University', 'Inst.': 'Institute', | |
'Tech.': 'Technology', 'Coll.': 'College' | |
} | |
for abbrev, full in replacements.items(): | |
normalized = re.sub(r'\b' + re.escape(abbrev) + r'\b', full, normalized) | |
return normalized | |
def classify_batch(self, universities: List[str]) -> List[Tuple[str, str]]: | |
""" | |
Classify a batch of universities at once using a simple line-by-line format. | |
Returns: | |
List of (normalized_name, country) tuples | |
""" | |
# Create batch prompt with simple format (no JSON) | |
batch_prompt = f"""You are an expert on global universities. For each university below, write EXACTLY this format: | |
UNIVERSITY: full name | COUNTRY: country name | |
Examples: | |
UNIVERSITY: Harvard University | COUNTRY: United States | |
UNIVERSITY: University of Oxford | COUNTRY: United Kingdom | |
UNIVERSITY: Universidad Nacional Autónoma de México | COUNTRY: Mexico | |
Now classify these {len(universities)} universities: | |
""" | |
for i, uni in enumerate(universities, 1): | |
batch_prompt += f"{i}. {uni}\n" | |
batch_prompt += "\nRespond with exactly one line per university in the format above:" | |
try: | |
response = self._call_ollama_batch(batch_prompt) | |
return self._parse_simple_response(response, universities) | |
except Exception as e: | |
print(f"Error processing batch: {e}") | |
# Fallback: return normalized names with "Unknown" country | |
return [(self.normalize_institution_name(uni), "Unknown") for uni in universities] | |
def _parse_simple_response(self, response: str, original_universities: List[str]) -> List[Tuple[str, str]]: | |
"""Parse the simple line-by-line response format.""" | |
results = [] | |
lines = [line.strip() for line in response.split('\n') if line.strip()] | |
for line in lines: | |
if 'UNIVERSITY:' in line and 'COUNTRY:' in line: | |
try: | |
# Split by | to get university and country parts | |
parts = line.split('|') | |
if len(parts) >= 2: | |
uni_part = parts[0].strip() | |
country_part = parts[1].strip() | |
# Extract the actual values | |
if 'UNIVERSITY:' in uni_part: | |
institution = uni_part.split('UNIVERSITY:')[1].strip() | |
else: | |
institution = uni_part.strip() | |
if 'COUNTRY:' in country_part: | |
country = country_part.split('COUNTRY:')[1].strip() | |
else: | |
country = country_part.strip() | |
results.append((institution, country)) | |
except Exception as e: | |
print(f"Error parsing line: {line} - {e}") | |
continue | |
# Fill in missing results with fallback | |
while len(results) < len(original_universities): | |
idx = len(results) | |
if idx < len(original_universities): | |
normalized = self.normalize_institution_name(original_universities[idx]) | |
results.append((normalized, "Unknown")) | |
return results[:len(original_universities)] | |
def process_batch_worker(self, batch: List[str]) -> List[Dict]: | |
"""Worker function to process a batch of universities.""" | |
batch_results = self.classify_batch(batch) | |
results = [] | |
for i, (original, (normalized, country)) in enumerate(zip(batch, batch_results)): | |
results.append({ | |
'original_name': original, | |
'normalized_name': normalized, | |
'country': country | |
}) | |
return results | |
def process_universities_csv(self, input_file: str, output_file: str, start_from: int = 0): | |
""" | |
Process the universities CSV file using batch processing and threading. | |
""" | |
# Read all universities | |
universities = [] | |
with open(input_file, 'r', encoding='utf-8') as file: | |
reader = csv.reader(file) | |
try: | |
header = next(reader) | |
if not any(word in header[0].lower() for word in ['university', 'college', 'institute']): | |
universities = [header[0]] | |
else: | |
universities = [] | |
except StopIteration: | |
universities = [] | |
for row in reader: | |
if row and row[0].strip(): | |
universities.append(row[0].strip()) | |
# Start from specified position | |
universities = universities[start_from:] | |
total_universities = len(universities) | |
print(f"Found {total_universities} institutions to process") | |
print(f"Processing in batches of {self.batch_size} with {self.max_workers} workers") | |
# Create batches | |
batches = [universities[i:i + self.batch_size] | |
for i in range(0, len(universities), self.batch_size)] | |
all_results = [] | |
start_time = time.time() | |
# Process batches with threading | |
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor: | |
# Submit all batches | |
future_to_batch = {executor.submit(self.process_batch_worker, batch): i | |
for i, batch in enumerate(batches)} | |
# Collect results as they complete | |
completed = 0 | |
for future in concurrent.futures.as_completed(future_to_batch): | |
batch_idx = future_to_batch[future] | |
try: | |
batch_results = future.result() | |
all_results.extend(batch_results) | |
completed += 1 | |
# Progress update | |
processed_count = completed * self.batch_size | |
if processed_count > total_universities: | |
processed_count = total_universities | |
elapsed = time.time() - start_time | |
rate = processed_count / elapsed if elapsed > 0 else 0 | |
remaining = (total_universities - processed_count) / rate if rate > 0 else 0 | |
print(f"Completed batch {completed}/{len(batches)} - " | |
f"Processed {processed_count}/{total_universities} institutions " | |
f"({rate:.1f}/sec, ~{remaining/60:.1f} min remaining)") | |
# Save progress every few batches | |
if completed % 5 == 0: | |
self._save_results(all_results, output_file) | |
print("Progress saved!") | |
# Add small delay between batches to prevent overload | |
time.sleep(1) | |
except Exception as e: | |
print(f"Error processing batch {batch_idx}: {e}") | |
# Sort results to maintain original order | |
all_results.sort(key=lambda x: universities.index(x['original_name']) | |
if x['original_name'] in universities else float('inf')) | |
# Save final results | |
self._save_results(all_results, output_file) | |
elapsed = time.time() - start_time | |
print(f"\nProcessing complete! Processed {len(all_results)} institutions in {elapsed/60:.1f} minutes") | |
print(f"Average rate: {len(all_results)/elapsed:.1f} institutions/second") | |
return all_results | |
def _save_results(self, results: List[Dict], output_file: str): | |
"""Save results to CSV file.""" | |
with open(output_file, 'w', newline='', encoding='utf-8') as file: | |
if results: | |
writer = csv.DictWriter(file, fieldnames=['original_name', 'normalized_name', 'country']) | |
writer.writeheader() | |
writer.writerows(results) | |
# Example usage | |
if __name__ == "__main__": | |
# Configuration - adjust these for optimal performance | |
INPUT_FILE = "university2counties.csv" | |
OUTPUT_FILE = "universities_with_countries.csv" | |
MODEL = "gemma3:12b-it-qat" | |
# Conservative settings for reliability | |
BATCH_SIZE = 25 # Smaller batches for stability | |
MAX_WORKERS = 3 # Conservative threading | |
# Initialize classifier | |
classifier = UniversityClassifier( | |
model=MODEL, | |
batch_size=BATCH_SIZE, | |
max_workers=MAX_WORKERS | |
) | |
# Process the file | |
results = classifier.process_universities_csv( | |
input_file=INPUT_FILE, | |
output_file=OUTPUT_FILE, | |
start_from=0 | |
) | |
# Print summary | |
countries = {} | |
for result in results: | |
country = result['country'] | |
countries[country] = countries.get(country, 0) + 1 | |
print("\nCountry distribution:") | |
for country, count in sorted(countries.items()): | |
print(f"{country}: {count}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment